This requires ensuring TcpStreams are set in nonblocking mode as
tokio doesn't handle this for us anymore, so we adapt the public
API to just accept std TcpStreams instead of an extra conversion
hop. Luckily converting them is cheap.
[dependencies]
bitcoin = "0.24"
lightning = { version = "0.0.12", path = "../lightning" }
[dependencies]
bitcoin = "0.24"
lightning = { version = "0.0.12", path = "../lightning" }
-tokio = { version = ">=0.2.12", features = [ "io-util", "macros", "rt-core", "sync", "tcp", "time" ] }
+tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "sync", "net", "time" ] }
-tokio = { version = ">=0.2.12", features = [ "io-util", "macros", "rt-core", "rt-threaded", "sync", "tcp", "time" ] }
+tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "rt-multi-thread", "sync", "net", "time" ] }
//! The call site should, thus, look something like this:
//! ```
//! use tokio::sync::mpsc;
//! The call site should, thus, look something like this:
//! ```
//! use tokio::sync::mpsc;
-//! use tokio::net::TcpStream;
+//! use std::net::TcpStream;
//! use bitcoin::secp256k1::key::PublicKey;
//! use lightning::util::events::EventsProvider;
//! use std::net::SocketAddr;
//! use bitcoin::secp256k1::key::PublicKey;
//! use lightning::util::events::EventsProvider;
//! use std::net::SocketAddr;
use std::{task, thread};
use std::net::SocketAddr;
use std::{task, thread};
use std::net::SocketAddr;
+use std::net::TcpStream as StdTcpStream;
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
- fn new(event_notify: mpsc::Sender<()>, stream: TcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
+ fn new(event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
// We only ever need a channel of depth 1 here: if we returned a non-full write to the
// PeerManager, we will eventually get notified that there is room in the socket to write
// new bytes, which will generate an event. That event will be popped off the queue before
// We only ever need a channel of depth 1 here: if we returned a non-full write to the
// PeerManager, we will eventually get notified that there is room in the socket to write
// new bytes, which will generate an event. That event will be popped off the queue before
// we shove a value into the channel which comes after we've reset the read_paused bool to
// false.
let (read_waker, read_receiver) = mpsc::channel(1);
// we shove a value into the channel which comes after we've reset the read_paused bool to
// false.
let (read_waker, read_receiver) = mpsc::channel(1);
- let (reader, writer) = io::split(stream);
+ stream.set_nonblocking(true).unwrap();
+ let (reader, writer) = io::split(TcpStream::from_std(stream).unwrap());
(reader, write_receiver, read_receiver,
Arc::new(Mutex::new(Self {
(reader, write_receiver, read_receiver,
Arc::new(Mutex::new(Self {
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
- if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), TcpStream::connect(&addr)).await {
+ if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
} else { None }
}
Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
} else { None }
}
}
fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
}
fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
- let mut sender = unsafe { (*sender_ptr).clone() };
+ let sender = unsafe { (*sender_ptr).clone() };
let _ = sender.try_send(());
}
fn drop_socket_waker(orig_ptr: *const ()) {
let _ = sender.try_send(());
}
fn drop_socket_waker(orig_ptr: *const ()) {
use tokio::sync::mpsc;
use std::mem;
use tokio::sync::mpsc;
use std::mem;
+ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::sync::{Arc, Mutex};
use std::time::Duration;
expected_pubkey: PublicKey,
pubkey_connected: mpsc::Sender<()>,
pubkey_disconnected: mpsc::Sender<()>,
expected_pubkey: PublicKey,
pubkey_connected: mpsc::Sender<()>,
pubkey_disconnected: mpsc::Sender<()>,
+ disconnected_flag: AtomicBool,
msg_events: Mutex<Vec<MessageSendEvent>>,
}
impl RoutingMessageHandler for MsgHandler {
msg_events: Mutex<Vec<MessageSendEvent>>,
}
impl RoutingMessageHandler for MsgHandler {
fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
if *their_node_id == self.expected_pubkey {
fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
if *their_node_id == self.expected_pubkey {
+ self.disconnected_flag.store(true, Ordering::SeqCst);
self.pubkey_disconnected.clone().try_send(()).unwrap();
}
}
self.pubkey_disconnected.clone().try_send(()).unwrap();
}
}
expected_pubkey: b_pub,
pubkey_connected: a_connected_sender,
pubkey_disconnected: a_disconnected_sender,
expected_pubkey: b_pub,
pubkey_connected: a_connected_sender,
pubkey_disconnected: a_disconnected_sender,
+ disconnected_flag: AtomicBool::new(false),
msg_events: Mutex::new(Vec::new()),
});
let a_manager = Arc::new(PeerManager::new(MessageHandler {
msg_events: Mutex::new(Vec::new()),
});
let a_manager = Arc::new(PeerManager::new(MessageHandler {
expected_pubkey: a_pub,
pubkey_connected: b_connected_sender,
pubkey_disconnected: b_disconnected_sender,
expected_pubkey: a_pub,
pubkey_connected: b_connected_sender,
pubkey_disconnected: b_disconnected_sender,
+ disconnected_flag: AtomicBool::new(false),
msg_events: Mutex::new(Vec::new()),
});
let b_manager = Arc::new(PeerManager::new(MessageHandler {
msg_events: Mutex::new(Vec::new()),
});
let b_manager = Arc::new(PeerManager::new(MessageHandler {
} else { panic!("Failed to bind to v4 localhost on common ports"); };
let (sender, _receiver) = mpsc::channel(2);
} else { panic!("Failed to bind to v4 localhost on common ports"); };
let (sender, _receiver) = mpsc::channel(2);
- let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, tokio::net::TcpStream::from_std(conn_a).unwrap());
- let fut_b = super::setup_inbound(b_manager, sender, tokio::net::TcpStream::from_std(conn_b).unwrap());
+ let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, conn_a);
+ let fut_b = super::setup_inbound(b_manager, sender, conn_b);
tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None }
});
a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None }
});
- assert!(a_disconnected.try_recv().is_err());
- assert!(b_disconnected.try_recv().is_err());
+ assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
+ assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
a_manager.process_events();
tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
a_manager.process_events();
tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
+ assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
+ assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
fut_a.await;
fut_b.await;
}
fut_a.await;
fut_b.await;
}
- #[tokio::test(threaded_scheduler)]
+ #[tokio::test(flavor = "multi_thread")]
async fn basic_threaded_connection_test() {
do_basic_connection_test().await;
}
async fn basic_threaded_connection_test() {
do_basic_connection_test().await;
}