Merge pull request #2529 from TheBlueMatt/2023-08-shutdown-remove-early-sign
[rust-lightning] / lightning-net-tokio / src / lib.rs
index 6e2ea3f14c14f54b890f8ff504e4f3fe6dd0a304..d4f75dd6cd8b073d0e29de1ecf76685b8122ab6e 100644 (file)
 
 use bitcoin::secp256k1::PublicKey;
 
-use tokio::net::{tcp, TcpStream};
-use tokio::{io, time};
+use tokio::net::TcpStream;
+use tokio::time;
 use tokio::sync::mpsc;
-use tokio::io::AsyncWrite;
 
 use lightning::ln::peer_handler;
 use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
 use lightning::ln::peer_handler::APeerManager;
-use lightning::ln::msgs::NetAddress;
+use lightning::ln::msgs::SocketAddress;
 
 use std::ops::Deref;
 use std::task::{self, Poll};
@@ -231,7 +230,7 @@ impl Connection {
                                                        // readable() is allowed to spuriously wake, so we have to handle
                                                        // WouldBlock here.
                                                },
-                                               Err(e) => break Disconnect::PeerDisconnected,
+                                               Err(_) => break Disconnect::PeerDisconnected,
                                        }
                                },
                        }
@@ -274,13 +273,13 @@ impl Connection {
        }
 }
 
-fn get_addr_from_stream(stream: &StdTcpStream) -> Option<NetAddress> {
+fn get_addr_from_stream(stream: &StdTcpStream) -> Option<SocketAddress> {
        match stream.peer_addr() {
-               Ok(SocketAddr::V4(sockaddr)) => Some(NetAddress::IPv4 {
+               Ok(SocketAddr::V4(sockaddr)) => Some(SocketAddress::TcpIpV4 {
                        addr: sockaddr.ip().octets(),
                        port: sockaddr.port(),
                }),
-               Ok(SocketAddr::V6(sockaddr)) => Some(NetAddress::IPv6 {
+               Ok(SocketAddr::V6(sockaddr)) => Some(SocketAddress::TcpIpV6 {
                        addr: sockaddr.ip().octets(),
                        port: sockaddr.port(),
                }),
@@ -423,7 +422,11 @@ const SOCK_WAKER_VTABLE: task::RawWakerVTable =
        task::RawWakerVTable::new(clone_socket_waker, wake_socket_waker, wake_socket_waker_by_ref, drop_socket_waker);
 
 fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker {
-       write_avail_to_waker(orig_ptr as *const mpsc::Sender<()>)
+       let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) };
+       let res = write_avail_to_waker(&new_waker);
+       // Don't decrement the refcount when dropping new_waker by turning it back `into_raw`.
+       let _ = Arc::into_raw(new_waker);
+       res
 }
 // When waking, an error should be fine. Most likely we got two send_datas in a row, both of which
 // failed to fully write, but we only need to call write_buffer_space_avail() once. Otherwise, the
@@ -436,16 +439,15 @@ fn wake_socket_waker(orig_ptr: *const ()) {
 }
 fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
        let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
-       let sender = unsafe { (*sender_ptr).clone() };
+       let sender = unsafe { &*sender_ptr };
        let _ = sender.try_send(());
 }
 fn drop_socket_waker(orig_ptr: *const ()) {
-       let _orig_box = unsafe { Box::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
-       // _orig_box is now dropped
+       let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
+       // _orig_arc is now dropped
 }
-fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker {
-       let new_box = Box::leak(Box::new(unsafe { (*sender).clone() }));
-       let new_ptr = new_box as *const mpsc::Sender<()>;
+fn write_avail_to_waker(sender: &Arc<mpsc::Sender<()>>) -> task::RawWaker {
+       let new_ptr = Arc::into_raw(Arc::clone(&sender));
        task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE)
 }
 
@@ -453,12 +455,20 @@ fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker {
 /// type in the template of PeerHandler.
 pub struct SocketDescriptor {
        conn: Arc<Mutex<Connection>>,
+       // We store a copy of the mpsc::Sender to wake the read task in an Arc here. While we can
+       // simply clone the sender and store a copy in each waker, that would require allocating for
+       // each waker. Instead, we can simply `Arc::clone`, creating a new reference and store the
+       // pointer in the waker.
+       write_avail_sender: Arc<mpsc::Sender<()>>,
        id: u64,
 }
 impl SocketDescriptor {
        fn new(conn: Arc<Mutex<Connection>>) -> Self {
-               let id = conn.lock().unwrap().id;
-               Self { conn, id }
+               let (id, write_avail_sender) = {
+                       let us = conn.lock().unwrap();
+                       (us.id, Arc::new(us.write_avail.clone()))
+               };
+               Self { conn, id, write_avail_sender }
        }
 }
 impl peer_handler::SocketDescriptor for SocketDescriptor {
@@ -481,7 +491,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
                        let _ = us.read_waker.try_send(());
                }
                if data.is_empty() { return 0; }
-               let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&us.write_avail)) };
+               let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) };
                let mut ctx = task::Context::from_waker(&waker);
                let mut written_len = 0;
                loop {
@@ -493,10 +503,10 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
                                                        written_len += res;
                                                        if written_len == data.len() { return written_len; }
                                                },
-                                               Err(e) => return written_len,
+                                               Err(_) => return written_len,
                                        }
                                },
-                               task::Poll::Ready(Err(e)) => return written_len,
+                               task::Poll::Ready(Err(_)) => return written_len,
                                task::Poll::Pending => {
                                        // We're queued up for a write event now, but we need to make sure we also
                                        // pause read given we're now waiting on the remote end to ACK (and in
@@ -523,6 +533,7 @@ impl Clone for SocketDescriptor {
                Self {
                        conn: Arc::clone(&self.conn),
                        id: self.id,
+                       write_avail_sender: Arc::clone(&self.write_avail_sender),
                }
        }
 }
@@ -606,6 +617,10 @@ mod tests {
                fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &ChannelUpdate) {}
                fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, _msg: &OpenChannelV2) {}
                fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, _msg: &AcceptChannelV2) {}
+               fn handle_stfu(&self, _their_node_id: &PublicKey, _msg: &Stfu) {}
+               fn handle_splice(&self, _their_node_id: &PublicKey, _msg: &Splice) {}
+               fn handle_splice_ack(&self, _their_node_id: &PublicKey, _msg: &SpliceAck) {}
+               fn handle_splice_locked(&self, _their_node_id: &PublicKey, _msg: &SpliceLocked) {}
                fn handle_tx_add_input(&self, _their_node_id: &PublicKey, _msg: &TxAddInput) {}
                fn handle_tx_add_output(&self, _their_node_id: &PublicKey, _msg: &TxAddOutput) {}
                fn handle_tx_remove_input(&self, _their_node_id: &PublicKey, _msg: &TxRemoveInput) {}
@@ -631,7 +646,7 @@ mod tests {
                fn handle_error(&self, _their_node_id: &PublicKey, _msg: &ErrorMessage) {}
                fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }
                fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { InitFeatures::empty() }
-               fn get_genesis_hashes(&self) -> Option<Vec<ChainHash>> {
+               fn get_chain_hashes(&self) -> Option<Vec<ChainHash>> {
                        Some(vec![ChainHash::using_genesis_block(Network::Testnet)])
                }
        }