Add a parallel async event handler to `OnionMessenger`
[rust-lightning] / lightning / src / onion_message / messenger.rs
index 1d7a730fa3625126097fd6d25de150e5ba46c742..7bea65ab5af53eb3d9f7e7014454393d892dae86 100644 (file)
@@ -15,8 +15,8 @@ use bitcoin::hashes::hmac::{Hmac, HmacEngine};
 use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey};
 
-use crate::blinded_path::{BlindedPath, IntroductionNode, NodeIdLookUp};
-use crate::blinded_path::message::{advance_path_by_one, ForwardTlvs, NextHop, ReceiveTlvs};
+use crate::blinded_path::{BlindedPath, IntroductionNode, NextMessageHop, NodeIdLookUp};
+use crate::blinded_path::message::{advance_path_by_one, ForwardTlvs, ReceiveTlvs};
 use crate::blinded_path::utils;
 use crate::events::{Event, EventHandler, EventsProvider};
 use crate::sign::{EntropySource, NodeSigner, Recipient};
@@ -47,6 +47,70 @@ use {
 
 pub(super) const MAX_TIMER_TICKS: usize = 2;
 
+/// A trivial trait which describes any [`OnionMessenger`].
+///
+/// This is not exported to bindings users as general cover traits aren't useful in other
+/// languages.
+pub trait AOnionMessenger {
+       /// A type implementing [`EntropySource`]
+       type EntropySource: EntropySource + ?Sized;
+       /// A type that may be dereferenced to [`Self::EntropySource`]
+       type ES: Deref<Target = Self::EntropySource>;
+       /// A type implementing [`NodeSigner`]
+       type NodeSigner: NodeSigner + ?Sized;
+       /// A type that may be dereferenced to [`Self::NodeSigner`]
+       type NS: Deref<Target = Self::NodeSigner>;
+       /// A type implementing [`Logger`]
+       type Logger: Logger + ?Sized;
+       /// A type that may be dereferenced to [`Self::Logger`]
+       type L: Deref<Target = Self::Logger>;
+       /// A type implementing [`NodeIdLookUp`]
+       type NodeIdLookUp: NodeIdLookUp + ?Sized;
+       /// A type that may be dereferenced to [`Self::NodeIdLookUp`]
+       type NL: Deref<Target = Self::NodeIdLookUp>;
+       /// A type implementing [`MessageRouter`]
+       type MessageRouter: MessageRouter + ?Sized;
+       /// A type that may be dereferenced to [`Self::MessageRouter`]
+       type MR: Deref<Target = Self::MessageRouter>;
+       /// A type implementing [`OffersMessageHandler`]
+       type OffersMessageHandler: OffersMessageHandler + ?Sized;
+       /// A type that may be dereferenced to [`Self::OffersMessageHandler`]
+       type OMH: Deref<Target = Self::OffersMessageHandler>;
+       /// A type implementing [`CustomOnionMessageHandler`]
+       type CustomOnionMessageHandler: CustomOnionMessageHandler + ?Sized;
+       /// A type that may be dereferenced to [`Self::CustomOnionMessageHandler`]
+       type CMH: Deref<Target = Self::CustomOnionMessageHandler>;
+       /// Returns a reference to the actual [`OnionMessenger`] object.
+       fn get_om(&self) -> &OnionMessenger<Self::ES, Self::NS, Self::L, Self::NL, Self::MR, Self::OMH, Self::CMH>;
+}
+
+impl<ES: Deref, NS: Deref, L: Deref, NL: Deref, MR: Deref, OMH: Deref, CMH: Deref> AOnionMessenger
+for OnionMessenger<ES, NS, L, NL, MR, OMH, CMH> where
+       ES::Target: EntropySource,
+       NS::Target: NodeSigner,
+       L::Target: Logger,
+       NL::Target: NodeIdLookUp,
+       MR::Target: MessageRouter,
+       OMH::Target: OffersMessageHandler,
+       CMH::Target: CustomOnionMessageHandler,
+{
+       type EntropySource = ES::Target;
+       type ES = ES;
+       type NodeSigner = NS::Target;
+       type NS = NS;
+       type Logger = L::Target;
+       type L = L;
+       type NodeIdLookUp = NL::Target;
+       type NL = NL;
+       type MessageRouter = MR::Target;
+       type MR = MR;
+       type OffersMessageHandler = OMH::Target;
+       type OMH = OMH;
+       type CustomOnionMessageHandler = CMH::Target;
+       type CMH = CMH;
+       fn get_om(&self) -> &OnionMessenger<ES, NS, L, NL, MR, OMH, CMH> { self }
+}
+
 /// A sender, receiver and forwarder of [`OnionMessage`]s.
 ///
 /// # Handling Messages
@@ -135,6 +199,7 @@ pub(super) const MAX_TIMER_TICKS: usize = 2;
 ///            # let your_custom_message_type = 42;
 ///            your_custom_message_type
 ///    }
+///    fn msg_type(&self) -> &'static str { "YourCustomMessageType" }
 /// }
 /// // Send a custom onion message to a node id.
 /// let destination = Destination::Node(destination_node_id);
@@ -175,6 +240,13 @@ where
        message_router: MR,
        offers_handler: OMH,
        custom_handler: CMH,
+       intercept_messages_for_offline_peers: bool,
+       pending_events: Mutex<PendingEvents>,
+}
+
+struct PendingEvents {
+       intercepted_msgs: Vec<Event>,
+       peer_connecteds: Vec<Event>,
 }
 
 /// [`OnionMessage`]s buffered to be sent.
@@ -246,6 +318,50 @@ impl OnionMessageRecipient {
        }
 }
 
+
+/// The `Responder` struct creates an appropriate [`ResponseInstruction`]
+/// for responding to a message.
+pub struct Responder {
+       /// The path along which a response can be sent.
+       reply_path: BlindedPath,
+       path_id: Option<[u8; 32]>
+}
+
+impl Responder {
+       /// Creates a new [`Responder`] instance with the provided reply path.
+       fn new(reply_path: BlindedPath, path_id: Option<[u8; 32]>) -> Self {
+               Responder {
+                       reply_path,
+                       path_id,
+               }
+       }
+
+       /// Creates the appropriate [`ResponseInstruction`] for a given response.
+       pub fn respond<T: OnionMessageContents>(self, response: T) -> ResponseInstruction<T> {
+               ResponseInstruction::WithoutReplyPath(OnionMessageResponse {
+                       message: response,
+                       reply_path: self.reply_path,
+                       path_id: self.path_id,
+               })
+       }
+}
+
+/// This struct contains the information needed to reply to a received message.
+pub struct OnionMessageResponse<T: OnionMessageContents> {
+       message: T,
+       reply_path: BlindedPath,
+       path_id: Option<[u8; 32]>,
+}
+
+/// `ResponseInstruction` represents instructions for responding to received messages.
+pub enum ResponseInstruction<T: OnionMessageContents> {
+       /// Indicates that a response should be sent without including a reply path
+       /// for the recipient to respond back.
+       WithoutReplyPath(OnionMessageResponse<T>),
+       /// Indicates that there's no response to send back.
+       NoResponse,
+}
+
 /// An [`OnionMessage`] for [`OnionMessenger`] to send.
 ///
 /// These are obtained when released from [`OnionMessenger`]'s handlers after which they are
@@ -546,7 +662,7 @@ pub trait CustomOnionMessageHandler {
        /// Called with the custom message that was received, returning a response to send, if any.
        ///
        /// The returned [`Self::CustomMessage`], if any, is enqueued to be sent by [`OnionMessenger`].
-       fn handle_custom_message(&self, msg: Self::CustomMessage) -> Option<Self::CustomMessage>;
+       fn handle_custom_message(&self, message: Self::CustomMessage, responder: Option<Responder>) -> ResponseInstruction<Self::CustomMessage>;
 
        /// Read a custom message of type `message_type` from `buffer`, returning `Ok(None)` if the
        /// message type is unknown.
@@ -569,10 +685,10 @@ pub trait CustomOnionMessageHandler {
 
 /// A processed incoming onion message, containing either a Forward (another onion message)
 /// or a Receive payload with decrypted contents.
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 pub enum PeeledOnion<T: OnionMessageContents> {
        /// Forwarded onion, with the next node id and a new onion
-       Forward(NextHop, OnionMessage),
+       Forward(NextMessageHop, OnionMessage),
        /// Received onion message, with decrypted contents, path_id, and reply path
        Receive(ParsedOnionMessageContents<T>, Option<[u8; 32]>, Option<BlindedPath>)
 }
@@ -796,6 +912,48 @@ where
        pub fn new(
                entropy_source: ES, node_signer: NS, logger: L, node_id_lookup: NL, message_router: MR,
                offers_handler: OMH, custom_handler: CMH
+       ) -> Self {
+               Self::new_inner(
+                       entropy_source, node_signer, logger, node_id_lookup, message_router,
+                       offers_handler, custom_handler, false
+               )
+       }
+
+       /// Similar to [`Self::new`], but rather than dropping onion messages that are
+       /// intended to be forwarded to offline peers, we will intercept them for
+       /// later forwarding.
+       ///
+       /// Interception flow:
+       /// 1. If an onion message for an offline peer is received, `OnionMessenger` will
+       ///    generate an [`Event::OnionMessageIntercepted`]. Event handlers can
+       ///    then choose to persist this onion message for later forwarding, or drop
+       ///    it.
+       /// 2. When the offline peer later comes back online, `OnionMessenger` will
+       ///    generate an [`Event::OnionMessagePeerConnected`]. Event handlers will
+       ///    then fetch all previously intercepted onion messages for this peer.
+       /// 3. Once the stored onion messages are fetched, they can finally be
+       ///    forwarded to the now-online peer via [`Self::forward_onion_message`].
+       ///
+       /// # Note
+       ///
+       /// LDK will not rate limit how many [`Event::OnionMessageIntercepted`]s
+       /// are generated, so it is the caller's responsibility to limit how many
+       /// onion messages are persisted and only persist onion messages for relevant
+       /// peers.
+       pub fn new_with_offline_peer_interception(
+               entropy_source: ES, node_signer: NS, logger: L, node_id_lookup: NL,
+               message_router: MR, offers_handler: OMH, custom_handler: CMH
+       ) -> Self {
+               Self::new_inner(
+                       entropy_source, node_signer, logger, node_id_lookup, message_router,
+                       offers_handler, custom_handler, true
+               )
+       }
+
+       fn new_inner(
+               entropy_source: ES, node_signer: NS, logger: L, node_id_lookup: NL,
+               message_router: MR, offers_handler: OMH, custom_handler: CMH,
+               intercept_messages_for_offline_peers: bool
        ) -> Self {
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&entropy_source.get_secure_random_bytes());
@@ -809,6 +967,11 @@ where
                        message_router,
                        offers_handler,
                        custom_handler,
+                       intercept_messages_for_offline_peers,
+                       pending_events: Mutex::new(PendingEvents {
+                               intercepted_msgs: Vec::new(),
+                               peer_connecteds: Vec::new(),
+                       }),
                }
        }
 
@@ -832,13 +995,12 @@ where
                &self, contents: T, destination: Destination, reply_path: Option<BlindedPath>,
                log_suffix: fmt::Arguments
        ) -> Result<SendSuccess, SendError> {
-               let mut logger = WithContext::from(&self.logger, None, None);
-               let result = self.find_path(destination)
-                       .and_then(|path| {
-                               let first_hop = path.intermediate_nodes.get(0).map(|p| *p);
-                               logger = WithContext::from(&self.logger, first_hop, None);
-                               self.enqueue_onion_message(path, contents, reply_path, log_suffix)
-                       });
+               let mut logger = WithContext::from(&self.logger, None, None, None);
+               let result = self.find_path(destination).and_then(|path| {
+                       let first_hop = path.intermediate_nodes.get(0).map(|p| *p);
+                       logger = WithContext::from(&self.logger, first_hop, None, None);
+                       self.enqueue_onion_message(path, contents, reply_path, log_suffix)
+               });
 
                match result.as_ref() {
                        Err(SendError::GetNodeIdFailed) => {
@@ -917,6 +1079,27 @@ where
                }
        }
 
+       /// Forwards an [`OnionMessage`] to `peer_node_id`. Useful if we initialized
+       /// the [`OnionMessenger`] with [`Self::new_with_offline_peer_interception`]
+       /// and want to forward a previously intercepted onion message to a peer that
+       /// has just come online.
+       pub fn forward_onion_message(
+               &self, message: OnionMessage, peer_node_id: &PublicKey
+       ) -> Result<(), SendError> {
+               let mut message_recipients = self.message_recipients.lock().unwrap();
+               if outbound_buffer_full(&peer_node_id, &message_recipients) {
+                       return Err(SendError::BufferFull);
+               }
+
+               match message_recipients.entry(*peer_node_id) {
+                       hash_map::Entry::Occupied(mut e) if e.get().is_connected() => {
+                               e.get_mut().enqueue_message(message);
+                               Ok(())
+                       },
+                       _ => Err(SendError::InvalidFirstHop(*peer_node_id))
+               }
+       }
+
        #[cfg(any(test, feature = "_test_utils"))]
        pub fn send_onion_message_using_path<T: OnionMessageContents>(
                &self, path: OnionMessagePath, contents: T, reply_path: Option<BlindedPath>
@@ -933,19 +1116,18 @@ where
        }
 
        fn handle_onion_message_response<T: OnionMessageContents>(
-               &self, response: Option<T>, reply_path: Option<BlindedPath>, log_suffix: fmt::Arguments
+               &self, response: ResponseInstruction<T>
        ) {
-               if let Some(response) = response {
-                       match reply_path {
-                               Some(reply_path) => {
-                                       let _ = self.find_path_and_enqueue_onion_message(
-                                               response, Destination::BlindedPath(reply_path), None, log_suffix
-                                       );
-                               },
-                               None => {
-                                       log_trace!(self.logger, "Missing reply path {}", log_suffix);
-                               },
-                       }
+               if let ResponseInstruction::WithoutReplyPath(response) = response {
+                       let message_type = response.message.msg_type();
+                       let _ = self.find_path_and_enqueue_onion_message(
+                               response.message, Destination::BlindedPath(response.reply_path), None,
+                               format_args!(
+                                       "when responding with {} to an onion message with path_id {:02x?}",
+                                       message_type,
+                                       response.path_id
+                               )
+                       );
                }
        }
 
@@ -960,6 +1142,63 @@ where
                }
                msgs
        }
+
+       fn enqueue_intercepted_event(&self, event: Event) {
+               const MAX_EVENTS_BUFFER_SIZE: usize = (1 << 10) * 256;
+               let mut pending_events = self.pending_events.lock().unwrap();
+               let total_buffered_bytes: usize =
+                       pending_events.intercepted_msgs.iter().map(|ev| ev.serialized_length()).sum();
+               if total_buffered_bytes >= MAX_EVENTS_BUFFER_SIZE {
+                       log_trace!(self.logger, "Dropping event {:?}: buffer full", event);
+                       return
+               }
+               pending_events.intercepted_msgs.push(event);
+       }
+
+       /// Processes any events asynchronously using the given handler.
+       ///
+       /// Note that the event handler is called in the order each event was generated, however
+       /// futures are polled in parallel for some events to allow for parallelism where events do not
+       /// have an ordering requirement.
+       ///
+       /// See the trait-level documentation of [`EventsProvider`] for requirements.
+       pub async fn process_pending_events_async<Future: core::future::Future<Output = ()> + core::marker::Unpin, H: Fn(Event) -> Future>(
+               &self, handler: H
+       ) {
+               let mut intercepted_msgs = Vec::new();
+               let mut peer_connecteds = Vec::new();
+               {
+                       let mut pending_events = self.pending_events.lock().unwrap();
+                       core::mem::swap(&mut pending_events.intercepted_msgs, &mut intercepted_msgs);
+                       core::mem::swap(&mut pending_events.peer_connecteds, &mut peer_connecteds);
+               }
+
+               let mut futures = Vec::with_capacity(intercepted_msgs.len());
+               for (node_id, recipient) in self.message_recipients.lock().unwrap().iter_mut() {
+                       if let OnionMessageRecipient::PendingConnection(_, addresses, _) = recipient {
+                               if let Some(addresses) = addresses.take() {
+                                       futures.push(Some(handler(Event::ConnectionNeeded { node_id: *node_id, addresses })));
+                               }
+                       }
+               }
+
+               for ev in intercepted_msgs {
+                       if let Event::OnionMessageIntercepted { .. } = ev {} else { debug_assert!(false); }
+                       futures.push(Some(handler(ev)));
+               }
+               // Let the `OnionMessageIntercepted` events finish before moving on to peer_connecteds
+               crate::util::async_poll::MultiFuturePoller(futures).await;
+
+               if peer_connecteds.len() <= 1 {
+                       for event in peer_connecteds { handler(event).await; }
+               } else {
+                       let mut futures = Vec::new();
+                       for event in peer_connecteds {
+                               futures.push(Some(handler(event)));
+                       }
+                       crate::util::async_poll::MultiFuturePoller(futures).await;
+               }
+       }
 }
 
 fn outbound_buffer_full(peer_node_id: &PublicKey, buffer: &HashMap<PublicKey, OnionMessageRecipient>) -> bool {
@@ -1004,6 +1243,24 @@ where
                                }
                        }
                }
+               let mut events = Vec::new();
+               {
+                       let mut pending_events = self.pending_events.lock().unwrap();
+                       #[cfg(debug_assertions)] {
+                               for ev in pending_events.intercepted_msgs.iter() {
+                                       if let Event::OnionMessageIntercepted { .. } = ev {} else { panic!(); }
+                               }
+                               for ev in pending_events.peer_connecteds.iter() {
+                                       if let Event::OnionMessagePeerConnected { .. } = ev {} else { panic!(); }
+                               }
+                       }
+                       core::mem::swap(&mut pending_events.intercepted_msgs, &mut events);
+                       events.append(&mut pending_events.peer_connecteds);
+                       pending_events.peer_connecteds.shrink_to(10); // Limit total heap usage
+               }
+               for ev in events {
+                       handler.handle_event(ev);
+               }
        }
 }
 
@@ -1019,7 +1276,7 @@ where
        CMH::Target: CustomOnionMessageHandler,
 {
        fn handle_onion_message(&self, peer_node_id: &PublicKey, msg: &OnionMessage) {
-               let logger = WithContext::from(&self.logger, Some(*peer_node_id), None);
+               let logger = WithContext::from(&self.logger, Some(*peer_node_id), None, None);
                match self.peel_onion_message(msg) {
                        Ok(PeeledOnion::Receive(message, path_id, reply_path)) => {
                                log_trace!(
@@ -1029,29 +1286,25 @@ where
 
                                match message {
                                        ParsedOnionMessageContents::Offers(msg) => {
-                                               let response = self.offers_handler.handle_message(msg);
-                                               self.handle_onion_message_response(
-                                                       response, reply_path, format_args!(
-                                                               "when responding to Offers onion message with path_id {:02x?}",
-                                                               path_id
-                                                       )
+                                               let responder = reply_path.map(
+                                                       |reply_path| Responder::new(reply_path, path_id)
                                                );
+                                               let response_instructions = self.offers_handler.handle_message(msg, responder);
+                                               self.handle_onion_message_response(response_instructions);
                                        },
                                        ParsedOnionMessageContents::Custom(msg) => {
-                                               let response = self.custom_handler.handle_custom_message(msg);
-                                               self.handle_onion_message_response(
-                                                       response, reply_path, format_args!(
-                                                               "when responding to Custom onion message with path_id {:02x?}",
-                                                               path_id
-                                                       )
+                                               let responder = reply_path.map(
+                                                       |reply_path| Responder::new(reply_path, path_id)
                                                );
+                                               let response_instructions = self.custom_handler.handle_custom_message(msg, responder);
+                                               self.handle_onion_message_response(response_instructions);
                                        },
                                }
                        },
                        Ok(PeeledOnion::Forward(next_hop, onion_message)) => {
                                let next_node_id = match next_hop {
-                                       NextHop::NodeId(pubkey) => pubkey,
-                                       NextHop::ShortChannelId(scid) => match self.node_id_lookup.next_node_id(scid) {
+                                       NextMessageHop::NodeId(pubkey) => pubkey,
+                                       NextMessageHop::ShortChannelId(scid) => match self.node_id_lookup.next_node_id(scid) {
                                                Some(pubkey) => pubkey,
                                                None => {
                                                        log_trace!(self.logger, "Dropping forwarded onion messager: unable to resolve next hop using SCID {}", scid);
@@ -1081,6 +1334,13 @@ where
                                                e.get_mut().enqueue_message(onion_message);
                                                log_trace!(logger, "Forwarding an onion message to peer {}", next_node_id);
                                        },
+                                       _ if self.intercept_messages_for_offline_peers => {
+                                               self.enqueue_intercepted_event(
+                                                       Event::OnionMessageIntercepted {
+                                                               peer_node_id: next_node_id, message: onion_message
+                                                       }
+                                               );
+                                       },
                                        _ => {
                                                log_trace!(
                                                        logger,
@@ -1102,6 +1362,11 @@ where
                                .entry(*their_node_id)
                                .or_insert_with(|| OnionMessageRecipient::ConnectedPeer(VecDeque::new()))
                                .mark_connected();
+                       if self.intercept_messages_for_offline_peers {
+                               self.pending_events.lock().unwrap().peer_connecteds.push(
+                                       Event::OnionMessagePeerConnected { peer_node_id: *their_node_id }
+                               );
+                       }
                } else {
                        self.message_recipients.lock().unwrap().remove(their_node_id);
                }
@@ -1255,7 +1520,7 @@ fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + sec
                                if let Some(ss) = prev_control_tlvs_ss.take() {
                                        payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(
                                                ForwardTlvs {
-                                                       next_hop: NextHop::NodeId(unblinded_pk_opt.unwrap()),
+                                                       next_hop: NextMessageHop::NodeId(unblinded_pk_opt.unwrap()),
                                                        next_blinding_override: None,
                                                }
                                        )), ss));
@@ -1265,7 +1530,7 @@ fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + sec
                        } else if let Some((intro_node_id, blinding_pt)) = intro_node_id_blinding_pt.take() {
                                if let Some(control_tlvs_ss) = prev_control_tlvs_ss.take() {
                                        payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs {
-                                               next_hop: NextHop::NodeId(intro_node_id),
+                                               next_hop: NextMessageHop::NodeId(intro_node_id),
                                                next_blinding_override: Some(blinding_pt),
                                        })), control_tlvs_ss));
                                }