Fix (and test) net-tokio outbound conns without a threaded env
authorMatt Corallo <git@bluematt.me>
Fri, 28 Feb 2020 16:55:34 +0000 (11:55 -0500)
committerMatt Corallo <git@bluematt.me>
Wed, 11 Mar 2020 16:19:39 +0000 (12:19 -0400)
lightning-net-tokio/src/lib.rs

index e635d93d383162560ad28060de65dd5a361dd902..94c3eee151164142668507a01d7081bac5c4fa04 100644 (file)
@@ -271,18 +271,35 @@ pub fn setup_inbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<pee
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
 pub fn setup_outbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> {
-       let (reader, write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
+       let (reader, mut write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
        #[cfg(debug_assertions)]
        let last_us = Arc::clone(&us);
 
        let handle_opt = if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone())) {
                Some(tokio::spawn(async move {
-                       if SocketDescriptor::new(us.clone()).send_data(&initial_send, true) != initial_send.len() {
-                               // We should essentially always have enough room in a TCP socket buffer to send the
-                               // initial 10s of bytes, if not, just give up as hopeless.
-                               eprintln!("Failed to write first full message to socket!");
-                               peer_manager.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
-                       } else {
+                       // We should essentially always have enough room in a TCP socket buffer to send the
+                       // initial 10s of bytes. However, tokio running in single-threaded mode will always
+                       // fail writes and wake us back up later to write. Thus, we handle a single
+                       // std::task::Poll::Pending but still expect to write the full set of bytes at once
+                       // and use a relatively tight timeout.
+                       if let Ok(Ok(())) = tokio::time::timeout(Duration::from_millis(100), async {
+                               loop {
+                                       match SocketDescriptor::new(us.clone()).send_data(&initial_send, true) {
+                                               v if v == initial_send.len() => break Ok(()),
+                                               0 => {
+                                                       write_receiver.recv().await;
+                                                       // In theory we could check for if we've been instructed to disconnect
+                                                       // the peer here, but its OK to just skip it - we'll check for it in
+                                                       // schedule_read prior to any relevant calls into RL.
+                                               },
+                                               _ => {
+                                                       eprintln!("Failed to write first full message to socket!");
+                                                       peer_manager.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
+                                                       break Err(());
+                                               }
+                                       }
+                               }
+                       }).await {
                                Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver).await;
                        }
                }))
@@ -531,8 +548,7 @@ mod tests {
                }
        }
 
-       #[tokio::test(threaded_scheduler)]
-       async fn basic_connection_test() {
+       async fn do_basic_connection_test() {
                let secp_ctx = Secp256k1::new();
                let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
                let b_key = SecretKey::from_slice(&[1; 32]).unwrap();
@@ -597,4 +613,13 @@ mod tests {
                fut_a.await;
                fut_b.await;
        }
+
+       #[tokio::test(threaded_scheduler)]
+       async fn basic_threaded_connection_test() {
+               do_basic_connection_test().await;
+       }
+       #[tokio::test]
+       async fn basic_unthreaded_connection_test() {
+               do_basic_connection_test().await;
+       }
 }