PeerManager Logger Arc --> Deref
[rust-lightning] / lightning-net-tokio / src / lib.rs
index 08aa12571985a985f896929e2034a248ff77ca1a..15226642a4105aee94c511334952408e42df5964 100644 (file)
@@ -16,7 +16,7 @@
 //! ```
 //! use tokio::sync::mpsc;
 //! use tokio::net::TcpStream;
-//! use secp256k1::key::PublicKey;
+//! use bitcoin::secp256k1::key::PublicKey;
 //! use lightning::util::events::EventsProvider;
 //! use std::net::SocketAddr;
 //! use std::sync::Arc;
 //! // Define concrete types for our high-level objects:
 //! type TxBroadcaster = dyn lightning::chain::chaininterface::BroadcasterInterface;
 //! type FeeEstimator = dyn lightning::chain::chaininterface::FeeEstimator;
-//! type ChannelMonitor = lightning::ln::channelmonitor::SimpleManyChannelMonitor<lightning::chain::transaction::OutPoint, lightning::chain::keysinterface::InMemoryChannelKeys, Arc<TxBroadcaster>, Arc<FeeEstimator>>;
-//! type ChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager<ChannelMonitor, TxBroadcaster, FeeEstimator>;
-//! type PeerManager = lightning::ln::peer_handler::SimpleArcPeerManager<lightning_net_tokio::SocketDescriptor, ChannelMonitor, TxBroadcaster, FeeEstimator>;
+//! type Logger = dyn lightning::util::logger::Logger;
+//! type ChainWatchInterface = dyn lightning::chain::chaininterface::ChainWatchInterface;
+//! type ChannelMonitor = lightning::ln::channelmonitor::SimpleManyChannelMonitor<lightning::chain::transaction::OutPoint, lightning::chain::keysinterface::InMemoryChannelKeys, Arc<TxBroadcaster>, Arc<FeeEstimator>, Arc<Logger>, Arc<ChainWatchInterface>>;
+//! type ChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager<ChannelMonitor, TxBroadcaster, FeeEstimator, Logger>;
+//! type PeerManager = lightning::ln::peer_handler::SimpleArcPeerManager<lightning_net_tokio::SocketDescriptor, ChannelMonitor, TxBroadcaster, FeeEstimator, Logger>;
 //!
 //! // Connect to node with pubkey their_node_id at addr:
 //! async fn connect_to_node(peer_manager: PeerManager, channel_monitor: Arc<ChannelMonitor>, channel_manager: ChannelManager, their_node_id: PublicKey, addr: SocketAddr) {
@@ -59,7 +61,7 @@
 //! }
 //! ```
 
-use secp256k1::key::PublicKey;
+use bitcoin::secp256k1::key::PublicKey;
 
 use tokio::net::TcpStream;
 use tokio::{io, time};
@@ -69,8 +71,9 @@ use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
 use lightning::ln::peer_handler;
 use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
 use lightning::ln::msgs::ChannelMessageHandler;
+use lightning::util::logger::Logger;
 
-use std::task;
+use std::{task, thread};
 use std::net::SocketAddr;
 use std::sync::{Arc, Mutex, MutexGuard};
 use std::sync::atomic::{AtomicU64, Ordering};
@@ -101,6 +104,11 @@ struct Connection {
        // socket. To wake it up (without otherwise changing its state, we can push a value into this
        // Sender.
        read_waker: mpsc::Sender<()>,
+       // When we are told by rust-lightning to disconnect, we can't return to rust-lightning until we
+       // are sure we won't call any more read/write PeerManager functions with the same connection.
+       // This is set to true if we're in such a condition (with disconnect checked before with the
+       // top-level mutex held) and false when we can return.
+       block_disconnect_socket: bool,
        read_paused: bool,
        rl_requested_disconnect: bool,
        id: u64,
@@ -116,7 +124,7 @@ impl Connection {
                        _ => panic!()
                }
        }
-       async fn schedule_read<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) {
+       async fn schedule_read<CMH: ChannelMessageHandler + 'static, L: Logger + 'static + ?Sized>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<L>>>, us: Arc<Mutex<Self>>, mut reader: io::ReadHalf<TcpStream>, mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>) {
                let peer_manager_ref = peer_manager.clone();
                // 8KB is nice and big but also should never cause any issues with stack overflowing.
                let mut buf = [0; 8192];
@@ -143,28 +151,35 @@ impl Connection {
                                } }
                        }
 
+                       macro_rules! prepare_read_write_call {
+                               () => { {
+                                       let mut us_lock = us.lock().unwrap();
+                                       if us_lock.rl_requested_disconnect {
+                                               shutdown_socket!("disconnect_socket() call from RL", Disconnect::CloseConnection);
+                                       }
+                                       us_lock.block_disconnect_socket = true;
+                               } }
+                       }
+
                        let read_paused = us.lock().unwrap().read_paused;
                        tokio::select! {
                                v = write_avail_receiver.recv() => {
                                        assert!(v.is_some()); // We can't have dropped the sending end, its in the us Arc!
-                                       if us.lock().unwrap().rl_requested_disconnect {
-                                               shutdown_socket!("disconnect_socket() call from RL", Disconnect::CloseConnection);
-                                       }
+                                       prepare_read_write_call!();
                                        if let Err(e) = peer_manager.write_buffer_space_avail(&mut our_descriptor) {
                                                shutdown_socket!(e, Disconnect::CloseConnection);
                                        }
+                                       us.lock().unwrap().block_disconnect_socket = false;
                                },
                                _ = read_wake_receiver.recv() => {},
                                read = reader.read(&mut buf), if !read_paused => match read {
                                        Ok(0) => shutdown_socket!("Connection closed", Disconnect::PeerDisconnected),
                                        Ok(len) => {
-                                               if us.lock().unwrap().rl_requested_disconnect {
-                                                       shutdown_socket!("disconnect_socket() call from RL", Disconnect::CloseConnection);
-                                               }
+                                               prepare_read_write_call!();
                                                let read_res = peer_manager.read_event(&mut our_descriptor, &buf[0..len]);
+                                               let mut us_lock = us.lock().unwrap();
                                                match read_res {
                                                        Ok(pause_read) => {
-                                                               let mut us_lock = us.lock().unwrap();
                                                                if pause_read {
                                                                        us_lock.read_paused = true;
                                                                }
@@ -172,6 +187,7 @@ impl Connection {
                                                        },
                                                        Err(e) => shutdown_socket!(e, Disconnect::CloseConnection),
                                                }
+                                               us_lock.block_disconnect_socket = false;
                                        },
                                        Err(e) => shutdown_socket!(e, Disconnect::PeerDisconnected),
                                },
@@ -203,8 +219,8 @@ impl Connection {
 
                (reader, write_receiver, read_receiver,
                Arc::new(Mutex::new(Self {
-                       writer: Some(writer), event_notify, write_avail, read_waker,
-                       read_paused: false, rl_requested_disconnect: false,
+                       writer: Some(writer), event_notify, write_avail, read_waker, read_paused: false,
+                       block_disconnect_socket: false, rl_requested_disconnect: false,
                        id: ID_COUNTER.fetch_add(1, Ordering::AcqRel)
                })))
        }
@@ -218,7 +234,7 @@ impl Connection {
 /// not need to poll the provided future in order to make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_inbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> {
+pub fn setup_inbound<CMH: ChannelMessageHandler + 'static, L: Logger + 'static + ?Sized>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> {
        let (reader, write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
        #[cfg(debug_assertions)]
        let last_us = Arc::clone(&us);
@@ -257,19 +273,36 @@ pub fn setup_inbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<pee
 /// not need to poll the provided future in order to make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_outbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> {
-       let (reader, write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
+pub fn setup_outbound<CMH: ChannelMessageHandler + 'static, L: Logger + 'static + ?Sized>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> {
+       let (reader, mut write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
        #[cfg(debug_assertions)]
        let last_us = Arc::clone(&us);
 
        let handle_opt = if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone())) {
                Some(tokio::spawn(async move {
-                       if SocketDescriptor::new(us.clone()).send_data(&initial_send, true) != initial_send.len() {
-                               // We should essentially always have enough room in a TCP socket buffer to send the
-                               // initial 10s of bytes, if not, just give up as hopeless.
-                               eprintln!("Failed to write first full message to socket!");
-                               peer_manager.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
-                       } else {
+                       // We should essentially always have enough room in a TCP socket buffer to send the
+                       // initial 10s of bytes. However, tokio running in single-threaded mode will always
+                       // fail writes and wake us back up later to write. Thus, we handle a single
+                       // std::task::Poll::Pending but still expect to write the full set of bytes at once
+                       // and use a relatively tight timeout.
+                       if let Ok(Ok(())) = tokio::time::timeout(Duration::from_millis(100), async {
+                               loop {
+                                       match SocketDescriptor::new(us.clone()).send_data(&initial_send, true) {
+                                               v if v == initial_send.len() => break Ok(()),
+                                               0 => {
+                                                       write_receiver.recv().await;
+                                                       // In theory we could check for if we've been instructed to disconnect
+                                                       // the peer here, but its OK to just skip it - we'll check for it in
+                                                       // schedule_read prior to any relevant calls into RL.
+                                               },
+                                               _ => {
+                                                       eprintln!("Failed to write first full message to socket!");
+                                                       peer_manager.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
+                                                       break Err(());
+                                               }
+                                       }
+                               }
+                       }).await {
                                Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver).await;
                        }
                }))
@@ -309,7 +342,7 @@ pub fn setup_outbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<pe
 /// make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub async fn connect_outbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> {
+pub async fn connect_outbound<CMH: ChannelMessageHandler + 'static, L: Logger + 'static + ?Sized>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, addr: SocketAddr) -> Option<impl std::future::Future<Output=()>> {
        if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), TcpStream::connect(&addr)).await {
                Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
        } else { None }
@@ -411,15 +444,18 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
        }
 
        fn disconnect_socket(&mut self) {
-               let mut us = self.conn.lock().unwrap();
-               us.rl_requested_disconnect = true;
-               us.read_paused = true;
-               // Wake up the sending thread, assuming it is still alive
-               let _ = us.write_avail.try_send(());
-               // TODO: There's a race where we don't meet the requirements of disconnect_socket if the
-               // read task is about to call a PeerManager function (eg read_event or write_event).
-               // Ideally we need to release the us lock and block until we have confirmation from the
-               // read task that it has broken out of its main loop.
+               {
+                       let mut us = self.conn.lock().unwrap();
+                       us.rl_requested_disconnect = true;
+                       us.read_paused = true;
+                       // Wake up the sending thread, assuming it is still alive
+                       let _ = us.write_avail.try_send(());
+                       // Happy-path return:
+                       if !us.block_disconnect_socket { return; }
+               }
+               while self.conn.lock().unwrap().block_disconnect_socket {
+                       thread::yield_now();
+               }
        }
 }
 impl Clone for SocketDescriptor {
@@ -448,7 +484,7 @@ mod tests {
        use lightning::ln::msgs::*;
        use lightning::ln::peer_handler::{MessageHandler, PeerManager};
        use lightning::util::events::*;
-       use secp256k1::{Secp256k1, SecretKey, PublicKey};
+       use bitcoin::secp256k1::{Secp256k1, SecretKey, PublicKey};
 
        use tokio::sync::mpsc;
 
@@ -474,7 +510,7 @@ mod tests {
                fn handle_channel_announcement(&self, _msg: &ChannelAnnouncement) -> Result<bool, LightningError> { Ok(false) }
                fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result<bool, LightningError> { Ok(false) }
                fn handle_htlc_fail_channel_update(&self, _update: &HTLCFailChannelUpdate) { }
-               fn get_next_channel_announcements(&self, _starting_point: u64, _batch_amount: u8) -> Vec<(ChannelAnnouncement, ChannelUpdate, ChannelUpdate)> { Vec::new() }
+               fn get_next_channel_announcements(&self, _starting_point: u64, _batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> { Vec::new() }
                fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec<NodeAnnouncement> { Vec::new() }
                fn should_request_full_sync(&self, _node_id: &PublicKey) -> bool { false }
        }
@@ -515,8 +551,7 @@ mod tests {
                }
        }
 
-       #[tokio::test(threaded_scheduler)]
-       async fn basic_connection_test() {
+       async fn do_basic_connection_test() {
                let secp_ctx = Secp256k1::new();
                let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
                let b_key = SecretKey::from_slice(&[1; 32]).unwrap();
@@ -581,4 +616,13 @@ mod tests {
                fut_a.await;
                fut_b.await;
        }
+
+       #[tokio::test(threaded_scheduler)]
+       async fn basic_threaded_connection_test() {
+               do_basic_connection_test().await;
+       }
+       #[tokio::test]
+       async fn basic_unthreaded_connection_test() {
+               do_basic_connection_test().await;
+       }
 }