X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning-net-tokio%2Fsrc%2Flib.rs;h=be41a2401244f0bff743e44b9e7ec26482d95227;hb=baf089728df765961966c8ec08b425a19a8806be;hp=bac18b2b398cac33ebbb9523fc0a2811da914b75;hpb=54f96ef944423eac98d302fbc7cdcdc136d58312;p=rust-lightning diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index bac18b2b3..be41a2401 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -22,9 +22,8 @@ //! //! [`PeerManager`]: lightning::ln::peer_handler::PeerManager -// Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. -#![deny(broken_intra_doc_links)] -#![deny(private_intra_doc_links)] +#![deny(rustdoc::broken_intra_doc_links)] +#![deny(rustdoc::private_intra_doc_links)] #![deny(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] @@ -422,7 +421,11 @@ const SOCK_WAKER_VTABLE: task::RawWakerVTable = task::RawWakerVTable::new(clone_socket_waker, wake_socket_waker, wake_socket_waker_by_ref, drop_socket_waker); fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker { - write_avail_to_waker(orig_ptr as *const mpsc::Sender<()>) + let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) }; + let res = write_avail_to_waker(&new_waker); + // Don't decrement the refcount when dropping new_waker by turning it back `into_raw`. + let _ = Arc::into_raw(new_waker); + res } // When waking, an error should be fine. Most likely we got two send_datas in a row, both of which // failed to fully write, but we only need to call write_buffer_space_avail() once. Otherwise, the @@ -435,16 +438,15 @@ fn wake_socket_waker(orig_ptr: *const ()) { } fn wake_socket_waker_by_ref(orig_ptr: *const ()) { let sender_ptr = orig_ptr as *const mpsc::Sender<()>; - let sender = unsafe { (*sender_ptr).clone() }; + let sender = unsafe { &*sender_ptr }; let _ = sender.try_send(()); } fn drop_socket_waker(orig_ptr: *const ()) { - let _orig_box = unsafe { Box::from_raw(orig_ptr as *mut mpsc::Sender<()>) }; - // _orig_box is now dropped + let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) }; + // _orig_arc is now dropped } -fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker { - let new_box = Box::leak(Box::new(unsafe { (*sender).clone() })); - let new_ptr = new_box as *const mpsc::Sender<()>; +fn write_avail_to_waker(sender: &Arc>) -> task::RawWaker { + let new_ptr = Arc::into_raw(Arc::clone(&sender)); task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE) } @@ -452,12 +454,20 @@ fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker { /// type in the template of PeerHandler. pub struct SocketDescriptor { conn: Arc>, + // We store a copy of the mpsc::Sender to wake the read task in an Arc here. While we can + // simply clone the sender and store a copy in each waker, that would require allocating for + // each waker. Instead, we can simply `Arc::clone`, creating a new reference and store the + // pointer in the waker. + write_avail_sender: Arc>, id: u64, } impl SocketDescriptor { fn new(conn: Arc>) -> Self { - let id = conn.lock().unwrap().id; - Self { conn, id } + let (id, write_avail_sender) = { + let us = conn.lock().unwrap(); + (us.id, Arc::new(us.write_avail.clone())) + }; + Self { conn, id, write_avail_sender } } } impl peer_handler::SocketDescriptor for SocketDescriptor { @@ -480,7 +490,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { let _ = us.read_waker.try_send(()); } if data.is_empty() { return 0; } - let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&us.write_avail)) }; + let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) }; let mut ctx = task::Context::from_waker(&waker); let mut written_len = 0; loop { @@ -492,6 +502,9 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { written_len += res; if written_len == data.len() { return written_len; } }, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } Err(_) => return written_len, } }, @@ -522,6 +535,7 @@ impl Clone for SocketDescriptor { Self { conn: Arc::clone(&self.conn), id: self.id, + write_avail_sender: Arc::clone(&self.write_avail_sender), } } } @@ -559,7 +573,7 @@ mod tests { pub struct TestLogger(); impl lightning::util::logger::Logger for TestLogger { - fn log(&self, record: &lightning::util::logger::Record) { + fn log(&self, record: lightning::util::logger::Record) { println!("{:<5} [{} : {}, {}] {}", record.level.to_string(), record.module_path, record.file, record.line, record.args); } } @@ -605,6 +619,10 @@ mod tests { fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &ChannelUpdate) {} fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, _msg: &OpenChannelV2) {} fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, _msg: &AcceptChannelV2) {} + fn handle_stfu(&self, _their_node_id: &PublicKey, _msg: &Stfu) {} + fn handle_splice(&self, _their_node_id: &PublicKey, _msg: &Splice) {} + fn handle_splice_ack(&self, _their_node_id: &PublicKey, _msg: &SpliceAck) {} + fn handle_splice_locked(&self, _their_node_id: &PublicKey, _msg: &SpliceLocked) {} fn handle_tx_add_input(&self, _their_node_id: &PublicKey, _msg: &TxAddInput) {} fn handle_tx_add_output(&self, _their_node_id: &PublicKey, _msg: &TxAddOutput) {} fn handle_tx_remove_input(&self, _their_node_id: &PublicKey, _msg: &TxRemoveInput) {}