//! The call site should, thus, look something like this:
//! ```
//! use tokio::sync::mpsc;
-//! use tokio::net::TcpStream;
+//! use std::net::TcpStream;
//! use bitcoin::secp256k1::key::PublicKey;
//! use lightning::util::events::EventsProvider;
//! use std::net::SocketAddr;
use std::{task, thread};
use std::net::SocketAddr;
+use std::net::TcpStream as StdTcpStream;
use std::sync::{Arc, Mutex, MutexGuard};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
}
}
- fn new(event_notify: mpsc::Sender<()>, stream: TcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
+ fn new(event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> (io::ReadHalf<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
// We only ever need a channel of depth 1 here: if we returned a non-full write to the
// PeerManager, we will eventually get notified that there is room in the socket to write
// new bytes, which will generate an event. That event will be popped off the queue before
// we shove a value into the channel which comes after we've reset the read_paused bool to
// false.
let (read_waker, read_receiver) = mpsc::channel(1);
- let (reader, writer) = io::split(stream);
+ stream.set_nonblocking(true).unwrap();
+ let (reader, writer) = io::split(TcpStream::from_std(stream).unwrap());
(reader, write_receiver, read_receiver,
Arc::new(Mutex::new(Self {
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_inbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
/// not need to poll the provided future in order to make progress.
///
/// See the module-level documentation for how to handle the event_notify mpsc::Sender.
-pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) -> impl std::future::Future<Output=()> where
+pub fn setup_outbound<CMH, RMH, L>(peer_manager: Arc<peer_handler::PeerManager<SocketDescriptor, Arc<CMH>, Arc<RMH>, Arc<L>>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: StdTcpStream) -> impl std::future::Future<Output=()> where
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
CMH: ChannelMessageHandler + 'static,
RMH: RoutingMessageHandler + 'static,
L: Logger + 'static + ?Sized {
- if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), TcpStream::connect(&addr)).await {
+ if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) }).await {
Some(setup_outbound(peer_manager, event_notify, their_node_id, stream))
} else { None }
}
}
fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
- let mut sender = unsafe { (*sender_ptr).clone() };
+ let sender = unsafe { (*sender_ptr).clone() };
let _ = sender.try_send(());
}
fn drop_socket_waker(orig_ptr: *const ()) {
use tokio::sync::mpsc;
use std::mem;
+ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
expected_pubkey: PublicKey,
pubkey_connected: mpsc::Sender<()>,
pubkey_disconnected: mpsc::Sender<()>,
+ disconnected_flag: AtomicBool,
msg_events: Mutex<Vec<MessageSendEvent>>,
}
impl RoutingMessageHandler for MsgHandler {
fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
if *their_node_id == self.expected_pubkey {
+ self.disconnected_flag.store(true, Ordering::SeqCst);
self.pubkey_disconnected.clone().try_send(()).unwrap();
}
}
expected_pubkey: b_pub,
pubkey_connected: a_connected_sender,
pubkey_disconnected: a_disconnected_sender,
+ disconnected_flag: AtomicBool::new(false),
msg_events: Mutex::new(Vec::new()),
});
let a_manager = Arc::new(PeerManager::new(MessageHandler {
expected_pubkey: a_pub,
pubkey_connected: b_connected_sender,
pubkey_disconnected: b_disconnected_sender,
+ disconnected_flag: AtomicBool::new(false),
msg_events: Mutex::new(Vec::new()),
});
let b_manager = Arc::new(PeerManager::new(MessageHandler {
} else { panic!("Failed to bind to v4 localhost on common ports"); };
let (sender, _receiver) = mpsc::channel(2);
- let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, tokio::net::TcpStream::from_std(conn_a).unwrap());
- let fut_b = super::setup_inbound(b_manager, sender, tokio::net::TcpStream::from_std(conn_b).unwrap());
+ let fut_a = super::setup_outbound(Arc::clone(&a_manager), sender.clone(), b_pub, conn_a);
+ let fut_b = super::setup_inbound(b_manager, sender, conn_b);
tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
node_id: b_pub, action: ErrorAction::DisconnectPeer { msg: None }
});
- assert!(a_disconnected.try_recv().is_err());
- assert!(b_disconnected.try_recv().is_err());
+ assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
+ assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
a_manager.process_events();
tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
+ assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
+ assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
fut_a.await;
fut_b.await;
}
- #[tokio::test(threaded_scheduler)]
+ #[tokio::test(flavor = "multi_thread")]
async fn basic_threaded_connection_test() {
do_basic_connection_test().await;
}