Merge pull request #2249 from TheBlueMatt/2023-04-less-pm-bounds
[rust-lightning] / lightning / src / ln / peer_handler.rs
index 040ccc655f4b93187ea52eefce46ecbe166aa8d9..25ac234a847c5959f5414c74755d975b56a173f3 100644 (file)
@@ -259,10 +259,11 @@ impl Deref for ErroringMessageHandler {
 }
 
 /// Provides references to trait impls which handle different types of messages.
-pub struct MessageHandler<CM: Deref, RM: Deref, OM: Deref> where
-               CM::Target: ChannelMessageHandler,
-               RM::Target: RoutingMessageHandler,
-               OM::Target: OnionMessageHandler,
+pub struct MessageHandler<CM: Deref, RM: Deref, OM: Deref, CustomM: Deref> where
+       CM::Target: ChannelMessageHandler,
+       RM::Target: RoutingMessageHandler,
+       OM::Target: OnionMessageHandler,
+       CustomM::Target: CustomMessageHandler,
 {
        /// A message handler which handles messages specific to channels. Usually this is just a
        /// [`ChannelManager`] object or an [`ErroringMessageHandler`].
@@ -275,9 +276,15 @@ pub struct MessageHandler<CM: Deref, RM: Deref, OM: Deref> where
        /// [`P2PGossipSync`]: crate::routing::gossip::P2PGossipSync
        pub route_handler: RM,
 
-       /// A message handler which handles onion messages. For now, this can only be an
-       /// [`IgnoringMessageHandler`].
+       /// A message handler which handles onion messages. This should generally be an
+       /// [`OnionMessenger`], but can also be an [`IgnoringMessageHandler`].
+       ///
+       /// [`OnionMessenger`]: crate::onion_message::OnionMessenger
        pub onion_message_handler: OM,
+
+       /// A message handler which handles custom messages. The only LDK-provided implementation is
+       /// [`IgnoringMessageHandler`].
+       pub custom_message_handler: CustomM,
 }
 
 /// Provides an object which can be used to send data to and which uniquely identifies a connection
@@ -535,6 +542,54 @@ pub type SimpleArcPeerManager<SD, M, T, F, C, L> = PeerManager<SD, Arc<SimpleArc
 /// This is not exported to bindings users as general type aliases don't make sense in bindings.
 pub type SimpleRefPeerManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j, 'k, 'l, 'm, SD, M, T, F, C, L> = PeerManager<SD, SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'm, M, T, F, L>, &'f P2PGossipSync<&'g NetworkGraph<&'f L>, &'h C, &'f L>, &'i SimpleRefOnionMessenger<'j, 'k, L>, &'f L, IgnoringMessageHandler, &'c KeysManager>;
 
+
+/// A generic trait which is implemented for all [`PeerManager`]s. This makes bounding functions or
+/// structs on any [`PeerManager`] much simpler as only this trait is needed as a bound, rather
+/// than the full set of bounds on [`PeerManager`] itself.
+#[allow(missing_docs)]
+pub trait APeerManager {
+       type Descriptor: SocketDescriptor;
+       type CMT: ChannelMessageHandler + ?Sized;
+       type CM: Deref<Target=Self::CMT>;
+       type RMT: RoutingMessageHandler + ?Sized;
+       type RM: Deref<Target=Self::RMT>;
+       type OMT: OnionMessageHandler + ?Sized;
+       type OM: Deref<Target=Self::OMT>;
+       type LT: Logger + ?Sized;
+       type L: Deref<Target=Self::LT>;
+       type CMHT: CustomMessageHandler + ?Sized;
+       type CMH: Deref<Target=Self::CMHT>;
+       type NST: NodeSigner + ?Sized;
+       type NS: Deref<Target=Self::NST>;
+       /// Gets a reference to the underlying [`PeerManager`].
+       fn as_ref(&self) -> &PeerManager<Self::Descriptor, Self::CM, Self::RM, Self::OM, Self::L, Self::CMH, Self::NS>;
+}
+
+impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CMH: Deref, NS: Deref>
+APeerManager for PeerManager<Descriptor, CM, RM, OM, L, CMH, NS> where
+       CM::Target: ChannelMessageHandler,
+       RM::Target: RoutingMessageHandler,
+       OM::Target: OnionMessageHandler,
+       L::Target: Logger,
+       CMH::Target: CustomMessageHandler,
+       NS::Target: NodeSigner,
+{
+       type Descriptor = Descriptor;
+       type CMT = <CM as Deref>::Target;
+       type CM = CM;
+       type RMT = <RM as Deref>::Target;
+       type RM = RM;
+       type OMT = <OM as Deref>::Target;
+       type OM = OM;
+       type LT = <L as Deref>::Target;
+       type L = L;
+       type CMHT = <CMH as Deref>::Target;
+       type CMH = CMH;
+       type NST = <NS as Deref>::Target;
+       type NS = NS;
+       fn as_ref(&self) -> &PeerManager<Descriptor, CM, RM, OM, L, CMH, NS> { self }
+}
+
 /// A PeerManager manages a set of peers, described by their [`SocketDescriptor`] and marshalls
 /// socket events into messages which it passes on to its [`MessageHandler`].
 ///
@@ -561,7 +616,7 @@ pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: D
                L::Target: Logger,
                CMH::Target: CustomMessageHandler,
                NS::Target: NodeSigner {
-       message_handler: MessageHandler<CM, RM, OM>,
+       message_handler: MessageHandler<CM, RM, OM, CMH>,
        /// Connection state for each connected peer - we have an outer read-write lock which is taken
        /// as read while we're doing processing for a peer and taken write when a peer is being added
        /// or removed.
@@ -591,7 +646,6 @@ pub struct PeerManager<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: D
        last_node_announcement_serial: AtomicU32,
 
        ephemeral_key_midstate: Sha256Engine,
-       custom_message_handler: CMH,
 
        peer_counter: AtomicCounter,
 
@@ -652,7 +706,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, OM: Deref, L: Deref, NS: Deref> Pe
                        chan_handler: channel_message_handler,
                        route_handler: IgnoringMessageHandler{},
                        onion_message_handler,
-               }, current_time, ephemeral_random_data, logger, IgnoringMessageHandler{}, node_signer)
+                       custom_message_handler: IgnoringMessageHandler{},
+               }, current_time, ephemeral_random_data, logger, node_signer)
        }
 }
 
@@ -679,7 +734,8 @@ impl<Descriptor: SocketDescriptor, RM: Deref, L: Deref, NS: Deref> PeerManager<D
                        chan_handler: ErroringMessageHandler::new(),
                        route_handler: routing_message_handler,
                        onion_message_handler: IgnoringMessageHandler{},
-               }, current_time, ephemeral_random_data, logger, IgnoringMessageHandler{}, node_signer)
+                       custom_message_handler: IgnoringMessageHandler{},
+               }, current_time, ephemeral_random_data, logger, node_signer)
        }
 }
 
@@ -741,7 +797,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
        /// incremented irregularly internally. In general it is best to simply use the current UNIX
        /// timestamp, however if it is not available a persistent counter that increases once per
        /// minute should suffice.
-       pub fn new(message_handler: MessageHandler<CM, RM, OM>, current_time: u32, ephemeral_random_data: &[u8; 32], logger: L, custom_message_handler: CMH, node_signer: NS) -> Self {
+       pub fn new(message_handler: MessageHandler<CM, RM, OM, CMH>, current_time: u32, ephemeral_random_data: &[u8; 32], logger: L, node_signer: NS) -> Self {
                let mut ephemeral_key_midstate = Sha256::engine();
                ephemeral_key_midstate.input(ephemeral_random_data);
 
@@ -761,7 +817,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        gossip_processing_backlog_lifted: AtomicBool::new(false),
                        last_node_announcement_serial: AtomicU32::new(current_time),
                        logger,
-                       custom_message_handler,
                        node_signer,
                        secp_ctx,
                }
@@ -1232,7 +1287,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                        peer.pending_read_is_header = true;
 
                                                                        let mut reader = io::Cursor::new(&msg_data[..]);
-                                                                       let message_result = wire::read(&mut reader, &*self.custom_message_handler);
+                                                                       let message_result = wire::read(&mut reader, &*self.message_handler.custom_message_handler);
                                                                        let message = match message_result {
                                                                                Ok(x) => x,
                                                                                Err(e) => {
@@ -1543,7 +1598,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                log_trace!(self.logger, "Received unknown odd message of type {}, ignoring", type_id);
                        },
                        wire::Message::Custom(custom) => {
-                               self.custom_message_handler.handle_custom_message(custom, &their_node_id)?;
+                               self.message_handler.custom_message_handler.handle_custom_message(custom, &their_node_id)?;
                        },
                };
                Ok(should_forward)
@@ -1896,7 +1951,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                }
                        }
 
-                       for (node_id, msg) in self.custom_message_handler.get_and_clear_pending_msg() {
+                       for (node_id, msg) in self.message_handler.custom_message_handler.get_and_clear_pending_msg() {
                                if peers_to_disconnect.get(&node_id).is_some() { continue; }
                                self.enqueue_message(&mut *get_peer_for_forwarding!(&node_id), &msg);
                        }
@@ -2264,8 +2319,11 @@ mod tests {
                let mut peers = Vec::new();
                for i in 0..peer_count {
                        let ephemeral_bytes = [i as u8; 32];
-                       let msg_handler = MessageHandler { chan_handler: &cfgs[i].chan_handler, route_handler: &cfgs[i].routing_handler, onion_message_handler: IgnoringMessageHandler {} };
-                       let peer = PeerManager::new(msg_handler, 0, &ephemeral_bytes, &cfgs[i].logger, IgnoringMessageHandler {}, &cfgs[i].node_signer);
+                       let msg_handler = MessageHandler {
+                               chan_handler: &cfgs[i].chan_handler, route_handler: &cfgs[i].routing_handler,
+                               onion_message_handler: IgnoringMessageHandler {}, custom_message_handler: IgnoringMessageHandler {}
+                       };
+                       let peer = PeerManager::new(msg_handler, 0, &ephemeral_bytes, &cfgs[i].logger, &cfgs[i].node_signer);
                        peers.push(peer);
                }