From c1b7bcd7a464506a551f618a5afbbaf6ecbc0f1c Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 26 Aug 2019 15:10:44 -0400 Subject: [PATCH] Use Option so short buffer and unknown msg are diff --- src/main.rs | 28 ++++++++++++++-------------- src/peer.rs | 10 +++++----- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/main.rs b/src/main.rs index c86529e..19dcb30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -108,7 +108,7 @@ pub fn scan_node(scan_time: Instant, node: SocketAddr, manual: bool) { } state_lock.fail_reason = AddressState::TimeoutDuringRequest; match msg { - NetworkMessage::Version(ver) => { + Some(NetworkMessage::Version(ver)) => { if ver.start_height < 0 || ver.start_height as u64 > state_lock.request.0 + 1008*2 { state_lock.fail_reason = AddressState::HighBlockCount; return future::err(()); @@ -142,18 +142,18 @@ pub fn scan_node(scan_time: Instant, node: SocketAddr, manual: bool) { return future::err(()); } }, - NetworkMessage::Verack => { + Some(NetworkMessage::Verack) => { check_set_flag!(recvd_verack, "verack"); if let Err(_) = write.try_send(NetworkMessage::Ping(state_lock.pong_nonce)) { return future::err(()); } }, - NetworkMessage::Ping(v) => { + Some(NetworkMessage::Ping(v)) => { if let Err(_) = write.try_send(NetworkMessage::Pong(v)) { return future::err(()) } }, - NetworkMessage::Pong(v) => { + Some(NetworkMessage::Pong(v)) => { if v != state_lock.pong_nonce { state_lock.fail_reason = AddressState::ProtocolViolation; state_lock.msg = ("due to invalid pong nonce".to_string(), true); @@ -164,7 +164,7 @@ pub fn scan_node(scan_time: Instant, node: SocketAddr, manual: bool) { return future::err(()); } }, - NetworkMessage::Addr(addrs) => { + Some(NetworkMessage::Addr(addrs)) => { if addrs.len() > 1000 { state_lock.fail_reason = AddressState::ProtocolViolation; state_lock.msg = (format!("due to oversized addr: {}", addrs.len()), true); @@ -184,7 +184,7 @@ pub fn scan_node(scan_time: Instant, node: SocketAddr, manual: bool) { } unsafe { DATA_STORE.as_ref().unwrap() }.add_fresh_nodes(&addrs); }, - NetworkMessage::Block(block) => { + Some(NetworkMessage::Block(block)) => { if block != state_lock.request.2 { state_lock.fail_reason = AddressState::ProtocolViolation; state_lock.msg = ("due to bad block".to_string(), true); @@ -193,7 +193,7 @@ pub fn scan_node(scan_time: Instant, node: SocketAddr, manual: bool) { check_set_flag!(recvd_block, "block"); return future::err(()); }, - NetworkMessage::Inv(invs) => { + Some(NetworkMessage::Inv(invs)) => { for inv in invs { if inv.inv_type == InvType::Transaction { state_lock.fail_reason = AddressState::EvilNode; @@ -202,7 +202,7 @@ pub fn scan_node(scan_time: Instant, node: SocketAddr, manual: bool) { } } }, - NetworkMessage::Tx(_) => { + Some(NetworkMessage::Tx(_)) => { state_lock.fail_reason = AddressState::EvilNode; state_lock.msg = ("due to unrequested transaction".to_string(), true); return future::err(()); @@ -310,13 +310,13 @@ fn make_trusted_conn(trusted_sockaddr: SocketAddr, bgp_client: Arc) { return future::err(()); } match msg { - NetworkMessage::Version(ver) => { + Some(NetworkMessage::Version(ver)) => { if let Err(_) = trusted_write.try_send(NetworkMessage::Verack) { return future::err(()) } starting_height = ver.start_height; }, - NetworkMessage::Verack => { + Some(NetworkMessage::Verack) => { if let Err(_) = trusted_write.try_send(NetworkMessage::SendHeaders) { return future::err(()); } @@ -331,10 +331,10 @@ fn make_trusted_conn(trusted_sockaddr: SocketAddr, bgp_client: Arc) { return future::err(()); } }, - NetworkMessage::Addr(addrs) => { + Some(NetworkMessage::Addr(addrs)) => { unsafe { DATA_STORE.as_ref().unwrap() }.add_fresh_nodes(&addrs); }, - NetworkMessage::Headers(headers) => { + Some(NetworkMessage::Headers(headers)) => { if headers.is_empty() { return future::ok(()); } @@ -376,7 +376,7 @@ fn make_trusted_conn(trusted_sockaddr: SocketAddr, bgp_client: Arc) { return future::err(()) } }, - NetworkMessage::Block(block) => { + Some(NetworkMessage::Block(block)) => { let hash = block.header.bitcoin_hash(); let header_map = unsafe { HEADER_MAP.as_ref().unwrap() }.lock().unwrap(); let height = *header_map.get(&hash).expect("Got loose block from trusted peer we coulnd't have requested"); @@ -388,7 +388,7 @@ fn make_trusted_conn(trusted_sockaddr: SocketAddr, bgp_client: Arc) { } } }, - NetworkMessage::Ping(v) => { + Some(NetworkMessage::Ping(v)) => { if let Err(_) = trusted_write.try_send(NetworkMessage::Pong(v)) { return future::err(()) } diff --git a/src/peer.rs b/src/peer.rs index 137d92d..aeeb2c2 100644 --- a/src/peer.rs +++ b/src/peer.rs @@ -45,10 +45,10 @@ impl<'a> std::io::Read for BytesDecoder<'a> { struct MsgCoder<'a>(&'a Printer); impl<'a> codec::Decoder for MsgCoder<'a> { - type Item = NetworkMessage; + type Item = Option; type Error = encode::Error; - fn decode(&mut self, bytes: &mut bytes::BytesMut) -> Result, encode::Error> { + fn decode(&mut self, bytes: &mut bytes::BytesMut) -> Result>, encode::Error> { let mut decoder = BytesDecoder { buf: bytes, pos: 0 @@ -57,7 +57,7 @@ impl<'a> codec::Decoder for MsgCoder<'a> { Ok(res) => { decoder.buf.advance(decoder.pos); if res.magic == Network::Bitcoin.magic() { - Ok(Some(res.payload)) + Ok(Some(Some(res.payload))) } else { Err(encode::Error::UnexpectedNetworkMagic { expected: Network::Bitcoin.magic(), @@ -72,7 +72,7 @@ impl<'a> codec::Decoder for MsgCoder<'a> { //XXX(fixthese): self.0.add_line(format!("rust-bitcoin doesn't support {}!", msg), true); if msg == "gnop" { Err(e) - } else { Ok(None) } + } else { Ok(Some(None)) } }, _ => { self.0.add_line(format!("Error decoding message: {:?}", e), true); @@ -148,7 +148,7 @@ macro_rules! try_write_small { pub struct Peer {} impl Peer { - pub fn new(addr: SocketAddr, tor_proxy: &SocketAddr, timeout: Duration, printer: &'static Printer) -> impl Future, impl Stream)> { + pub fn new(addr: SocketAddr, tor_proxy: &SocketAddr, timeout: Duration, printer: &'static Printer) -> impl Future, impl Stream, Error=encode::Error>)> { let connect_timeout = Delay::new(Instant::now() + timeout.clone()).then(|_| { future::err(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout reached")) }); -- 2.39.5