Merge pull request #1666 from TheBlueMatt/2022-08-fix-script-check
[rust-lightning] / lightning-net-tokio / src / lib.rs
index 22df5c3e8bd4d4e8a88d863d751e95725dfdd211..ac9d4bb3bd5899a08e0dfaadebd811fb779e2c04 100644 (file)
@@ -23,7 +23,7 @@
 //! # Example
 //! ```
 //! use std::net::TcpStream;
-//! use bitcoin::secp256k1::key::PublicKey;
+//! use bitcoin::secp256k1::PublicKey;
 //! use lightning::util::events::{Event, EventHandler, EventsProvider};
 //! use std::net::SocketAddr;
 //! use std::sync::Arc;
 //! }
 //! ```
 
+// Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings.
 #![deny(broken_intra_doc_links)]
-#![deny(missing_docs)]
+#![deny(private_intra_doc_links)]
 
+#![deny(missing_docs)]
 #![cfg_attr(docsrs, feature(doc_auto_cfg))]
 
-use bitcoin::secp256k1::key::PublicKey;
+use bitcoin::secp256k1::PublicKey;
 
 use tokio::net::TcpStream;
 use tokio::{io, time};
@@ -84,6 +86,7 @@ use lightning::ln::peer_handler::CustomMessageHandler;
 use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler, NetAddress};
 use lightning::util::logger::Logger;
 
+use std::ops::Deref;
 use std::task;
 use std::net::SocketAddr;
 use std::net::TcpStream as StdTcpStream;
@@ -120,11 +123,38 @@ struct Connection {
        id: u64,
 }
 impl Connection {
-       async fn schedule_read<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) where
-                       CMH: ChannelMessageHandler + 'static,
-                       RMH: RoutingMessageHandler + 'static,
-                       L: Logger + 'static + ?Sized,
-                       UMH: CustomMessageHandler + 'static {
+       async fn poll_event_process<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, L, UMH>>, mut event_receiver: mpsc::Receiver<()>) where
+                       CMH: Deref + 'static + Send + Sync,
+                       RMH: Deref + 'static + Send + Sync,
+                       L: Deref + 'static + Send + Sync,
+                       UMH: Deref + 'static + Send + Sync,
+                       CMH::Target: ChannelMessageHandler + Send + Sync,
+                       RMH::Target: RoutingMessageHandler + Send + Sync,
+                       L::Target: Logger + Send + Sync,
+                       UMH::Target: CustomMessageHandler + Send + Sync,
+    {
+               loop {
+                       if event_receiver.recv().await.is_none() {
+                               return;
+                       }
+                       peer_manager.process_events();
+               }
+       }
+
+       async fn schedule_read<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, L, UMH>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) where
+                       CMH: Deref + 'static + Send + Sync,
+                       RMH: Deref + 'static + Send + Sync,
+                       L: Deref + 'static + Send + Sync,
+                       UMH: Deref + 'static + Send + Sync,
+                       CMH::Target: ChannelMessageHandler + 'static + Send + Sync,
+                       RMH::Target: RoutingMessageHandler + 'static + Send + Sync,
+                       L::Target: Logger + 'static + Send + Sync,
+                       UMH::Target: CustomMessageHandler + 'static + Send + Sync,
+        {
+               // Create a waker to wake up poll_event_process, above
+               let (event_waker, event_receiver) = mpsc::channel(1);
+               tokio::spawn(Self::poll_event_process(Arc::clone(&peer_manager), event_receiver));
+
                // 8KB is nice and big but also should never cause any issues with stack overflowing.
                let mut buf = [0; 8192];
 
@@ -175,7 +205,14 @@ impl Connection {
                                        Err(_) => break Disconnect::PeerDisconnected,
                                },
                        }
-                       peer_manager.process_events();
+                       let _ = event_waker.try_send(());
+
+                       // At this point we've processed a message or two, and reset the ping timer for this
+                       // peer, at least in the "are we still receiving messages" context, if we don't give up
+                       // our timeslice to another task we may just spin on this peer, starving other peers
+                       // and eventually disconnecting them for ping timeouts. Instead, we explicitly yield
+                       // here.
+                       tokio::task::yield_now().await;
                };
                let writer_option = us.lock().unwrap().writer.take();
                if let Some(mut writer) = writer_option {
@@ -231,11 +268,16 @@ fn get_addr_from_stream(stream: &StdTcpStream) -> Option<NetAddress> {
 /// The returned future will complete when the peer is disconnected and associated handling
 /// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
 /// not need to poll the provided future in order to make progress.
-pub fn setup_inbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
-               CMH: ChannelMessageHandler + 'static + Send + Sync,
-               RMH: RoutingMessageHandler + 'static + Send + Sync,
-               L: Logger + 'static + ?Sized + Send + Sync,
-               UMH: CustomMessageHandler + 'static + Send + Sync {
+pub fn setup_inbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, L, UMH>>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
+               CMH: Deref + 'static + Send + Sync,
+               RMH: Deref + 'static + Send + Sync,
+               L: Deref + 'static + Send + Sync,
+               UMH: Deref + 'static + Send + Sync,
+               CMH::Target: ChannelMessageHandler + Send + Sync,
+               RMH::Target: RoutingMessageHandler + Send + Sync,
+               L::Target: Logger + Send + Sync,
+               UMH::Target: CustomMessageHandler + Send + Sync,
+{
        let remote_addr = get_addr_from_stream(&stream);
        let (reader, write_receiver, read_receiver, us) = Connection::new(stream);
        #[cfg(debug_assertions)]
@@ -273,11 +315,16 @@ pub fn setup_inbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManag
 /// The returned future will complete when the peer is disconnected and associated handling
 /// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
 /// not need to poll the provided future in order to make progress.
-pub fn setup_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
-               CMH: ChannelMessageHandler + 'static + Send + Sync,
-               RMH: RoutingMessageHandler + 'static + Send + Sync,
-               L: Logger + 'static + ?Sized + Send + Sync,
-               UMH: CustomMessageHandler + 'static + Send + Sync {
+pub fn setup_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, L, UMH>>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
+               CMH: Deref + 'static + Send + Sync,
+               RMH: Deref + 'static + Send + Sync,
+               L: Deref + 'static + Send + Sync,
+               UMH: Deref + 'static + Send + Sync,
+               CMH::Target: ChannelMessageHandler + Send + Sync,
+               RMH::Target: RoutingMessageHandler + Send + Sync,
+               L::Target: Logger + Send + Sync,
+               UMH::Target: CustomMessageHandler + Send + Sync,
+{
        let remote_addr = get_addr_from_stream(&stream);
        let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream);
        #[cfg(debug_assertions)]
@@ -344,11 +391,16 @@ pub fn setup_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerMana
 /// disconnected and associated handling futures are freed, though, because all processing in said
 /// futures are spawned with tokio::spawn, you do not need to poll the second future in order to
 /// make progress.
-pub async fn connect_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>, Arc<UMH>>>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> where
-               CMH: ChannelMessageHandler + 'static + Send + Sync,
-               RMH: RoutingMessageHandler + 'static + Send + Sync,
-               L: Logger + 'static + ?Sized + Send + Sync,
-               UMH: CustomMessageHandler + 'static + Send + Sync {
+pub async fn connect_outbound<CMH, RMH, L, UMH>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, CMH, RMH, L, UMH>>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> where
+               CMH: Deref + 'static + Send + Sync,
+               RMH: Deref + 'static + Send + Sync,
+               L: Deref + 'static + Send + Sync,
+               UMH: Deref + 'static + Send + Sync,
+               CMH::Target: ChannelMessageHandler + Send + Sync,
+               RMH::Target: RoutingMessageHandler + Send + Sync,
+               L::Target: Logger + Send + Sync,
+               UMH::Target: CustomMessageHandler + Send + Sync,
+{
        if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
                Some(setup_outbound(peer_manager, their_node_id, stream))
        } else { None }
@@ -443,6 +495,9 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
                                        // pause read given we're now waiting on the remote end to ACK (and in
                                        // accordance with the send_data() docs).
                                        us.read_paused = true;
+                                       // Further, to avoid any current pending read causing a `read_event` call, wake
+                                       // up the read_waker and restart its loop.
+                                       let _ = us.read_waker.try_send(());
                                        return written_len;
                                },
                        }
@@ -509,8 +564,8 @@ mod tests {
                fn handle_node_announcement(&self, _msg: &NodeAnnouncement) -> Result<bool, LightningError> { Ok(false) }
                fn handle_channel_announcement(&self, _msg: &ChannelAnnouncement) -> Result<bool, LightningError> { Ok(false) }
                fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result<bool, LightningError> { Ok(false) }
-               fn get_next_channel_announcements(&self, _starting_point: u64, _batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> { Vec::new() }
-               fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec<NodeAnnouncement> { Vec::new() }
+               fn get_next_channel_announcement(&self, _starting_point: u64) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> { None }
+               fn get_next_node_announcement(&self, _starting_point: Option<&PublicKey>) -> Option<NodeAnnouncement> { None }
                fn peer_connected(&self, _their_node_id: &PublicKey, _init_msg: &Init) { }
                fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: ReplyChannelRange) -> Result<(), LightningError> { Ok(()) }
                fn handle_reply_short_channel_ids_end(&self, _their_node_id: &PublicKey, _msg: ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) }
@@ -522,7 +577,7 @@ mod tests {
                fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &AcceptChannel) {}
                fn handle_funding_created(&self, _their_node_id: &PublicKey, _msg: &FundingCreated) {}
                fn handle_funding_signed(&self, _their_node_id: &PublicKey, _msg: &FundingSigned) {}
-               fn handle_funding_locked(&self, _their_node_id: &PublicKey, _msg: &FundingLocked) {}
+               fn handle_channel_ready(&self, _their_node_id: &PublicKey, _msg: &ChannelReady) {}
                fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, _msg: &Shutdown) {}
                fn handle_closing_signed(&self, _their_node_id: &PublicKey, _msg: &ClosingSigned) {}
                fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, _msg: &UpdateAddHTLC) {}