[net-tokio] Explicitly yield after processing messages from a peer
[rust-lightning] / lightning-net-tokio / src / lib.rs
index 25c161d2aeb07f8ba5305956e93c9299aca303f1..3cfed870b31f59ae1b36ffbf7a1987eee2b0df00 100644 (file)
@@ -34,7 +34,7 @@
 //! type Logger = dyn lightning::util::logger::Logger + Send + Sync;
 //! type ChainAccess = dyn lightning::chain::Access + Send + Sync;
 //! type ChainFilter = dyn lightning::chain::Filter + Send + Sync;
-//! type DataPersister = dyn lightning::chain::channelmonitor::Persist<lightning::chain::keysinterface::InMemorySigner> + Send + Sync;
+//! type DataPersister = dyn lightning::chain::chainmonitor::Persist<lightning::chain::keysinterface::InMemorySigner> + Send + Sync;
 //! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor<lightning::chain::keysinterface::InMemorySigner, Arc<ChainFilter>, Arc<TxBroadcaster>, Arc<FeeEstimator>, Arc<Logger>, Arc<DataPersister>>;
 //! type ChannelManager = Arc<lightning::ln::channelmanager::SimpleArcChannelManager<ChainMonitor, TxBroadcaster, FeeEstimator, Logger>>;
 //! type PeerManager = Arc<lightning::ln::peer_handler::SimpleArcPeerManager<lightning_net_tokio::SocketDescriptor, ChainMonitor, TxBroadcaster, FeeEstimator, ChainAccess, Logger>>;
 //! async fn connect_to_node(peer_manager: PeerManager, chain_monitor: Arc<ChainMonitor>, channel_manager: ChannelManager, their_node_id: PublicKey, addr: SocketAddr) {
 //!    lightning_net_tokio::connect_outbound(peer_manager, their_node_id, addr).await;
 //!    loop {
-//!            channel_manager.await_persistable_update();
-//!            channel_manager.process_pending_events(&|event| {
-//!                    // Handle the event!
-//!            });
-//!            chain_monitor.process_pending_events(&|event| {
+//!            let event_handler = |event: &Event| {
 //!                    // Handle the event!
-//!            });
+//!            };
+//!            channel_manager.await_persistable_update();
+//!            channel_manager.process_pending_events(&event_handler);
+//!            chain_monitor.process_pending_events(&event_handler);
 //!    }
 //! }
 //!
 //! async fn accept_socket(peer_manager: PeerManager, chain_monitor: Arc<ChainMonitor>, channel_manager: ChannelManager, socket: TcpStream) {
 //!    lightning_net_tokio::setup_inbound(peer_manager, socket);
 //!    loop {
-//!            channel_manager.await_persistable_update();
-//!            channel_manager.process_pending_events(&|event| {
+//!            let event_handler = |event: &Event| {
 //!                    // Handle the event!
-//!            });
-//!            chain_monitor.process_pending_events(&|event| {
-//!                    // Handle the event!
-//!            });
+//!            };
+//!            channel_manager.await_persistable_update();
+//!            channel_manager.process_pending_events(&event_handler);
+//!            chain_monitor.process_pending_events(&event_handler);
 //!    }
 //! }
 //! ```
@@ -71,6 +69,8 @@
 #![deny(broken_intra_doc_links)]
 #![deny(missing_docs)]
 
+#![cfg_attr(docsrs, feature(doc_auto_cfg))]
+
 use bitcoin::secp256k1::key::PublicKey;
 
 use tokio::net::TcpStream;
@@ -81,10 +81,11 @@ use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
 use lightning::ln::peer_handler;
 use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
 use lightning::ln::peer_handler::CustomMessageHandler;
-use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
+use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler, NetAddress};
 use lightning::util::logger::Logger;
 
 use std::task;
+use std::net::IpAddr;
 use std::net::SocketAddr;
 use std::net::TcpStream as StdTcpStream;
 use std::sync::{Arc, Mutex};
@@ -120,11 +121,28 @@ struct Connection {
        id: u64,
 }
 impl Connection {
+       async fn poll_event_process<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, mut event_receiver: mpsc::Receiver<()>) where
+                       CMH: ChannelMessageHandler + 'static + Send + Sync,
+                       RMH: RoutingMessageHandler + 'static + Send + Sync,
+                       L: Logger + 'static + ?Sized + Send + Sync,
+                       UMH: CustomMessageHandler + 'static + Send + Sync {
+               loop {
+                       if event_receiver.recv().await.is_none() {
+                               return;
+                       }
+                       peer_manager.process_events();
+               }
+       }
+
        async fn schedule_read<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) where
-                       CMH: ChannelMessageHandler + 'static,
-                       RMH: RoutingMessageHandler + 'static,
-                       L: Logger + 'static + ?Sized,
-                       UMH: CustomMessageHandler + 'static {
+                       CMH: ChannelMessageHandler + 'static + Send + Sync,
+                       RMH: RoutingMessageHandler + 'static + Send + Sync,
+                       L: Logger + 'static + ?Sized + Send + Sync,
+                       UMH: CustomMessageHandler + 'static + Send + Sync {
+               // Create a waker to wake up poll_event_process, above
+               let (event_waker, event_receiver) = mpsc::channel(1);
+               tokio::spawn(Self::poll_event_process(Arc::clone(&peer_manager), event_receiver));
+
                // 8KB is nice and big but also should never cause any issues with stack overflowing.
                let mut buf = [0; 8192];
 
@@ -175,7 +193,14 @@ impl Connection {
                                        Err(_) => break Disconnect::PeerDisconnected,
                                },
                        }
-                       peer_manager.process_events();
+                       let _ = event_waker.try_send(());
+
+                       // At this point we've processed a message or two, and reset the ping timer for this
+                       // peer, at least in the "are we still receiving messages" context, if we don't give up
+                       // our timeslice to another task we may just spin on this peer, starving other peers
+                       // and eventually disconnecting them for ping timeouts. Instead, we explicitly yield
+                       // here.
+                       tokio::task::yield_now().await;
                };
                let writer_option = us.lock().unwrap().writer.take();
                if let Some(mut writer) = writer_option {
@@ -222,11 +247,21 @@ pub fn setup_inbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManag
                RMH: RoutingMessageHandler + 'static + Send + Sync,
                L: Logger + 'static + ?Sized + Send + Sync,
                UMH: CustomMessageHandler + 'static + Send + Sync {
+       let ip_addr = stream.peer_addr().unwrap();
        let (reader, write_receiver, read_receiver, us) = Connection::new(stream);
        #[cfg(debug_assertions)]
        let last_us = Arc::clone(&us);
 
-       let handle_opt = if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(us.clone())) {
+       let handle_opt = if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(us.clone()), match ip_addr.ip() {
+               IpAddr::V4(ip) => Some(NetAddress::IPv4 {
+                       addr: ip.octets(),
+                       port: ip_addr.port(),
+               }),
+               IpAddr::V6(ip) => Some(NetAddress::IPv6 {
+                       addr: ip.octets(),
+                       port: ip_addr.port(),
+               }),
+       }) {
                Some(tokio::spawn(Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver)))
        } else {
                // Note that we will skip socket_disconnected here, in accordance with the PeerManager
@@ -263,11 +298,20 @@ pub fn setup_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerMana
                RMH: RoutingMessageHandler + 'static + Send + Sync,
                L: Logger + 'static + ?Sized + Send + Sync,
                UMH: CustomMessageHandler + 'static + Send + Sync {
+       let ip_addr = stream.peer_addr().unwrap();
        let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream);
        #[cfg(debug_assertions)]
        let last_us = Arc::clone(&us);
-
-       let handle_opt = if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone())) {
+       let handle_opt = if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone()), match ip_addr.ip() {
+               IpAddr::V4(ip) => Some(NetAddress::IPv4 {
+                       addr: ip.octets(),
+                       port: ip_addr.port(),
+               }),
+               IpAddr::V6(ip) => Some(NetAddress::IPv6 {
+                       addr: ip.octets(),
+                       port: ip_addr.port(),
+               }),
+       }) {
                Some(tokio::spawn(async move {
                        // We should essentially always have enough room in a TCP socket buffer to send the
                        // initial 10s of bytes. However, tokio running in single-threaded mode will always
@@ -428,6 +472,9 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
                                        // pause read given we're now waiting on the remote end to ACK (and in
                                        // accordance with the send_data() docs).
                                        us.read_paused = true;
+                                       // Further, to avoid any current pending read causing a `read_event` call, wake
+                                       // up the read_waker and restart its loop.
+                                       let _ = us.read_waker.try_send(());
                                        return written_len;
                                },
                        }
@@ -494,10 +541,9 @@ mod tests {
                fn handle_node_announcement(&self, _msg: &NodeAnnouncement) -> Result<bool, LightningError> { Ok(false) }
                fn handle_channel_announcement(&self, _msg: &ChannelAnnouncement) -> Result<bool, LightningError> { Ok(false) }
                fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result<bool, LightningError> { Ok(false) }
-               fn handle_htlc_fail_channel_update(&self, _update: &HTLCFailChannelUpdate) { }
                fn get_next_channel_announcements(&self, _starting_point: u64, _batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> { Vec::new() }
                fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec<NodeAnnouncement> { Vec::new() }
-               fn sync_routing_table(&self, _their_node_id: &PublicKey, _init_msg: &Init) { }
+               fn peer_connected(&self, _their_node_id: &PublicKey, _init_msg: &Init) { }
                fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: ReplyChannelRange) -> Result<(), LightningError> { Ok(()) }
                fn handle_reply_short_channel_ids_end(&self, _their_node_id: &PublicKey, _msg: ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) }
                fn handle_query_channel_range(&self, _their_node_id: &PublicKey, _msg: QueryChannelRange) -> Result<(), LightningError> { Ok(()) }