From: Matt Corallo Date: Sun, 19 May 2019 17:51:45 +0000 (-0400) Subject: Refactor per-node-scan-start into a function X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=a95c8b1e0aeea6a14f193599c6fdb8e0e1d60f60;p=dnsseed-rust Refactor per-node-scan-start into a function --- diff --git a/src/main.rs b/src/main.rs index d58b1ad..c6a86a6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,135 +40,139 @@ struct PeerState { request: (u64, sha256d::Hash), } -fn scan_net() { - tokio::spawn(future::lazy(|| { - let store = unsafe { DATA_STORE.as_ref().unwrap() }; - let printer = unsafe { PRINTER.as_ref().unwrap() }; +fn scan_node(scan_time: Instant, node: SocketAddr) { + let printer = unsafe { PRINTER.as_ref().unwrap() }; + let store = unsafe { DATA_STORE.as_ref().unwrap() }; - let mut scan_nodes = store.get_next_scan_nodes(); + let peer_state = Arc::new(Mutex::new(PeerState { + recvd_version: false, + recvd_verack: false, + recvd_addrs: false, + recvd_block: false, + node_services: 0, + fail_reason: AddressState::Timeout, + request: (0, Default::default()), + })); + let final_peer_state = Arc::clone(&peer_state); + + let peer = Delay::new(scan_time).then(move |_| { + printer.set_stat(Stat::NewConnection); let timeout = store.get_u64(U64Setting::RunTimeout); - let per_iter_time = Duration::from_millis(1000 / store.get_u64(U64Setting::ConnsPerSec)); - let mut iter_time = Instant::now(); + Peer::new(node.clone(), Duration::from_secs(timeout), printer) //TODO: timeout for total run + }); + tokio::spawn(peer.and_then(move |conn_split| { let requested_height = unsafe { HIGHEST_HEADER.as_ref().unwrap() }.lock().unwrap().1 - 1008; let requested_block = unsafe { HEIGHT_MAP.as_ref().unwrap() }.lock().unwrap().get(&requested_height).unwrap().clone(); + peer_state.lock().unwrap().request = (requested_height, requested_block); - for node in scan_nodes.drain(..) { - let peer_state = Arc::new(Mutex::new(PeerState { - recvd_version: false, - recvd_verack: false, - recvd_addrs: false, - recvd_block: false, - node_services: 0, - fail_reason: AddressState::Timeout, - request: (requested_height, requested_block), - })); - let final_peer_state = Arc::clone(&peer_state); - let peer = Delay::new(iter_time).then(move |_| { - printer.set_stat(Stat::NewConnection); - Peer::new(node.clone(), Duration::from_secs(timeout), printer) //TODO: timeout for total run - }); - iter_time += per_iter_time; - tokio::spawn(peer.and_then(move |conn_split| { - let (mut write, read) = conn_split; - read.map_err(|_| { () }).for_each(move |msg| { - let mut state_lock = peer_state.lock().unwrap(); - macro_rules! check_set_flag { - ($recvd_flag: ident, $msg: expr) => { { - if state_lock.$recvd_flag { - state_lock.fail_reason = AddressState::ProtocolViolation; - printer.add_line(format!("Updating {} to ProtocolViolation due to dup {}", node, $msg), true); - state_lock.$recvd_flag = false; - return future::err(()); - } - state_lock.$recvd_flag = true; - } } + let (mut write, read) = conn_split; + read.map_err(|_| { () }).for_each(move |msg| { + let mut state_lock = peer_state.lock().unwrap(); + macro_rules! check_set_flag { + ($recvd_flag: ident, $msg: expr) => { { + if state_lock.$recvd_flag { + state_lock.fail_reason = AddressState::ProtocolViolation; + printer.add_line(format!("Updating {} to ProtocolViolation due to dup {}", node, $msg), true); + state_lock.$recvd_flag = false; + return future::err(()); } - state_lock.fail_reason = AddressState::TimeoutDuringRequest; - match msg { - NetworkMessage::Version(ver) => { - if ver.start_height < 0 || ver.start_height as u64 > state_lock.request.0 + 1008*2 { - state_lock.fail_reason = AddressState::HighBlockCount; - return future::err(()); - } - if (ver.start_height as u64) < state_lock.request.0 { - printer.add_line(format!("Updating {} to LowBlockCount ({} < {})", node, ver.start_height, state_lock.request.0), true); - state_lock.fail_reason = AddressState::LowBlockCount; - return future::err(()); - } - let min_version = store.get_u64(U64Setting::MinProtocolVersion); - if (ver.version as u64) < min_version { - printer.add_line(format!("Updating {} to LowVersion ({} < {})", node, ver.version, min_version), true); - state_lock.fail_reason = AddressState::LowVersion; - return future::err(()); - } - if ver.services & 1 != 1 { - printer.add_line(format!("Updating {} to NotFullNode (services {:x})", node, ver.services), true); - state_lock.fail_reason = AddressState::NotFullNode; - return future::err(()); - } - check_set_flag!(recvd_version, "version"); - state_lock.node_services = ver.services; - if let Err(_) = write.try_send(NetworkMessage::Verack) { - return future::err(()); - } - }, - NetworkMessage::Verack => { - check_set_flag!(recvd_verack, "verack"); - if let Err(_) = write.try_send(NetworkMessage::GetAddr) { - return future::err(()); - } - if let Err(_) = write.try_send(NetworkMessage::GetData(vec![Inventory { - inv_type: InvType::WitnessBlock, - hash: state_lock.request.1, - }])) { - return future::err(()); - } - }, - NetworkMessage::Ping(v) => { - if let Err(_) = write.try_send(NetworkMessage::Pong(v)) { - return future::err(()) - } - }, - NetworkMessage::Addr(addrs) => { - if addrs.len() > 1 { - check_set_flag!(recvd_addrs, "addr"); - unsafe { DATA_STORE.as_ref().unwrap() }.add_fresh_nodes(&addrs); - } - }, - NetworkMessage::Block(block) => { - if block.header.bitcoin_hash() != state_lock.request.1 || - !block.check_merkle_root() || !block.check_witness_commitment() { - state_lock.fail_reason = AddressState::ProtocolViolation; - printer.add_line(format!("Updating {} to ProtocolViolation due to bad block", node), true); - return future::err(()); - } - check_set_flag!(recvd_block, "block"); - }, - _ => {}, + state_lock.$recvd_flag = true; + } } + } + state_lock.fail_reason = AddressState::TimeoutDuringRequest; + match msg { + NetworkMessage::Version(ver) => { + if ver.start_height < 0 || ver.start_height as u64 > state_lock.request.0 + 1008*2 { + state_lock.fail_reason = AddressState::HighBlockCount; + return future::err(()); + } + if (ver.start_height as u64) < state_lock.request.0 { + printer.add_line(format!("Updating {} to LowBlockCount ({} < {})", node, ver.start_height, state_lock.request.0), true); + state_lock.fail_reason = AddressState::LowBlockCount; + return future::err(()); + } + let min_version = store.get_u64(U64Setting::MinProtocolVersion); + if (ver.version as u64) < min_version { + printer.add_line(format!("Updating {} to LowVersion ({} < {})", node, ver.version, min_version), true); + state_lock.fail_reason = AddressState::LowVersion; + return future::err(()); + } + if ver.services & 1 != 1 { + printer.add_line(format!("Updating {} to NotFullNode (services {:x})", node, ver.services), true); + state_lock.fail_reason = AddressState::NotFullNode; + return future::err(()); } - future::ok(()) - }).then(|_| { - future::err(()) - }) - }).then(move |_: Result<(), ()>| { - let printer = unsafe { PRINTER.as_ref().unwrap() }; - let store = unsafe { DATA_STORE.as_ref().unwrap() }; + check_set_flag!(recvd_version, "version"); + state_lock.node_services = ver.services; + if let Err(_) = write.try_send(NetworkMessage::Verack) { + return future::err(()); + } + }, + NetworkMessage::Verack => { + check_set_flag!(recvd_verack, "verack"); + if let Err(_) = write.try_send(NetworkMessage::GetAddr) { + return future::err(()); + } + if let Err(_) = write.try_send(NetworkMessage::GetData(vec![Inventory { + inv_type: InvType::WitnessBlock, + hash: state_lock.request.1, + }])) { + return future::err(()); + } + }, + NetworkMessage::Ping(v) => { + if let Err(_) = write.try_send(NetworkMessage::Pong(v)) { + return future::err(()) + } + }, + NetworkMessage::Addr(addrs) => { + if addrs.len() > 1 { + check_set_flag!(recvd_addrs, "addr"); + unsafe { DATA_STORE.as_ref().unwrap() }.add_fresh_nodes(&addrs); + } + }, + NetworkMessage::Block(block) => { + if block.header.bitcoin_hash() != state_lock.request.1 || + !block.check_merkle_root() || !block.check_witness_commitment() { + state_lock.fail_reason = AddressState::ProtocolViolation; + printer.add_line(format!("Updating {} to ProtocolViolation due to bad block", node), true); + return future::err(()); + } + check_set_flag!(recvd_block, "block"); + }, + _ => {}, + } + future::ok(()) + }).then(|_| { + future::err(()) + }) + }).then(move |_: Result<(), ()>| { + let printer = unsafe { PRINTER.as_ref().unwrap() }; + let store = unsafe { DATA_STORE.as_ref().unwrap() }; + printer.set_stat(Stat::ConnectionClosed); - printer.set_stat(Stat::ConnectionClosed); + let state_lock = final_peer_state.lock().unwrap(); + if state_lock.recvd_version && state_lock.recvd_verack && + state_lock.recvd_addrs && state_lock.recvd_block { + store.set_node_state(node, AddressState::Good, state_lock.node_services); + } else { + assert!(state_lock.fail_reason != AddressState::Good); + store.set_node_state(node, state_lock.fail_reason, 0); + } + future::ok(()) + })); +} - let state_lock = final_peer_state.lock().unwrap(); - if state_lock.recvd_version && state_lock.recvd_verack && - state_lock.recvd_addrs && state_lock.recvd_block { - store.set_node_state(node, AddressState::Good, state_lock.node_services); - } else { - if state_lock.fail_reason == AddressState::Timeout || state_lock.fail_reason == AddressState::TimeoutDuringRequest { - printer.add_line(format!("Updating {} to Timeout[DuringRequest]", node), true); - } - assert!(state_lock.fail_reason != AddressState::Good); - store.set_node_state(node, state_lock.fail_reason, 0); - } - future::ok(()) - })); +fn scan_net() { + tokio::spawn(future::lazy(|| { + let store = unsafe { DATA_STORE.as_ref().unwrap() }; + let mut scan_nodes = store.get_next_scan_nodes(); + let per_iter_time = Duration::from_millis(1000 / store.get_u64(U64Setting::ConnsPerSec)); + let mut iter_time = Instant::now(); + + for node in scan_nodes.drain(..) { + scan_node(iter_time, node); + iter_time += per_iter_time; } Delay::new(iter_time).then(|_| { scan_net(); diff --git a/src/peer.rs b/src/peer.rs index 95028ed..0d0fc3f 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -92,10 +92,7 @@ impl<'a> codec::Encoder for MsgCoder<'a> { } } -pub struct Peer { - -} - +pub struct Peer {} impl Peer { pub fn new(addr: SocketAddr, timeout: Duration, printer: &'static Printer) -> impl Future, impl Stream)> { let connect_timeout = Delay::new(Instant::now() + timeout.clone()).then(|_| {