From cc7f859c01570c300bb53b1ebf924d6c6c307141 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 12 Apr 2022 19:16:38 +0000 Subject: [PATCH] Add support for testing recvd messages in TestChannelMessageHandler --- lightning/src/util/test_utils.rs | 109 +++++++++++++++++++++++++------ 1 file changed, 89 insertions(+), 20 deletions(-) diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index f6872430..0aa39e8a 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -18,7 +18,7 @@ use chain::channelmonitor::MonitorEvent; use chain::transaction::OutPoint; use chain::keysinterface; use ln::features::{ChannelFeatures, InitFeatures}; -use ln::msgs; +use ln::{msgs, wire}; use ln::msgs::OptionalField; use ln::script::ShutdownScript; use routing::scoring::FixedPenaltyScorer; @@ -249,37 +249,106 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster { pub struct TestChannelMessageHandler { pub pending_events: Mutex>, + expected_recv_msgs: Mutex>>>, } impl TestChannelMessageHandler { pub fn new() -> Self { TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), + expected_recv_msgs: Mutex::new(None), + } + } + + #[cfg(test)] + pub(crate) fn expect_receive_msg(&self, ev: wire::Message<()>) { + let mut expected_msgs = self.expected_recv_msgs.lock().unwrap(); + if expected_msgs.is_none() { *expected_msgs = Some(Vec::new()); } + expected_msgs.as_mut().unwrap().push(ev); + } + + fn received_msg(&self, ev: wire::Message<()>) { + let mut msgs = self.expected_recv_msgs.lock().unwrap(); + if msgs.is_none() { return; } + assert!(!msgs.as_ref().unwrap().is_empty(), "Received message when we weren't expecting one"); + #[cfg(test)] + assert_eq!(msgs.as_ref().unwrap()[0], ev); + msgs.as_mut().unwrap().remove(0); + } +} + +impl Drop for TestChannelMessageHandler { + fn drop(&mut self) { + let l = self.expected_recv_msgs.lock().unwrap(); + #[cfg(feature = "std")] + { + if !std::thread::panicking() { + assert!(l.is_none() || l.as_ref().unwrap().is_empty()); + } } } } impl msgs::ChannelMessageHandler for TestChannelMessageHandler { - fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::OpenChannel) {} - fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::AcceptChannel) {} - fn handle_funding_created(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingCreated) {} - fn handle_funding_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingSigned) {} - fn handle_funding_locked(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingLocked) {} - fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, _msg: &msgs::Shutdown) {} - fn handle_closing_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::ClosingSigned) {} - fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateAddHTLC) {} - fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFulfillHTLC) {} - fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailHTLC) {} - fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailMalformedHTLC) {} - fn handle_commitment_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::CommitmentSigned) {} - fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, _msg: &msgs::RevokeAndACK) {} - fn handle_update_fee(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFee) {} - fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) {} - fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &msgs::AnnouncementSignatures) {} - fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelReestablish) {} + fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::OpenChannel) { + self.received_msg(wire::Message::OpenChannel(msg.clone())); + } + fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::AcceptChannel) { + self.received_msg(wire::Message::AcceptChannel(msg.clone())); + } + fn handle_funding_created(&self, _their_node_id: &PublicKey, msg: &msgs::FundingCreated) { + self.received_msg(wire::Message::FundingCreated(msg.clone())); + } + fn handle_funding_signed(&self, _their_node_id: &PublicKey, msg: &msgs::FundingSigned) { + self.received_msg(wire::Message::FundingSigned(msg.clone())); + } + fn handle_funding_locked(&self, _their_node_id: &PublicKey, msg: &msgs::FundingLocked) { + self.received_msg(wire::Message::FundingLocked(msg.clone())); + } + fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, msg: &msgs::Shutdown) { + self.received_msg(wire::Message::Shutdown(msg.clone())); + } + fn handle_closing_signed(&self, _their_node_id: &PublicKey, msg: &msgs::ClosingSigned) { + self.received_msg(wire::Message::ClosingSigned(msg.clone())); + } + fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateAddHTLC) { + self.received_msg(wire::Message::UpdateAddHTLC(msg.clone())); + } + fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) { + self.received_msg(wire::Message::UpdateFulfillHTLC(msg.clone())); + } + fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) { + self.received_msg(wire::Message::UpdateFailHTLC(msg.clone())); + } + fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) { + self.received_msg(wire::Message::UpdateFailMalformedHTLC(msg.clone())); + } + fn handle_commitment_signed(&self, _their_node_id: &PublicKey, msg: &msgs::CommitmentSigned) { + self.received_msg(wire::Message::CommitmentSigned(msg.clone())); + } + fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, msg: &msgs::RevokeAndACK) { + self.received_msg(wire::Message::RevokeAndACK(msg.clone())); + } + fn handle_update_fee(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFee) { + self.received_msg(wire::Message::UpdateFee(msg.clone())); + } + fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) { + // Don't call `received_msg` here as `TestRoutingMessageHandler` generates these sometimes + } + fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) { + self.received_msg(wire::Message::AnnouncementSignatures(msg.clone())); + } + fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReestablish) { + self.received_msg(wire::Message::ChannelReestablish(msg.clone())); + } fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {} - fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) {} - fn handle_error(&self, _their_node_id: &PublicKey, _msg: &msgs::ErrorMessage) {} + fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) { + // Don't bother with `received_msg` for Init as its auto-generated and we don't want to + // bother re-generating the expected Init message in all tests. + } + fn handle_error(&self, _their_node_id: &PublicKey, msg: &msgs::ErrorMessage) { + self.received_msg(wire::Message::Error(msg.clone())); + } } impl events::MessageSendEventsProvider for TestChannelMessageHandler { -- 2.30.2