X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning-net-tokio%2Fsrc%2Flib.rs;h=71d63ecadcfa96e3b7fbd4a2c7c1076af2470e25;hb=2f734f97550014e5424e55523ed46d76b94b737d;hp=6e2ea3f14c14f54b890f8ff504e4f3fe6dd0a304;hpb=073f0780f6b2d72de57e5bb5a7b690c0206fa40c;p=rust-lightning diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index 6e2ea3f1..71d63eca 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -22,24 +22,22 @@ //! //! [`PeerManager`]: lightning::ln::peer_handler::PeerManager -// Prefix these with `rustdoc::` when we update our MSRV to be >= 1.52 to remove warnings. -#![deny(broken_intra_doc_links)] -#![deny(private_intra_doc_links)] +#![deny(rustdoc::broken_intra_doc_links)] +#![deny(rustdoc::private_intra_doc_links)] #![deny(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] use bitcoin::secp256k1::PublicKey; -use tokio::net::{tcp, TcpStream}; -use tokio::{io, time}; +use tokio::net::TcpStream; +use tokio::time; use tokio::sync::mpsc; -use tokio::io::AsyncWrite; use lightning::ln::peer_handler; use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait; use lightning::ln::peer_handler::APeerManager; -use lightning::ln::msgs::NetAddress; +use lightning::ln::msgs::SocketAddress; use std::ops::Deref; use std::task::{self, Poll}; @@ -210,7 +208,12 @@ impl Connection { break Disconnect::CloseConnection; } }, - SelectorOutput::B(_) => {}, + SelectorOutput::B(some) => { + // The mpsc Receiver should only return `None` if the write side has been + // dropped, but that shouldn't be possible since its referenced by the Self in + // `us`. + debug_assert!(some.is_some()); + }, SelectorOutput::C(res) => { if res.is_err() { break Disconnect::PeerDisconnected; } match reader.try_read(&mut buf) { @@ -231,7 +234,7 @@ impl Connection { // readable() is allowed to spuriously wake, so we have to handle // WouldBlock here. }, - Err(e) => break Disconnect::PeerDisconnected, + Err(_) => break Disconnect::PeerDisconnected, } }, } @@ -274,13 +277,13 @@ impl Connection { } } -fn get_addr_from_stream(stream: &StdTcpStream) -> Option { +fn get_addr_from_stream(stream: &StdTcpStream) -> Option { match stream.peer_addr() { - Ok(SocketAddr::V4(sockaddr)) => Some(NetAddress::IPv4 { + Ok(SocketAddr::V4(sockaddr)) => Some(SocketAddress::TcpIpV4 { addr: sockaddr.ip().octets(), port: sockaddr.port(), }), - Ok(SocketAddr::V6(sockaddr)) => Some(NetAddress::IPv6 { + Ok(SocketAddr::V6(sockaddr)) => Some(SocketAddress::TcpIpV6 { addr: sockaddr.ip().octets(), port: sockaddr.port(), }), @@ -423,7 +426,11 @@ const SOCK_WAKER_VTABLE: task::RawWakerVTable = task::RawWakerVTable::new(clone_socket_waker, wake_socket_waker, wake_socket_waker_by_ref, drop_socket_waker); fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker { - write_avail_to_waker(orig_ptr as *const mpsc::Sender<()>) + let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) }; + let res = write_avail_to_waker(&new_waker); + // Don't decrement the refcount when dropping new_waker by turning it back `into_raw`. + let _ = Arc::into_raw(new_waker); + res } // When waking, an error should be fine. Most likely we got two send_datas in a row, both of which // failed to fully write, but we only need to call write_buffer_space_avail() once. Otherwise, the @@ -436,16 +443,15 @@ fn wake_socket_waker(orig_ptr: *const ()) { } fn wake_socket_waker_by_ref(orig_ptr: *const ()) { let sender_ptr = orig_ptr as *const mpsc::Sender<()>; - let sender = unsafe { (*sender_ptr).clone() }; + let sender = unsafe { &*sender_ptr }; let _ = sender.try_send(()); } fn drop_socket_waker(orig_ptr: *const ()) { - let _orig_box = unsafe { Box::from_raw(orig_ptr as *mut mpsc::Sender<()>) }; - // _orig_box is now dropped + let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) }; + // _orig_arc is now dropped } -fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker { - let new_box = Box::leak(Box::new(unsafe { (*sender).clone() })); - let new_ptr = new_box as *const mpsc::Sender<()>; +fn write_avail_to_waker(sender: &Arc>) -> task::RawWaker { + let new_ptr = Arc::into_raw(Arc::clone(&sender)); task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE) } @@ -453,12 +459,20 @@ fn write_avail_to_waker(sender: *const mpsc::Sender<()>) -> task::RawWaker { /// type in the template of PeerHandler. pub struct SocketDescriptor { conn: Arc>, + // We store a copy of the mpsc::Sender to wake the read task in an Arc here. While we can + // simply clone the sender and store a copy in each waker, that would require allocating for + // each waker. Instead, we can simply `Arc::clone`, creating a new reference and store the + // pointer in the waker. + write_avail_sender: Arc>, id: u64, } impl SocketDescriptor { fn new(conn: Arc>) -> Self { - let id = conn.lock().unwrap().id; - Self { conn, id } + let (id, write_avail_sender) = { + let us = conn.lock().unwrap(); + (us.id, Arc::new(us.write_avail.clone())) + }; + Self { conn, id, write_avail_sender } } } impl peer_handler::SocketDescriptor for SocketDescriptor { @@ -481,7 +495,7 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { let _ = us.read_waker.try_send(()); } if data.is_empty() { return 0; } - let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&us.write_avail)) }; + let waker = unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) }; let mut ctx = task::Context::from_waker(&waker); let mut written_len = 0; loop { @@ -493,10 +507,13 @@ impl peer_handler::SocketDescriptor for SocketDescriptor { written_len += res; if written_len == data.len() { return written_len; } }, - Err(e) => return written_len, + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + continue; + } + Err(_) => return written_len, } }, - task::Poll::Ready(Err(e)) => return written_len, + task::Poll::Ready(Err(_)) => return written_len, task::Poll::Pending => { // We're queued up for a write event now, but we need to make sure we also // pause read given we're now waiting on the remote end to ACK (and in @@ -523,6 +540,7 @@ impl Clone for SocketDescriptor { Self { conn: Arc::clone(&self.conn), id: self.id, + write_avail_sender: Arc::clone(&self.write_avail_sender), } } } @@ -543,7 +561,6 @@ mod tests { use lightning::ln::features::*; use lightning::ln::msgs::*; use lightning::ln::peer_handler::{MessageHandler, PeerManager}; - use lightning::ln::features::NodeFeatures; use lightning::routing::gossip::NodeId; use lightning::events::*; use lightning::util::test_utils::TestNodeSigner; @@ -560,7 +577,7 @@ mod tests { pub struct TestLogger(); impl lightning::util::logger::Logger for TestLogger { - fn log(&self, record: &lightning::util::logger::Record) { + fn log(&self, record: lightning::util::logger::Record) { println!("{:<5} [{} : {}, {}] {}", record.level.to_string(), record.module_path, record.file, record.line, record.args); } } @@ -606,6 +623,10 @@ mod tests { fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &ChannelUpdate) {} fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, _msg: &OpenChannelV2) {} fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, _msg: &AcceptChannelV2) {} + fn handle_stfu(&self, _their_node_id: &PublicKey, _msg: &Stfu) {} + fn handle_splice(&self, _their_node_id: &PublicKey, _msg: &Splice) {} + fn handle_splice_ack(&self, _their_node_id: &PublicKey, _msg: &SpliceAck) {} + fn handle_splice_locked(&self, _their_node_id: &PublicKey, _msg: &SpliceLocked) {} fn handle_tx_add_input(&self, _their_node_id: &PublicKey, _msg: &TxAddInput) {} fn handle_tx_add_output(&self, _their_node_id: &PublicKey, _msg: &TxAddOutput) {} fn handle_tx_remove_input(&self, _their_node_id: &PublicKey, _msg: &TxRemoveInput) {} @@ -631,7 +652,7 @@ mod tests { fn handle_error(&self, _their_node_id: &PublicKey, _msg: &ErrorMessage) {} fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() } fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures { InitFeatures::empty() } - fn get_genesis_hashes(&self) -> Option> { + fn get_chain_hashes(&self) -> Option> { Some(vec![ChainHash::using_genesis_block(Network::Testnet)]) } }