Fix ping message sent on timer ticks to be encrypted so the peer can decrypt it inste...
[rust-lightning] / lightning / src / ln / peer_handler.rs
index e25e50c28b5f1190edea81c429bae659b6a0c092..5fb2cdbdeab0e4c2c3ed7605dc722d7bcd594434 100644 (file)
@@ -12,13 +12,13 @@ use ln::features::InitFeatures;
 use ln::msgs;
 use ln::msgs::ChannelMessageHandler;
 use ln::channelmanager::{SimpleArcChannelManager, SimpleRefChannelManager};
+use util::ser::VecWriter;
 use ln::peer_channel_encryptor::{PeerChannelEncryptor,NextNoiseStep};
 use ln::wire;
 use ln::wire::Encode;
 use util::byte_utils;
 use util::events::{MessageSendEvent, MessageSendEventsProvider};
 use util::logger::Logger;
-use util::ser::Writer;
 
 use std::collections::{HashMap,hash_map,HashSet,LinkedList};
 use std::sync::{Arc, Mutex};
@@ -189,21 +189,9 @@ pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref> where CM::Target
        peer_counter_low: AtomicUsize,
        peer_counter_high: AtomicUsize,
 
-       initial_syncs_sent: AtomicUsize,
        logger: Arc<Logger>,
 }
 
-struct VecWriter(Vec<u8>);
-impl Writer for VecWriter {
-       fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
-               self.0.extend_from_slice(buf);
-               Ok(())
-       }
-       fn size_hint(&mut self, size: usize) {
-               self.0.reserve_exact(size);
-       }
-}
-
 macro_rules! encode_msg {
        ($msg: expr) => {{
                let mut buffer = VecWriter(Vec::new());
@@ -212,9 +200,6 @@ macro_rules! encode_msg {
        }}
 }
 
-//TODO: Really should do something smarter for this
-const INITIAL_SYNCS_TO_SEND: usize = 5;
-
 /// Manages and reacts to connection events. You probably want to use file descriptors as PeerIds.
 /// PeerIds may repeat, but only after disconnect_event() has been called.
 impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where CM::Target: msgs::ChannelMessageHandler {
@@ -236,7 +221,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                        ephemeral_key_midstate,
                        peer_counter_low: AtomicUsize::new(0),
                        peer_counter_high: AtomicUsize::new(0),
-                       initial_syncs_sent: AtomicUsize::new(0),
                        logger,
                }
        }
@@ -517,37 +501,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                                                                }
                                                        }
 
-                                                       macro_rules! try_potential_decodeerror {
-                                                               ($thing: expr) => {
-                                                                       match $thing {
-                                                                               Ok(x) => x,
-                                                                               Err(e) => {
-                                                                                       match e {
-                                                                                               msgs::DecodeError::UnknownVersion => return Err(PeerHandleError{ no_connection_possible: false }),
-                                                                                               msgs::DecodeError::UnknownRequiredFeature => {
-                                                                                                       log_debug!(self, "Got a channel/node announcement with an known required feature flag, you may want to update!");
-                                                                                                       continue;
-                                                                                               },
-                                                                                               msgs::DecodeError::InvalidValue => {
-                                                                                                       log_debug!(self, "Got an invalid value while deserializing message");
-                                                                                                       return Err(PeerHandleError{ no_connection_possible: false });
-                                                                                               },
-                                                                                               msgs::DecodeError::ShortRead => {
-                                                                                                       log_debug!(self, "Deserialization failed due to shortness of message");
-                                                                                                       return Err(PeerHandleError{ no_connection_possible: false });
-                                                                                               },
-                                                                                               msgs::DecodeError::ExtraAddressesPerType => {
-                                                                                                       log_debug!(self, "Error decoding message, ignoring due to lnd spec incompatibility. See https://github.com/lightningnetwork/lnd/issues/1407");
-                                                                                                       continue;
-                                                                                               },
-                                                                                               msgs::DecodeError::BadLengthDescriptor => return Err(PeerHandleError{ no_connection_possible: false }),
-                                                                                               msgs::DecodeError::Io(_) => return Err(PeerHandleError{ no_connection_possible: false }),
-                                                                                       }
-                                                                               }
-                                                                       };
-                                                               }
-                                                       }
-
                                                        macro_rules! insert_node_id {
                                                                () => {
                                                                        match peers.node_id_to_descriptor.entry(peer.their_node_id.unwrap()) {
@@ -580,8 +533,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                                                                        peer.their_node_id = Some(their_node_id);
                                                                        insert_node_id!();
                                                                        let mut features = InitFeatures::supported();
-                                                                       if self.initial_syncs_sent.load(Ordering::Acquire) < INITIAL_SYNCS_TO_SEND {
-                                                                               self.initial_syncs_sent.fetch_add(1, Ordering::AcqRel);
+                                                                       if self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
                                                                                features.set_initial_routing_sync();
                                                                        }
 
@@ -613,7 +565,34 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                                                                                peer.pending_read_is_header = true;
 
                                                                                let mut reader = ::std::io::Cursor::new(&msg_data[..]);
-                                                                               let message = try_potential_decodeerror!(wire::read(&mut reader));
+                                                                               let message_result = wire::read(&mut reader);
+                                                                               let message = match message_result {
+                                                                                       Ok(x) => x,
+                                                                                       Err(e) => {
+                                                                                               match e {
+                                                                                                       msgs::DecodeError::UnknownVersion => return Err(PeerHandleError { no_connection_possible: false }),
+                                                                                                       msgs::DecodeError::UnknownRequiredFeature => {
+                                                                                                               log_debug!(self, "Got a channel/node announcement with an known required feature flag, you may want to update!");
+                                                                                                               continue;
+                                                                                                       }
+                                                                                                       msgs::DecodeError::InvalidValue => {
+                                                                                                               log_debug!(self, "Got an invalid value while deserializing message");
+                                                                                                               return Err(PeerHandleError { no_connection_possible: false });
+                                                                                                       }
+                                                                                                       msgs::DecodeError::ShortRead => {
+                                                                                                               log_debug!(self, "Deserialization failed due to shortness of message");
+                                                                                                               return Err(PeerHandleError { no_connection_possible: false });
+                                                                                                       }
+                                                                                                       msgs::DecodeError::ExtraAddressesPerType => {
+                                                                                                               log_debug!(self, "Error decoding message, ignoring due to lnd spec incompatibility. See https://github.com/lightningnetwork/lnd/issues/1407");
+                                                                                                               continue;
+                                                                                                       }
+                                                                                                       msgs::DecodeError::BadLengthDescriptor => return Err(PeerHandleError { no_connection_possible: false }),
+                                                                                                       msgs::DecodeError::Io(_) => return Err(PeerHandleError { no_connection_possible: false }),
+                                                                                               }
+                                                                                       }
+                                                                               };
+
                                                                                log_trace!(self, "Received message of type {} from {}", message.type_id(), log_pubkey!(peer.their_node_id.unwrap()));
 
                                                                                // Need an Init as first message
@@ -652,8 +631,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
 
                                                                                                if !peer.outbound {
                                                                                                        let mut features = InitFeatures::supported();
-                                                                                                       if self.initial_syncs_sent.load(Ordering::Acquire) < INITIAL_SYNCS_TO_SEND {
-                                                                                                               self.initial_syncs_sent.fetch_add(1, Ordering::AcqRel);
+                                                                                                       if self.message_handler.route_handler.should_request_full_sync(&peer.their_node_id.unwrap()) {
                                                                                                                features.set_initial_routing_sync();
                                                                                                        }
 
@@ -1097,31 +1075,34 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                        let peers = &mut peers.peers;
 
                        peers.retain(|descriptor, peer| {
-                               if peer.awaiting_pong == true {
+                               if peer.awaiting_pong {
                                        peers_needing_send.remove(descriptor);
                                        match peer.their_node_id {
                                                Some(node_id) => {
                                                        node_id_to_descriptor.remove(&node_id);
                                                        self.message_handler.chan_handler.peer_disconnected(&node_id, true);
-                                               },
+                                               }
                                                None => {}
                                        }
+                                       return false;
+                               }
+
+                               if !peer.channel_encryptor.is_ready_for_encryption() {
+                                       // The peer needs to complete its handshake before we can exchange messages
+                                       return true;
                                }
 
                                let ping = msgs::Ping {
                                        ponglen: 0,
                                        byteslen: 64,
                                };
-                               peer.pending_outbound_buffer.push_back(encode_msg!(&ping));
+                               peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encode_msg!(&ping)));
+
                                let mut descriptor_clone = descriptor.clone();
                                self.do_attempt_write_data(&mut descriptor_clone, peer);
 
-                               if peer.awaiting_pong {
-                                       false // Drop the peer
-                               } else {
-                                       peer.awaiting_pong = true;
-                                       true
-                               }
+                               peer.awaiting_pong = true;
+                               true
                        });
                }
        }
@@ -1140,15 +1121,29 @@ mod tests {
 
        use rand::{thread_rng, Rng};
 
-       use std::sync::{Arc};
+       use std;
+       use std::sync::{Arc, Mutex};
 
-       #[derive(PartialEq, Eq, Clone, Hash)]
+       #[derive(Clone)]
        struct FileDescriptor {
                fd: u16,
+               outbound_data: Arc<Mutex<Vec<u8>>>,
+       }
+       impl PartialEq for FileDescriptor {
+               fn eq(&self, other: &Self) -> bool {
+                       self.fd == other.fd
+               }
+       }
+       impl Eq for FileDescriptor { }
+       impl std::hash::Hash for FileDescriptor {
+               fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+                       self.fd.hash(hasher)
+               }
        }
 
        impl SocketDescriptor for FileDescriptor {
                fn send_data(&mut self, data: &[u8], _resume_read: bool) -> usize {
+                       self.outbound_data.lock().unwrap().extend_from_slice(data);
                        data.len()
                }
 
@@ -1189,10 +1184,15 @@ mod tests {
 
        fn establish_connection<'a>(peer_a: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler>, peer_b: &PeerManager<FileDescriptor, &'a test_utils::TestChannelMessageHandler>) {
                let secp_ctx = Secp256k1::new();
-               let their_id = PublicKey::from_secret_key(&secp_ctx, &peer_b.our_node_secret);
-               let fd = FileDescriptor { fd: 1};
-               peer_a.new_inbound_connection(fd.clone()).unwrap();
-               peer_a.peers.lock().unwrap().node_id_to_descriptor.insert(their_id, fd.clone());
+               let a_id = PublicKey::from_secret_key(&secp_ctx, &peer_a.our_node_secret);
+               //let b_id = PublicKey::from_secret_key(&secp_ctx, &peer_b.our_node_secret);
+               let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_b = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let initial_data = peer_b.new_outbound_connection(a_id, fd_b.clone()).unwrap();
+               peer_a.new_inbound_connection(fd_a.clone()).unwrap();
+               assert_eq!(peer_a.read_event(&mut fd_a, initial_data).unwrap(), false);
+               assert_eq!(peer_b.read_event(&mut fd_b, fd_a.outbound_data.lock().unwrap().split_off(0)).unwrap(), false);
+               assert_eq!(peer_a.read_event(&mut fd_a, fd_b.outbound_data.lock().unwrap().split_off(0)).unwrap(), false);
        }
 
        #[test]