Update tokio to 1.0
[rust-lightning] / lightning-net-tokio / src / lib.rs
index 8e5885ca9bf2d7b67b017d8575a4059e9ecc5ce2..b8b33318f99a0fcaa41ce7c4b803eb9a5daff277 100644 (file)
@@ -24,7 +24,7 @@
 //! The call site should, thus, look something like this:
 //! ```
 //! use tokio::sync::mpsc;
-//! use tokio::net::TcpStream;
+//! use std::net::TcpStream;
 //! use bitcoin::secp256k1::key::PublicKey;
 //! use lightning::util::events::EventsProvider;
 //! use std::net::SocketAddr;
@@ -86,6 +86,7 @@ use lightning::util::logger::Logger;
 
 use std::{task, thread};
 use std::net::SocketAddr;
+use std::net::TcpStream as StdTcpStream;
 use std::sync::{Arc, Mutex, MutexGuard};
 use std::sync::atomic::{AtomicU64, Ordering};
 use std::time::Duration;
@@ -218,7 +219,7 @@ impl Connection {
                }
        }
 
-       fn new(event_notify: mpsc::Sender<()>, stream: TcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
+       fn new(event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
                // We only ever need a channel of depth 1 here: if we returned a non-full write to the
                // PeerManager, we will eventually get notified that there is room in the socket to write
                // new bytes, which will generate an event. That event will be popped off the queue before
@@ -229,7 +230,8 @@ impl Connection {
                // we shove a value into the channel which comes after we've reset the read_paused bool to
                // false.
                let (read_waker, read_receiver) = mpsc::channel(1);
-               let (reader, writer) = io::split(stream);
+               stream.set_nonblocking(true).unwrap();
+               let (reader, writer) = io::split(TcpStream::from_std(stream).unwrap());
 
                (reader, write_receiver, read_receiver,
                Arc::new(Mutex::new(Self {
@@ -248,7 +250,7 @@ impl Connection {
 /// not need to poll the provided future in order to make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
                CMH: ChannelMessageHandler + 'static,
                RMH: RoutingMessageHandler + 'static,
                L: Logger + 'static + ?Sized {
@@ -290,7 +292,7 @@ pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<So
 /// not need to poll the provided future in order to make progress.
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
                CMH: ChannelMessageHandler + 'static,
                RMH: RoutingMessageHandler + 'static,
                L: Logger + 'static + ?Sized {
@@ -366,7 +368,7 @@ pub async fn connect_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerM
                CMH: ChannelMessageHandler + 'static,
                RMH: RoutingMessageHandler + 'static,
                L: Logger + 'static + ?Sized {
-       if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), TcpStream::connect(&addr)).await {
+       if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
                Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
        } else { None }
 }
@@ -388,7 +390,7 @@ fn wake_socket_waker(orig_ptr: *const ()) {
 }
 fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
        let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
-       let mut sender = unsafe { (*sender_ptr).clone() };
+       let sender = unsafe { (*sender_ptr).clone() };
        let _ = sender.try_send(());
 }
 fn drop_socket_waker(orig_ptr: *const ()) {
@@ -512,6 +514,7 @@ mod tests {
        use tokio::sync::mpsc;
 
        use std::mem;
+       use std::sync::atomic::{AtomicBool, Ordering};
        use std::sync::{Arc, Mutex};
        use std::time::Duration;
 
@@ -526,6 +529,7 @@ mod tests {
                expected_pubkey: PublicKey,
                pubkey_connected: mpsc::Sender<()>,
                pubkey_disconnected: mpsc::Sender<()>,
+               disconnected_flag: AtomicBool,
                msg_events: Mutex<Vec<MessageSendEvent>>,
        }
        impl RoutingMessageHandler for MsgHandler {
@@ -559,6 +563,7 @@ mod tests {
                fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
                fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
                        if *their_node_id == self.expected_pubkey {
+                               self.disconnected_flag.store(true, Ordering::SeqCst);
                                self.pubkey_disconnected.clone().try_send(()).unwrap();
                        }
                }
@@ -591,6 +596,7 @@ mod tests {
                        expected_pubkey: b_pub,
                        pubkey_connected: a_connected_sender,
                        pubkey_disconnected: a_disconnected_sender,
+                       disconnected_flag: AtomicBool::new(false),
                        msg_events: Mutex::new(Vec::new()),
                });
                let a_manager = Arc::new(PeerManager::new(MessageHandler {
@@ -604,6 +610,7 @@ mod tests {
                        expected_pubkey: a_pub,
                        pubkey_connected: b_connected_sender,
                        pubkey_disconnected: b_disconnected_sender,
+                       disconnected_flag: AtomicBool::new(false),
                        msg_events: Mutex::new(Vec::new()),
                });
                let b_manager = Arc::new(PeerManager::new(MessageHandler {
@@ -624,8 +631,8 @@ mod tests {
                } else { panic!("Failed to bind to v4 localhost on common ports"); };
 
                let (sender, _receiver) = mpsc::channel(2);
-               let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, tokio::net::TcpStream::from_std(conn_a).unwrap());
-               let fut_b = super::setup_inbound(b_manager, sender, tokio::net::TcpStream::from_std(conn_b).unwrap());
+               let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, conn_a);
+               let fut_b = super::setup_inbound(b_manager, sender, conn_b);
 
                tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
                tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
@@ -633,18 +640,20 @@ mod tests {
                a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
                        node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None }
                });
-               assert!(a_disconnected.try_recv().is_err());
-               assert!(b_disconnected.try_recv().is_err());
+               assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
+               assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
 
                a_manager.process_events();
                tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
                tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
+               assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
+               assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
 
                fut_a.await;
                fut_b.await;
        }
 
-       #[tokio::test(threaded_scheduler)]
+       #[tokio::test(flavor = "multi_thread")]
        async fn basic_threaded_connection_test() {
                do_basic_connection_test().await;
        }