Merge pull request #2544 from optout21/splicing-msgs0
[rust-lightning] / lightning-net-tokio / src / lib.rs
index c67aa668d11da5af727ebde5a87ada216a4fa2c3..d4f75dd6cd8b073d0e29de1ecf76685b8122ab6e 100644 (file)
@@ -422,7 +422,11 @@ const SOCK_WAKER_VTABLE: task::RawWakerVTable =
        task::RawWakerVTable::new(clone_socket_waker, wake_socket_waker, wake_socket_waker_by_ref, drop_socket_waker);
 
 fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker {
-       write_avail_to_waker(orig_ptr as *const mpsc::Sender<()>)
+       let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) };
+       let res = write_avail_to_waker(&new_waker);
+       // Don't decrement the refcount when dropping new_waker by turning it back `into_raw`.
+       let _ = Arc::into_raw(new_waker);
+       res
 }
 // When waking, an error should be fine. Most likely we got two send_datas in a row, both of which
 // failed to fully write, but we only need to call write_buffer_space_avail() once. Otherwise, the
@@ -435,16 +439,15 @@ fn wake_socket_waker(orig_ptr: *const ()) {
 }
 fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
        let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
-       let sender = unsafe { (*sender_ptr).clone() };
+       let sender = unsafe { &*sender_ptr };
        let _ = sender.try_send(());
 }
 fn drop_socket_waker(orig_ptr: *const ()) {
-       let _orig_box = unsafe { Box::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
-       // _orig_box is now dropped
+       let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
+       // _orig_arc is now dropped
 }
-fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker {
-       let new_box = Box::leak(Box::new(unsafe { (*sender).clone() }));
-       let new_ptr = new_box as *const mpsc::Sender<()>;
+fn write_avail_to_waker(sender: &Arc<mpsc::Sender<()>>) -> task::RawWaker {
+       let new_ptr = Arc::into_raw(Arc::clone(&sender));
        task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE)
 }
 
@@ -452,12 +455,20 @@ fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker {
 /// type in the template of PeerHandler.
 pub struct SocketDescriptor {
        conn: Arc<Mutex<Connection>>,
+       // We store a copy of the mpsc::Sender to wake the read task in an Arc here. While we can
+       // simply clone the sender and store a copy in each waker, that would require allocating for
+       // each waker. Instead, we can simply `Arc::clone`, creating a new reference and store the
+       // pointer in the waker.
+       write_avail_sender: Arc<mpsc::Sender<()>>,
        id: u64,
 }
 impl SocketDescriptor {
        fn new(conn: Arc<Mutex<Connection>>) -> Self {
-               let id = conn.lock().unwrap().id;
-               Self { conn, id }
+               let (id, write_avail_sender) = {
+                       let us = conn.lock().unwrap();
+                       (us.id, Arc::new(us.write_avail.clone()))
+               };
+               Self { conn, id, write_avail_sender }
        }
 }
 impl peer_handler::SocketDescriptor for SocketDescriptor {
@@ -480,7 +491,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
                        let _ = us.read_waker.try_send(());
                }
                if data.is_empty() { return 0; }
-               let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&us.write_avail)) };
+               let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) };
                let mut ctx = task::Context::from_waker(&waker);
                let mut written_len = 0;
                loop {
@@ -522,6 +533,7 @@ impl Clone for SocketDescriptor {
                Self {
                        conn: Arc::clone(&self.conn),
                        id: self.id,
+                       write_avail_sender: Arc::clone(&self.write_avail_sender),
                }
        }
 }