Skip forwarding gossip messages to peers if their buffer is over-full
[rust-lightning] / lightning / src / ln / peer_handler.rs
index 0497fae158b45006cc697935fe7cab87f3f0cd6a..14d2d7de622abfe2a3e5b6be77433de4fe7b0355 100644 (file)
@@ -233,6 +233,15 @@ enum InitSyncTracker{
        NodesSyncing(PublicKey),
 }
 
+/// When the outbound buffer has this many messages, we'll stop reading bytes from the peer until
+/// we have fewer than this many messages in the outbound buffer again.
+/// We also use this as the target number of outbound gossip messages to keep in the write buffer,
+/// refilled as we send bytes.
+const OUTBOUND_BUFFER_LIMIT_READ_PAUSE: usize = 10;
+/// When the outbound buffer has this many messages, we'll simply skip relaying gossip messages to
+/// the peer.
+const OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP: usize = 20;
+
 struct Peer {
        channel_encryptor: PeerChannelEncryptor,
        their_node_id: Option<PublicKey>,
@@ -531,13 +540,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                }
                        }
                }
-               const MSG_BUFF_SIZE: usize = 10;
                while !peer.awaiting_write_event {
-                       if peer.pending_outbound_buffer.len() < MSG_BUFF_SIZE {
+                       if peer.pending_outbound_buffer.len() < OUTBOUND_BUFFER_LIMIT_READ_PAUSE {
                                match peer.sync_status {
                                        InitSyncTracker::NoSyncRequested => {},
                                        InitSyncTracker::ChannelsSyncing(c) if c < 0xffff_ffff_ffff_ffff => {
-                                               let steps = ((MSG_BUFF_SIZE - peer.pending_outbound_buffer.len() + 2) / 3) as u8;
+                                               let steps = ((OUTBOUND_BUFFER_LIMIT_READ_PAUSE - peer.pending_outbound_buffer.len() + 2) / 3) as u8;
                                                let all_messages = self.message_handler.route_handler.get_next_channel_announcements(c, steps);
                                                for &(ref announce, ref update_a_option, ref update_b_option) in all_messages.iter() {
                                                        encode_and_send_msg!(announce);
@@ -554,7 +562,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                                }
                                        },
                                        InitSyncTracker::ChannelsSyncing(c) if c == 0xffff_ffff_ffff_ffff => {
-                                               let steps = (MSG_BUFF_SIZE - peer.pending_outbound_buffer.len()) as u8;
+                                               let steps = (OUTBOUND_BUFFER_LIMIT_READ_PAUSE - peer.pending_outbound_buffer.len()) as u8;
                                                let all_messages = self.message_handler.route_handler.get_next_node_announcements(None, steps);
                                                for msg in all_messages.iter() {
                                                        encode_and_send_msg!(msg);
@@ -566,7 +574,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                        },
                                        InitSyncTracker::ChannelsSyncing(_) => unreachable!(),
                                        InitSyncTracker::NodesSyncing(key) => {
-                                               let steps = (MSG_BUFF_SIZE - peer.pending_outbound_buffer.len()) as u8;
+                                               let steps = (OUTBOUND_BUFFER_LIMIT_READ_PAUSE - peer.pending_outbound_buffer.len()) as u8;
                                                let all_messages = self.message_handler.route_handler.get_next_node_announcements(Some(&key), steps);
                                                for msg in all_messages.iter() {
                                                        encode_and_send_msg!(msg);
@@ -585,7 +593,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                        Some(buff) => buff,
                                };
 
-                               let should_be_reading = peer.pending_outbound_buffer.len() < MSG_BUFF_SIZE;
+                               let should_be_reading = peer.pending_outbound_buffer.len() < OUTBOUND_BUFFER_LIMIT_READ_PAUSE;
                                let pending = &next_buff[peer.pending_outbound_buffer_first_msg_offset..];
                                let data_sent = descriptor.send_data(pending, should_be_reading);
                                peer.pending_outbound_buffer_first_msg_offset += data_sent;
@@ -658,6 +666,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                let pause_read = {
                        let mut peers_lock = self.peers.lock().unwrap();
                        let peers = &mut *peers_lock;
+                       let mut msgs_to_forward = Vec::new();
+                       let mut peer_node_id = None;
                        let pause_read = match peers.peers.get_mut(peer_descriptor) {
                                None => panic!("Descriptor for read_event is not already known to PeerManager"),
                                Some(peer) => {
@@ -793,13 +803,18 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                                                                        }
                                                                                };
 
-                                                                               if let Err(handling_error) = self.handle_message(&mut peers.peers_needing_send, peer, peer_descriptor.clone(), message){
-                                                                                       match handling_error {
+                                                                               match self.handle_message(&mut peers.peers_needing_send, peer, peer_descriptor.clone(), message) {
+                                                                                       Err(handling_error) => match handling_error {
                                                                                                MessageHandlingError::PeerHandleError(e) => { return Err(e) },
                                                                                                MessageHandlingError::LightningError(e) => {
                                                                                                        try_potential_handleerror!(Err(e));
                                                                                                },
-                                                                                       }
+                                                                                       },
+                                                                                       Ok(Some(msg)) => {
+                                                                                               peer_node_id = Some(peer.their_node_id.expect("After noise is complete, their_node_id is always set"));
+                                                                                               msgs_to_forward.push(msg);
+                                                                                       },
+                                                                                       Ok(None) => {},
                                                                                }
                                                                        }
                                                                }
@@ -807,10 +822,14 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                                }
                                        }
 
-                                       peer.pending_outbound_buffer.len() > 10 // pause_read
+                                       peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_READ_PAUSE // pause_read
                                }
                        };
 
+                       for msg in msgs_to_forward.drain(..) {
+                               self.forward_broadcast_msg(peers, &msg, peer_node_id.as_ref());
+                       }
+
                        pause_read
                };
 
@@ -818,7 +837,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
        }
 
        /// Process an incoming message and return a decision (ok, lightning error, peer handling error) regarding the next action with the peer
-       fn handle_message(&self, peers_needing_send: &mut HashSet<Descriptor>, peer: &mut Peer, peer_descriptor: Descriptor, message: wire::Message) -> Result<(), MessageHandlingError> {
+       /// Returns the message back if it needs to be broadcasted to all other peers.
+       fn handle_message(&self, peers_needing_send: &mut HashSet<Descriptor>, peer: &mut Peer, peer_descriptor: Descriptor, message: wire::Message) -> Result<Option<wire::Message>, MessageHandlingError> {
                log_trace!(self.logger, "Received message of type {} from {}", message.type_id(), log_pubkey!(peer.their_node_id.unwrap()));
 
                // Need an Init as first message
@@ -828,6 +848,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                        return Err(PeerHandleError{ no_connection_possible: false }.into());
                }
 
+               let mut should_forward = None;
+
                match message {
                        // Setup and Control messages:
                        wire::Message::Init(msg) => {
@@ -950,34 +972,28 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                self.message_handler.chan_handler.handle_announcement_signatures(&peer.their_node_id.unwrap(), &msg);
                        },
                        wire::Message::ChannelAnnouncement(msg) => {
-                               let should_forward = match self.message_handler.route_handler.handle_channel_announcement(&msg) {
+                               if match self.message_handler.route_handler.handle_channel_announcement(&msg) {
                                        Ok(v) => v,
                                        Err(e) => { return Err(e.into()); },
-                               };
-
-                               if should_forward {
-                                       // TODO: forward msg along to all our other peers!
+                               } {
+                                       should_forward = Some(wire::Message::ChannelAnnouncement(msg));
                                }
                        },
                        wire::Message::NodeAnnouncement(msg) => {
-                               let should_forward = match self.message_handler.route_handler.handle_node_announcement(&msg) {
+                               if match self.message_handler.route_handler.handle_node_announcement(&msg) {
                                        Ok(v) => v,
                                        Err(e) => { return Err(e.into()); },
-                               };
-
-                               if should_forward {
-                                       // TODO: forward msg along to all our other peers!
+                               } {
+                                       should_forward = Some(wire::Message::NodeAnnouncement(msg));
                                }
                        },
                        wire::Message::ChannelUpdate(msg) => {
                                self.message_handler.chan_handler.handle_channel_update(&peer.their_node_id.unwrap(), &msg);
-                               let should_forward = match self.message_handler.route_handler.handle_channel_update(&msg) {
+                               if match self.message_handler.route_handler.handle_channel_update(&msg) {
                                        Ok(v) => v,
                                        Err(e) => { return Err(e.into()); },
-                               };
-
-                               if should_forward {
-                                       // TODO: forward msg along to all our other peers!
+                               } {
+                                       should_forward = Some(wire::Message::ChannelUpdate(msg));
                                }
                        },
                        wire::Message::QueryShortChannelIds(msg) => {
@@ -1006,7 +1022,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                log_trace!(self.logger, "Received unknown odd message of type {}, ignoring", msg_type);
                        }
                };
-               Ok(())
+               Ok(should_forward)
        }
 
        fn forward_broadcast_msg(&self, peers: &mut PeerHolder<Descriptor>, msg: &wire::Message, except_node: Option<&PublicKey>) {
@@ -1019,6 +1035,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                                        !peer.should_forward_channel_announcement(msg.contents.short_channel_id) {
                                                continue
                                        }
+                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP {
+                                               continue;
+                                       }
                                        if peer.their_node_id.as_ref() == Some(&msg.contents.node_id_1) ||
                                           peer.their_node_id.as_ref() == Some(&msg.contents.node_id_2) {
                                                continue;
@@ -1038,6 +1057,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                                        !peer.should_forward_node_announcement(msg.contents.node_id) {
                                                continue
                                        }
+                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP {
+                                               continue;
+                                       }
                                        if peer.their_node_id.as_ref() == Some(&msg.contents.node_id) {
                                                continue;
                                        }
@@ -1056,6 +1078,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref> PeerManager<D
                                                        !peer.should_forward_channel_announcement(msg.contents.short_channel_id)  {
                                                continue
                                        }
+                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP {
+                                               continue;
+                                       }
                                        if except_node.is_some() && peer.their_node_id.as_ref() == except_node {
                                                continue;
                                        }