Update tokio to 1.0
authorMatt Corallo <git@bluematt.me>
Tue, 26 Jan 2021 20:38:19 +0000 (15:38 -0500)
committerMatt Corallo <git@bluematt.me>
Tue, 26 Jan 2021 23:37:04 +0000 (18:37 -0500)
This requires ensuring TcpStreams are set in nonblocking mode as
tokio doesn't handle this for us anymore, so we adapt the public
API to just accept std TcpStreams instead of an extra conversion
hop. Luckily converting them is cheap.

lightning-net-tokio/Cargo.toml
lightning-net-tokio/src/lib.rs

index 50634bd32df2ce4be3263a232feea74cb7f1b36a..9165388066da3082f6e427bb6664c3dcb08bd2ff 100644 (file)
@@ -12,7 +12,7 @@ For Rust-Lightning clients which wish to make direct connections to Lightning P2
 [dependencies]
 bitcoin = "0.24"
 lightning = { version = "0.0.12", path = "../lightning" }
-tokio = { version = ">=0.2.12", features = [ "io-util", "macros", "rt-core", "sync", "tcp", "time" ] }
+tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "sync", "net", "time" ] }
 
 [dev-dependencies]
-tokio = { version = ">=0.2.12", features = [ "io-util", "macros", "rt-core", "rt-threaded", "sync", "tcp", "time" ] }
+tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "rt-multi-thread", "sync", "net", "time" ] }
index 8e5885ca9bf2d7b67b017d8575a4059e9ecc5ce2..b8b33318f99a0fcaa41ce7c4b803eb9a5daff277 100644 (file)
@@ -24,7 +24,7 @@
 //! The call site should, thus, look something like this:
 //! ```
 //! use tokio::sync::mpsc;
-//! use tokio::net::TcpStream;
+//! use std::net::TcpStream;
 //! use bitcoin::secp256k1::key::PublicKey;
 //! use lightning::util::events::EventsProvider;
 //! use std::net::SocketAddr;
@@ -86,6 +86,7 @@ use lightning::util::logger::Logger;
 
 use std::{task, thread};
 use std::net::SocketAddr;
+use std::net::TcpStream as StdTcpStream;
 use std::sync::{Arc, Mutex, MutexGuard};
 use std::sync::atomic::{AtomicU64, Ordering};
 use std::time::Duration;
@@ -218,7 +219,7 @@ impl Connection {
                }
        }
 
-       fn new(event_notify: mpsc::Sender<()>, stream: TcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
+       fn new(event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
                // We only ever need a channel of depth 1 here: if we returned a non-full write to the
                // PeerManager, we will eventually get notified that there is room in the socket to write
                // new bytes, which will generate an event. That event will be popped off the queue before
@@ -229,7 +230,8 @@ impl Connection {
                // we shove a value into the channel which comes after we've reset the read_paused bool to
                // false.
                let (read_waker, read_receiver) = mpsc::channel(1);
-               let (reader, writer) = io::split(stream);
+               stream.set_nonblocking(true).unwrap();
+               let (reader, writer) = io::split(TcpStream::from_std(stream).unwrap());
 
                (reader, write_receiver, read_receiver,
                Arc::new(Mutex::new(Self {
@@ -248,7 +250,7 @@ impl Connection {
 /// not need to poll the provided future in order to make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
                CMH: ChannelMessageHandler + 'static,
                RMH: RoutingMessageHandler + 'static,
                L: Logger + 'static + ?Sized {
@@ -290,7 +292,7 @@ pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<So
 /// not need to poll the provided future in order to make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
                CMH: ChannelMessageHandler + 'static,
                RMH: RoutingMessageHandler + 'static,
                L: Logger + 'static + ?Sized {
@@ -366,7 +368,7 @@ pub async fn connect_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerM
                CMH: ChannelMessageHandler + 'static,
                RMH: RoutingMessageHandler + 'static,
                L: Logger + 'static + ?Sized {
-       if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), TcpStream::connect(&addr)).await {
+       if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
                Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
        } else { None }
 }
@@ -388,7 +390,7 @@ fn wake_socket_waker(orig_ptr: *const ()) {
 }
 fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
        let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
-       let mut sender = unsafe { (*sender_ptr).clone() };
+       let sender = unsafe { (*sender_ptr).clone() };
        let _ = sender.try_send(());
 }
 fn drop_socket_waker(orig_ptr: *const ()) {
@@ -512,6 +514,7 @@ mod tests {
        use tokio::sync::mpsc;
 
        use std::mem;
+       use std::sync::atomic::{AtomicBool, Ordering};
        use std::sync::{Arc, Mutex};
        use std::time::Duration;
 
@@ -526,6 +529,7 @@ mod tests {
                expected_pubkey: PublicKey,
                pubkey_connected: mpsc::Sender<()>,
                pubkey_disconnected: mpsc::Sender<()>,
+               disconnected_flag: AtomicBool,
                msg_events: Mutex<Vec<MessageSendEvent>>,
        }
        impl RoutingMessageHandler for MsgHandler {
@@ -559,6 +563,7 @@ mod tests {
                fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
                fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
                        if *their_node_id == self.expected_pubkey {
+                               self.disconnected_flag.store(true, Ordering::SeqCst);
                                self.pubkey_disconnected.clone().try_send(()).unwrap();
                        }
                }
@@ -591,6 +596,7 @@ mod tests {
                        expected_pubkey: b_pub,
                        pubkey_connected: a_connected_sender,
                        pubkey_disconnected: a_disconnected_sender,
+                       disconnected_flag: AtomicBool::new(false),
                        msg_events: Mutex::new(Vec::new()),
                });
                let a_manager = Arc::new(PeerManager::new(MessageHandler {
@@ -604,6 +610,7 @@ mod tests {
                        expected_pubkey: a_pub,
                        pubkey_connected: b_connected_sender,
                        pubkey_disconnected: b_disconnected_sender,
+                       disconnected_flag: AtomicBool::new(false),
                        msg_events: Mutex::new(Vec::new()),
                });
                let b_manager = Arc::new(PeerManager::new(MessageHandler {
@@ -624,8 +631,8 @@ mod tests {
                } else { panic!("Failed to bind to v4 localhost on common ports"); };
 
                let (sender, _receiver) = mpsc::channel(2);
-               let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, tokio::net::TcpStream::from_std(conn_a).unwrap());
-               let fut_b = super::setup_inbound(b_manager, sender, tokio::net::TcpStream::from_std(conn_b).unwrap());
+               let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, conn_a);
+               let fut_b = super::setup_inbound(b_manager, sender, conn_b);
 
                tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
                tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
@@ -633,18 +640,20 @@ mod tests {
                a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
                        node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None }
                });
-               assert!(a_disconnected.try_recv().is_err());
-               assert!(b_disconnected.try_recv().is_err());
+               assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
+               assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
 
                a_manager.process_events();
                tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
                tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
+               assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
+               assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
 
                fut_a.await;
                fut_b.await;
        }
 
-       #[tokio::test(threaded_scheduler)]
+       #[tokio::test(flavor = "multi_thread")]
        async fn basic_threaded_connection_test() {
                do_basic_connection_test().await;
        }