Use a message buffer abstraction in OnionMessenger
authorJeffrey Czyz <jkczyz@gmail.com>
Mon, 6 Nov 2023 22:53:07 +0000 (16:53 -0600)
committerJeffrey Czyz <jkczyz@gmail.com>
Wed, 6 Dec 2023 04:39:16 +0000 (22:39 -0600)
Onion messages are buffered for sending to the next node. Since the
network has limited adoption, connecting directly to a peer may be
necessary. Add an OnionMessageBuffer abstraction that can differentiate
between connected peers and those are pending a connection. This allows
for buffering messages before a connection is established and applying
different buffer policies for peers yet to be connected.

fuzz/src/onion_message.rs
lightning/src/onion_message/messenger.rs

index de7b8b6b4c612a412902720df613f768649205d4..2882dcfb5089ec7bcb863de6cb9b3bb8e773ae05 100644 (file)
@@ -269,7 +269,7 @@ mod tests {
                                                "Received an onion message with path_id None and a reply_path: Custom(TestCustomMessage)"
                                                .to_string())), Some(&1));
                        assert_eq!(log_entries.get(&("lightning::onion_message::messenger".to_string(),
-                                               "Sending onion message: TestCustomMessage".to_string())), Some(&1));
+                                               "Sending onion message when responding to Custom onion message with path_id None: TestCustomMessage".to_string())), Some(&1));
                }
 
                let two_unblinded_hops_om = "\
index c7f01ae597862713ab1e3bde1d1f4c8b8f4b4a0d..2c42566e733bd205157d7aeb2b882b0501512837 100644 (file)
@@ -150,13 +150,70 @@ where
        entropy_source: ES,
        node_signer: NS,
        logger: L,
-       pending_messages: Mutex<HashMap<PublicKey, VecDeque<OnionMessage>>>,
+       message_buffers: Mutex<HashMap<PublicKey, OnionMessageBuffer>>,
        secp_ctx: Secp256k1<secp256k1::All>,
        message_router: MR,
        offers_handler: OMH,
        custom_handler: CMH,
 }
 
+/// [`OnionMessage`]s buffered to be sent.
+enum OnionMessageBuffer {
+       /// Messages for a node connected as a peer.
+       ConnectedPeer(VecDeque<OnionMessage>),
+
+       /// Messages for a node that is not yet connected.
+       PendingConnection(VecDeque<OnionMessage>),
+}
+
+impl OnionMessageBuffer {
+       fn pending_messages(&self) -> &VecDeque<OnionMessage> {
+               match self {
+                       OnionMessageBuffer::ConnectedPeer(pending_messages) => pending_messages,
+                       OnionMessageBuffer::PendingConnection(pending_messages) => pending_messages,
+               }
+       }
+
+       fn enqueue_message(&mut self, message: OnionMessage) {
+               let pending_messages = match self {
+                       OnionMessageBuffer::ConnectedPeer(pending_messages) => pending_messages,
+                       OnionMessageBuffer::PendingConnection(pending_messages) => pending_messages,
+               };
+
+               pending_messages.push_back(message);
+       }
+
+       fn dequeue_message(&mut self) -> Option<OnionMessage> {
+               let pending_messages = match self {
+                       OnionMessageBuffer::ConnectedPeer(pending_messages) => pending_messages,
+                       OnionMessageBuffer::PendingConnection(pending_messages) => {
+                               debug_assert!(false);
+                               pending_messages
+                       },
+               };
+
+               pending_messages.pop_front()
+       }
+
+       #[cfg(test)]
+       fn release_pending_messages(&mut self) -> VecDeque<OnionMessage> {
+               let pending_messages = match self {
+                       OnionMessageBuffer::ConnectedPeer(pending_messages) => pending_messages,
+                       OnionMessageBuffer::PendingConnection(pending_messages) => pending_messages,
+               };
+
+               core::mem::take(pending_messages)
+       }
+
+       fn mark_connected(&mut self) {
+               if let OnionMessageBuffer::PendingConnection(pending_messages) = self {
+                       let mut new_pending_messages = VecDeque::new();
+                       core::mem::swap(pending_messages, &mut new_pending_messages);
+                       *self = OnionMessageBuffer::ConnectedPeer(new_pending_messages);
+               }
+       }
+}
+
 /// An [`OnionMessage`] for [`OnionMessenger`] to send.
 ///
 /// These are obtained when released from [`OnionMessenger`]'s handlers after which they are
@@ -502,7 +559,7 @@ where
                OnionMessenger {
                        entropy_source,
                        node_signer,
-                       pending_messages: Mutex::new(HashMap::new()),
+                       message_buffers: Mutex::new(HashMap::new()),
                        secp_ctx,
                        logger,
                        message_router,
@@ -518,21 +575,23 @@ where
        pub fn send_onion_message<T: OnionMessageContents>(
                &self, path: OnionMessagePath, contents: T, reply_path: Option<BlindedPath>
        ) -> Result<(), SendError> {
-
                log_trace!(self.logger, "Sending onion message: {:?}", contents);
-               
-               let (first_node_id, onion_msg) = create_onion_message(
+
+               let (first_node_id, onion_message) = create_onion_message(
                        &self.entropy_source, &self.node_signer, &self.secp_ctx, path, contents, reply_path
                )?;
 
-               let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
-               if outbound_buffer_full(&first_node_id, &pending_per_peer_msgs) { return Err(SendError::BufferFull) }
-               match pending_per_peer_msgs.entry(first_node_id) {
+               let mut message_buffers = self.message_buffers.lock().unwrap();
+               if outbound_buffer_full(&first_node_id, &message_buffers) {
+                       return Err(SendError::BufferFull);
+               }
+
+               match message_buffers.entry(first_node_id) {
                        hash_map::Entry::Vacant(_) => Err(SendError::InvalidFirstHop),
                        hash_map::Entry::Occupied(mut e) => {
-                               e.get_mut().push_back(onion_msg);
+                               e.get_mut().enqueue_message(onion_message);
                                Ok(())
-                       }
+                       },
                }
        }
 
@@ -565,7 +624,7 @@ where
                        }
                };
 
-               let peers = self.pending_messages.lock().unwrap().keys().copied().collect();
+               let peers = self.message_buffers.lock().unwrap().keys().copied().collect();
                let path = match self.message_router.find_path(sender, peers, destination) {
                        Ok(path) => path,
                        Err(()) => {
@@ -578,30 +637,29 @@ where
 
                if let Err(e) = self.send_onion_message(path, contents, reply_path) {
                        log_trace!(self.logger, "Failed sending onion message {}: {:?}", log_suffix, e);
-                       return;
                }
        }
 
        #[cfg(test)]
        pub(super) fn release_pending_msgs(&self) -> HashMap<PublicKey, VecDeque<OnionMessage>> {
-               let mut pending_msgs = self.pending_messages.lock().unwrap();
+               let mut message_buffers = self.message_buffers.lock().unwrap();
                let mut msgs = HashMap::new();
                // We don't want to disconnect the peers by removing them entirely from the original map, so we
-               // swap the pending message buffers individually.
-               for (peer_node_id, pending_messages) in &mut *pending_msgs {
-                       msgs.insert(*peer_node_id, core::mem::take(pending_messages));
+               // release the pending message buffers individually.
+               for (peer_node_id, buffer) in &mut *message_buffers {
+                       msgs.insert(*peer_node_id, buffer.release_pending_messages());
                }
                msgs
        }
 }
 
-fn outbound_buffer_full(peer_node_id: &PublicKey, buffer: &HashMap<PublicKey, VecDeque<OnionMessage>>) -> bool {
+fn outbound_buffer_full(peer_node_id: &PublicKey, buffer: &HashMap<PublicKey, OnionMessageBuffer>) -> bool {
        const MAX_TOTAL_BUFFER_SIZE: usize = (1 << 20) * 128;
        const MAX_PER_PEER_BUFFER_SIZE: usize = (1 << 10) * 256;
        let mut total_buffered_bytes = 0;
        let mut peer_buffered_bytes = 0;
        for (pk, peer_buf) in buffer {
-               for om in peer_buf {
+               for om in peer_buf.pending_messages() {
                        let om_len = om.serialized_length();
                        if pk == peer_node_id {
                                peer_buffered_bytes += om_len;
@@ -660,24 +718,28 @@ where
                                }
                        },
                        Ok(PeeledOnion::Forward(next_node_id, onion_message)) => {
-                               let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
-                               if outbound_buffer_full(&next_node_id, &pending_per_peer_msgs) {
+                               let mut message_buffers = self.message_buffers.lock().unwrap();
+                               if outbound_buffer_full(&next_node_id, &message_buffers) {
                                        log_trace!(self.logger, "Dropping forwarded onion message to peer {:?}: outbound buffer full", next_node_id);
                                        return
                                }
 
                                #[cfg(fuzzing)]
-                               pending_per_peer_msgs.entry(next_node_id).or_insert_with(VecDeque::new);
-
-                               match pending_per_peer_msgs.entry(next_node_id) {
-                                       hash_map::Entry::Vacant(_) => {
+                               message_buffers
+                                       .entry(next_node_id)
+                                       .or_insert_with(|| OnionMessageBuffer::ConnectedPeer(VecDeque::new()));
+
+                               match message_buffers.entry(next_node_id) {
+                                       hash_map::Entry::Occupied(mut e) if matches!(
+                                               e.get(), OnionMessageBuffer::ConnectedPeer(..)
+                                       ) => {
+                                               e.get_mut().enqueue_message(onion_message);
+                                               log_trace!(self.logger, "Forwarding an onion message to peer {}", next_node_id);
+                                       },
+                                       _ => {
                                                log_trace!(self.logger, "Dropping forwarded onion message to disconnected peer {:?}", next_node_id);
                                                return
                                        },
-                                       hash_map::Entry::Occupied(mut e) => {
-                                               e.get_mut().push_back(onion_message);
-                                               log_trace!(self.logger, "Forwarding an onion message to peer {}", next_node_id);
-                                       }
                                }
                        },
                        Err(e) => {
@@ -688,15 +750,22 @@ where
 
        fn peer_connected(&self, their_node_id: &PublicKey, init: &msgs::Init, _inbound: bool) -> Result<(), ()> {
                if init.features.supports_onion_messages() {
-                       let mut peers = self.pending_messages.lock().unwrap();
-                       peers.insert(their_node_id.clone(), VecDeque::new());
+                       self.message_buffers.lock().unwrap()
+                               .entry(*their_node_id)
+                               .or_insert_with(|| OnionMessageBuffer::ConnectedPeer(VecDeque::new()))
+                               .mark_connected();
+               } else {
+                       self.message_buffers.lock().unwrap().remove(their_node_id);
                }
+
                Ok(())
        }
 
        fn peer_disconnected(&self, their_node_id: &PublicKey) {
-               let mut pending_msgs = self.pending_messages.lock().unwrap();
-               pending_msgs.remove(their_node_id);
+               match self.message_buffers.lock().unwrap().remove(their_node_id) {
+                       Some(OnionMessageBuffer::ConnectedPeer(..)) => {},
+                       _ => debug_assert!(false),
+               }
        }
 
        fn provided_node_features(&self) -> NodeFeatures {
@@ -737,11 +806,9 @@ where
                        );
                }
 
-               let mut pending_msgs = self.pending_messages.lock().unwrap();
-               if let Some(msgs) = pending_msgs.get_mut(&peer_node_id) {
-                       return msgs.pop_front()
-               }
-               None
+               self.message_buffers.lock().unwrap()
+                       .get_mut(&peer_node_id)
+                       .and_then(|buffer| buffer.dequeue_message())
        }
 }