Avoid a `tokio::mpsc::Sender` clone for each P2P send operation
authorMatt Corallo <git@bluematt.me>
Sat, 4 Nov 2023 21:21:58 +0000 (21:21 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 9 Nov 2023 22:28:08 +0000 (22:28 +0000)
Whenever we go to send bytes to a peer, we need to construct a
waker for tokio to call back into if we need to finish sending
later. That waker needs some reference to the peer's read task to
wake it up, hidden behind a single `*const ()`. To do this, we'd
previously simply stored a `Box<tokio::mpsc::Sender>` in that
pointer, which requires a `clone` for each waker construction. This
leads to substantial malloc traffic.

Instead, here, we replace this box with an `Arc`, leaving a single
`tokio::mpsc::Sender` floating around and simply change the
refcounts whenever we construct a new waker, which we can do
without allocations.

lightning-net-tokio/src/lib.rs

index bac18b2b398cac33ebbb9523fc0a2811da914b75..4483ae74256d6c0fdddd64bfe7e01dfcc372cd35 100644 (file)
@@ -422,7 +422,11 @@ const SOCK_WAKER_VTABLE: task::RawWakerVTable =
        task::RawWakerVTable::new(clone_socket_waker, wake_socket_waker, wake_socket_waker_by_ref, drop_socket_waker);
 
 fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker {
-       write_avail_to_waker(orig_ptr as *const mpsc::Sender<()>)
+       let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) };
+       let res = write_avail_to_waker(&new_waker);
+       // Don't decrement the refcount when dropping new_waker by turning it back `into_raw`.
+       let _ = Arc::into_raw(new_waker);
+       res
 }
 // When waking, an error should be fine. Most likely we got two send_datas in a row, both of which
 // failed to fully write, but we only need to call write_buffer_space_avail() once. Otherwise, the
@@ -435,16 +439,15 @@ fn wake_socket_waker(orig_ptr: *const ()) {
 }
 fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
        let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
-       let sender = unsafe { (*sender_ptr).clone() };
+       let sender = unsafe { &*sender_ptr };
        let _ = sender.try_send(());
 }
 fn drop_socket_waker(orig_ptr: *const ()) {
-       let _orig_box = unsafe { Box::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
-       // _orig_box is now dropped
+       let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
+       // _orig_arc is now dropped
 }
-fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker {
-       let new_box = Box::leak(Box::new(unsafe { (*sender).clone() }));
-       let new_ptr = new_box as *const mpsc::Sender<()>;
+fn write_avail_to_waker(sender: &Arc<mpsc::Sender<()>>) -> task::RawWaker {
+       let new_ptr = Arc::into_raw(Arc::clone(&sender));
        task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE)
 }
 
@@ -452,12 +455,20 @@ fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker {
 /// type in the template of PeerHandler.
 pub struct SocketDescriptor {
        conn: Arc<Mutex<Connection>>,
+       // We store a copy of the mpsc::Sender to wake the read task in an Arc here. While we can
+       // simply clone the sender and store a copy in each waker, that would require allocating for
+       // each waker. Instead, we can simply `Arc::clone`, creating a new reference and store the
+       // pointer in the waker.
+       write_avail_sender: Arc<mpsc::Sender<()>>,
        id: u64,
 }
 impl SocketDescriptor {
        fn new(conn: Arc<Mutex<Connection>>) -> Self {
-               let id = conn.lock().unwrap().id;
-               Self { conn, id }
+               let (id, write_avail_sender) = {
+                       let us = conn.lock().unwrap();
+                       (us.id, Arc::new(us.write_avail.clone()))
+               };
+               Self { conn, id, write_avail_sender }
        }
 }
 impl peer_handler::SocketDescriptor for SocketDescriptor {
@@ -480,7 +491,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor {
                        let _ = us.read_waker.try_send(());
                }
                if data.is_empty() { return 0; }
-               let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&us.write_avail)) };
+               let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) };
                let mut ctx = task::Context::from_waker(&waker);
                let mut written_len = 0;
                loop {
@@ -522,6 +533,7 @@ impl Clone for SocketDescriptor {
                Self {
                        conn: Arc::clone(&self.conn),
                        id: self.id,
+                       write_avail_sender: Arc::clone(&self.write_avail_sender),
                }
        }
 }