Refactor per-node-scan-start into a function
authorMatt Corallo <git@bluematt.me>
Sun, 19 May 2019 17:51:45 +0000 (13:51 -0400)
committerMatt Corallo <git@bluematt.me>
Sun, 19 May 2019 17:51:45 +0000 (13:51 -0400)
src/main.rs
src/peer.rs

index d58b1adc932d219751858386e60c841c79501292..c6a86a6845326a19205cb12d9f8221a0b8141261 100644 (file)
@@ -40,135 +40,139 @@ struct PeerState {
        request: (u64, sha256d::Hash),
 }
 
-fn scan_net() {
-       tokio::spawn(future::lazy(|| {
-               let store = unsafe { DATA_STORE.as_ref().unwrap() };
-               let printer = unsafe { PRINTER.as_ref().unwrap() };
+fn scan_node(scan_time: Instant, node: SocketAddr) {
+       let printer = unsafe { PRINTER.as_ref().unwrap() };
+       let store = unsafe { DATA_STORE.as_ref().unwrap() };
 
-               let mut scan_nodes = store.get_next_scan_nodes();
+       let peer_state = Arc::new(Mutex::new(PeerState {
+               recvd_version: false,
+               recvd_verack: false,
+               recvd_addrs: false,
+               recvd_block: false,
+               node_services: 0,
+               fail_reason: AddressState::Timeout,
+               request: (0, Default::default()),
+       }));
+       let final_peer_state = Arc::clone(&peer_state);
+
+       let peer = Delay::new(scan_time).then(move |_| {
+               printer.set_stat(Stat::NewConnection);
                let timeout = store.get_u64(U64Setting::RunTimeout);
-               let per_iter_time = Duration::from_millis(1000 / store.get_u64(U64Setting::ConnsPerSec));
-               let mut iter_time = Instant::now();
+               Peer::new(node.clone(), Duration::from_secs(timeout), printer) //TODO: timeout for total run
+       });
+       tokio::spawn(peer.and_then(move |conn_split| {
                let requested_height = unsafe { HIGHEST_HEADER.as_ref().unwrap() }.lock().unwrap().1 - 1008;
                let requested_block = unsafe { HEIGHT_MAP.as_ref().unwrap() }.lock().unwrap().get(&requested_height).unwrap().clone();
+               peer_state.lock().unwrap().request = (requested_height, requested_block);
 
-               for node in scan_nodes.drain(..) {
-                       let peer_state = Arc::new(Mutex::new(PeerState {
-                               recvd_version: false,
-                               recvd_verack: false,
-                               recvd_addrs: false,
-                               recvd_block: false,
-                               node_services: 0,
-                               fail_reason: AddressState::Timeout,
-                               request: (requested_height, requested_block),
-                       }));
-                       let final_peer_state = Arc::clone(&peer_state);
-                       let peer = Delay::new(iter_time).then(move |_| {
-                               printer.set_stat(Stat::NewConnection);
-                               Peer::new(node.clone(), Duration::from_secs(timeout), printer) //TODO: timeout for total run
-                       });
-                       iter_time += per_iter_time;
-                       tokio::spawn(peer.and_then(move |conn_split| {
-                               let (mut write, read) = conn_split;
-                               read.map_err(|_| { () }).for_each(move |msg| {
-                                       let mut state_lock = peer_state.lock().unwrap();
-                                       macro_rules! check_set_flag {
-                                               ($recvd_flag: ident, $msg: expr) => { {
-                                                       if state_lock.$recvd_flag {
-                                                               state_lock.fail_reason = AddressState::ProtocolViolation;
-                                                               printer.add_line(format!("Updating {} to ProtocolViolation due to dup {}", node, $msg), true);
-                                                               state_lock.$recvd_flag = false;
-                                                               return future::err(());
-                                                       }
-                                                       state_lock.$recvd_flag = true;
-                                               } }
+               let (mut write, read) = conn_split;
+               read.map_err(|_| { () }).for_each(move |msg| {
+                       let mut state_lock = peer_state.lock().unwrap();
+                       macro_rules! check_set_flag {
+                               ($recvd_flag: ident, $msg: expr) => { {
+                                       if state_lock.$recvd_flag {
+                                               state_lock.fail_reason = AddressState::ProtocolViolation;
+                                               printer.add_line(format!("Updating {} to ProtocolViolation due to dup {}", node, $msg), true);
+                                               state_lock.$recvd_flag = false;
+                                               return future::err(());
                                        }
-                                       state_lock.fail_reason = AddressState::TimeoutDuringRequest;
-                                       match msg {
-                                               NetworkMessage::Version(ver) => {
-                                                       if ver.start_height < 0 || ver.start_height as u64 > state_lock.request.0 + 1008*2 {
-                                                               state_lock.fail_reason = AddressState::HighBlockCount;
-                                                               return future::err(());
-                                                       }
-                                                       if (ver.start_height as u64) < state_lock.request.0 {
-                                                               printer.add_line(format!("Updating {} to LowBlockCount ({} < {})", node, ver.start_height, state_lock.request.0), true);
-                                                               state_lock.fail_reason = AddressState::LowBlockCount;
-                                                               return future::err(());
-                                                       }
-                                                       let min_version = store.get_u64(U64Setting::MinProtocolVersion);
-                                                       if (ver.version as u64) < min_version {
-                                                               printer.add_line(format!("Updating {} to LowVersion ({} < {})", node, ver.version, min_version), true);
-                                                               state_lock.fail_reason = AddressState::LowVersion;
-                                                               return future::err(());
-                                                       }
-                                                       if ver.services & 1 != 1 {
-                                                               printer.add_line(format!("Updating {} to NotFullNode (services {:x})", node, ver.services), true);
-                                                               state_lock.fail_reason = AddressState::NotFullNode;
-                                                               return future::err(());
-                                                       }
-                                                       check_set_flag!(recvd_version, "version");
-                                                       state_lock.node_services = ver.services;
-                                                       if let Err(_) = write.try_send(NetworkMessage::Verack) {
-                                                               return future::err(());
-                                                       }
-                                               },
-                                               NetworkMessage::Verack => {
-                                                       check_set_flag!(recvd_verack, "verack");
-                                                       if let Err(_) = write.try_send(NetworkMessage::GetAddr) {
-                                                               return future::err(());
-                                                       }
-                                                       if let Err(_) = write.try_send(NetworkMessage::GetData(vec![Inventory {
-                                                               inv_type: InvType::WitnessBlock,
-                                                               hash: state_lock.request.1,
-                                                       }])) {
-                                                               return future::err(());
-                                                       }
-                                               },
-                                               NetworkMessage::Ping(v) => {
-                                                       if let Err(_) = write.try_send(NetworkMessage::Pong(v)) {
-                                                               return future::err(())
-                                                       }
-                                               },
-                                               NetworkMessage::Addr(addrs) => {
-                                                       if addrs.len() > 1 {
-                                                               check_set_flag!(recvd_addrs, "addr");
-                                                               unsafe { DATA_STORE.as_ref().unwrap() }.add_fresh_nodes(&addrs);
-                                                       }
-                                               },
-                                               NetworkMessage::Block(block) => {
-                                                       if block.header.bitcoin_hash() != state_lock.request.1 ||
-                                                                       !block.check_merkle_root() || !block.check_witness_commitment() {
-                                                               state_lock.fail_reason = AddressState::ProtocolViolation;
-                                                               printer.add_line(format!("Updating {} to ProtocolViolation due to bad block", node), true);
-                                                               return future::err(());
-                                                       }
-                                                       check_set_flag!(recvd_block, "block");
-                                               },
-                                               _ => {},
+                                       state_lock.$recvd_flag = true;
+                               } }
+                       }
+                       state_lock.fail_reason = AddressState::TimeoutDuringRequest;
+                       match msg {
+                               NetworkMessage::Version(ver) => {
+                                       if ver.start_height < 0 || ver.start_height as u64 > state_lock.request.0 + 1008*2 {
+                                               state_lock.fail_reason = AddressState::HighBlockCount;
+                                               return future::err(());
+                                       }
+                                       if (ver.start_height as u64) < state_lock.request.0 {
+                                               printer.add_line(format!("Updating {} to LowBlockCount ({} < {})", node, ver.start_height, state_lock.request.0), true);
+                                               state_lock.fail_reason = AddressState::LowBlockCount;
+                                               return future::err(());
+                                       }
+                                       let min_version = store.get_u64(U64Setting::MinProtocolVersion);
+                                       if (ver.version as u64) < min_version {
+                                               printer.add_line(format!("Updating {} to LowVersion ({} < {})", node, ver.version, min_version), true);
+                                               state_lock.fail_reason = AddressState::LowVersion;
+                                               return future::err(());
+                                       }
+                                       if ver.services & 1 != 1 {
+                                               printer.add_line(format!("Updating {} to NotFullNode (services {:x})", node, ver.services), true);
+                                               state_lock.fail_reason = AddressState::NotFullNode;
+                                               return future::err(());
                                        }
-                                       future::ok(())
-                               }).then(|_| {
-                                       future::err(())
-                               })
-                       }).then(move |_: Result<(), ()>| {
-                               let printer = unsafe { PRINTER.as_ref().unwrap() };
-                               let store = unsafe { DATA_STORE.as_ref().unwrap() };
+                                       check_set_flag!(recvd_version, "version");
+                                       state_lock.node_services = ver.services;
+                                       if let Err(_) = write.try_send(NetworkMessage::Verack) {
+                                               return future::err(());
+                                       }
+                               },
+                               NetworkMessage::Verack => {
+                                       check_set_flag!(recvd_verack, "verack");
+                                       if let Err(_) = write.try_send(NetworkMessage::GetAddr) {
+                                               return future::err(());
+                                       }
+                                       if let Err(_) = write.try_send(NetworkMessage::GetData(vec![Inventory {
+                                               inv_type: InvType::WitnessBlock,
+                                               hash: state_lock.request.1,
+                                       }])) {
+                                               return future::err(());
+                                       }
+                               },
+                               NetworkMessage::Ping(v) => {
+                                       if let Err(_) = write.try_send(NetworkMessage::Pong(v)) {
+                                               return future::err(())
+                                       }
+                               },
+                               NetworkMessage::Addr(addrs) => {
+                                       if addrs.len() > 1 {
+                                               check_set_flag!(recvd_addrs, "addr");
+                                               unsafe { DATA_STORE.as_ref().unwrap() }.add_fresh_nodes(&addrs);
+                                       }
+                               },
+                               NetworkMessage::Block(block) => {
+                                       if block.header.bitcoin_hash() != state_lock.request.1 ||
+                                                       !block.check_merkle_root() || !block.check_witness_commitment() {
+                                               state_lock.fail_reason = AddressState::ProtocolViolation;
+                                               printer.add_line(format!("Updating {} to ProtocolViolation due to bad block", node), true);
+                                               return future::err(());
+                                       }
+                                       check_set_flag!(recvd_block, "block");
+                               },
+                               _ => {},
+                       }
+                       future::ok(())
+               }).then(|_| {
+                       future::err(())
+               })
+       }).then(move |_: Result<(), ()>| {
+               let printer = unsafe { PRINTER.as_ref().unwrap() };
+               let store = unsafe { DATA_STORE.as_ref().unwrap() };
+               printer.set_stat(Stat::ConnectionClosed);
 
-                               printer.set_stat(Stat::ConnectionClosed);
+               let state_lock = final_peer_state.lock().unwrap();
+               if state_lock.recvd_version && state_lock.recvd_verack &&
+                               state_lock.recvd_addrs && state_lock.recvd_block {
+                       store.set_node_state(node, AddressState::Good, state_lock.node_services);
+               } else {
+                       assert!(state_lock.fail_reason != AddressState::Good);
+                       store.set_node_state(node, state_lock.fail_reason, 0);
+               }
+               future::ok(())
+       }));
+}
 
-                               let state_lock = final_peer_state.lock().unwrap();
-                               if state_lock.recvd_version && state_lock.recvd_verack &&
-                                               state_lock.recvd_addrs && state_lock.recvd_block {
-                                       store.set_node_state(node, AddressState::Good, state_lock.node_services);
-                               } else {
-                                       if state_lock.fail_reason == AddressState::Timeout || state_lock.fail_reason == AddressState::TimeoutDuringRequest {
-                                               printer.add_line(format!("Updating {} to Timeout[DuringRequest]", node), true);
-                                       }
-                                       assert!(state_lock.fail_reason != AddressState::Good);
-                                       store.set_node_state(node, state_lock.fail_reason, 0);
-                               }
-                               future::ok(())
-                       }));
+fn scan_net() {
+       tokio::spawn(future::lazy(|| {
+               let store = unsafe { DATA_STORE.as_ref().unwrap() };
+               let mut scan_nodes = store.get_next_scan_nodes();
+               let per_iter_time = Duration::from_millis(1000 / store.get_u64(U64Setting::ConnsPerSec));
+               let mut iter_time = Instant::now();
+
+               for node in scan_nodes.drain(..) {
+                       scan_node(iter_time, node);
+                       iter_time += per_iter_time;
                }
                Delay::new(iter_time).then(|_| {
                        scan_net();
index 95028ed5f4bb3e68c97db5724b424a1ac2d71848..0d0fc3f9689be897d89b1d422c14cd269b30beff 100644 (file)
@@ -92,10 +92,7 @@ impl<'a> codec::Encoder for MsgCoder<'a> {
        }
 }
 
-pub struct Peer {
-       
-}
-
+pub struct Peer {}
 impl Peer {
        pub fn new(addr: SocketAddr, timeout: Duration, printer: &'static Printer) -> impl Future<Error=(), Item=(mpsc::Sender<NetworkMessage>, impl Stream<Item=NetworkMessage, Error=std::io::Error>)> {
                let connect_timeout = Delay::new(Instant::now() + timeout.clone()).then(|_| {