BDR: Linearizing secp256k1 deps
[rust-lightning] / lightning-net-tokio / src / lib.rs
index e635d93d383162560ad28060de65dd5a361dd902..e460df25e54d192e92548feb5b6053715798fd27 100644 (file)
@@ -16,7 +16,7 @@
 //! ```
 //! use tokio::sync::mpsc;
 //! use tokio::net::TcpStream;
-//! use secp256k1::key::PublicKey;
+//! use bitcoin::secp256k1::key::PublicKey;
 //! use lightning::util::events::EventsProvider;
 //! use std::net::SocketAddr;
 //! use std::sync::Arc;
@@ -59,7 +59,7 @@
 //! }
 //! ```
 
-use secp256k1::key::PublicKey;
+use bitcoin::secp256k1::key::PublicKey;
 
 use tokio::net::TcpStream;
 use tokio::{io, time};
@@ -271,18 +271,35 @@ pub fn setup_inbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<pee
 ///
 /// See the module-level documentation for how to handle the event_notify mpsc::Sender.
 pub fn setup_outbound<CMH: ChannelMessageHandler + 'static>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> {
-       let (reader, write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
+       let (reader, mut write_receiver, read_receiver, us) = Connection::new(event_notify, stream);
        #[cfg(debug_assertions)]
        let last_us = Arc::clone(&us);
 
        let handle_opt = if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone())) {
                Some(tokio::spawn(async move {
-                       if SocketDescriptor::new(us.clone()).send_data(&initial_send, true) != initial_send.len() {
-                               // We should essentially always have enough room in a TCP socket buffer to send the
-                               // initial 10s of bytes, if not, just give up as hopeless.
-                               eprintln!("Failed to write first full message to socket!");
-                               peer_manager.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
-                       } else {
+                       // We should essentially always have enough room in a TCP socket buffer to send the
+                       // initial 10s of bytes. However, tokio running in single-threaded mode will always
+                       // fail writes and wake us back up later to write. Thus, we handle a single
+                       // std::task::Poll::Pending but still expect to write the full set of bytes at once
+                       // and use a relatively tight timeout.
+                       if let Ok(Ok(())) = tokio::time::timeout(Duration::from_millis(100), async {
+                               loop {
+                                       match SocketDescriptor::new(us.clone()).send_data(&initial_send, true) {
+                                               v if v == initial_send.len() => break Ok(()),
+                                               0 => {
+                                                       write_receiver.recv().await;
+                                                       // In theory we could check for if we've been instructed to disconnect
+                                                       // the peer here, but its OK to just skip it - we'll check for it in
+                                                       // schedule_read prior to any relevant calls into RL.
+                                               },
+                                               _ => {
+                                                       eprintln!("Failed to write first full message to socket!");
+                                                       peer_manager.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
+                                                       break Err(());
+                                               }
+                                       }
+                               }
+                       }).await {
                                Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver).await;
                        }
                }))
@@ -464,7 +481,7 @@ mod tests {
        use lightning::ln::msgs::*;
        use lightning::ln::peer_handler::{MessageHandler, PeerManager};
        use lightning::util::events::*;
-       use secp256k1::{Secp256k1, SecretKey, PublicKey};
+       use bitcoin::secp256k1::{Secp256k1, SecretKey, PublicKey};
 
        use tokio::sync::mpsc;
 
@@ -490,7 +507,7 @@ mod tests {
                fn handle_channel_announcement(&self, _msg: &ChannelAnnouncement) -> Result<bool, LightningError> { Ok(false) }
                fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result<bool, LightningError> { Ok(false) }
                fn handle_htlc_fail_channel_update(&self, _update: &HTLCFailChannelUpdate) { }
-               fn get_next_channel_announcements(&self, _starting_point: u64, _batch_amount: u8) -> Vec<(ChannelAnnouncement, ChannelUpdate, ChannelUpdate)> { Vec::new() }
+               fn get_next_channel_announcements(&self, _starting_point: u64, _batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> { Vec::new() }
                fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec<NodeAnnouncement> { Vec::new() }
                fn should_request_full_sync(&self, _node_id: &PublicKey) -> bool { false }
        }
@@ -531,8 +548,7 @@ mod tests {
                }
        }
 
-       #[tokio::test(threaded_scheduler)]
-       async fn basic_connection_test() {
+       async fn do_basic_connection_test() {
                let secp_ctx = Secp256k1::new();
                let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
                let b_key = SecretKey::from_slice(&[1; 32]).unwrap();
@@ -597,4 +613,13 @@ mod tests {
                fut_a.await;
                fut_b.await;
        }
+
+       #[tokio::test(threaded_scheduler)]
+       async fn basic_threaded_connection_test() {
+               do_basic_connection_test().await;
+       }
+       #[tokio::test]
+       async fn basic_unthreaded_connection_test() {
+               do_basic_connection_test().await;
+       }
 }