Add support for testing recvd messages in TestChannelMessageHandler
[rust-lightning] / lightning / src / util / test_utils.rs
index 3c36cdf066a5048bbb1b6fe112c2906ac824e6cc..0aa39e8a979ee90c639a6ee97c6f6afb2d1fb7a5 100644 (file)
@@ -18,7 +18,7 @@ use chain::channelmonitor::MonitorEvent;
 use chain::transaction::OutPoint;
 use chain::keysinterface;
 use ln::features::{ChannelFeatures, InitFeatures};
-use ln::msgs;
+use ln::{msgs, wire};
 use ln::msgs::OptionalField;
 use ln::script::ShutdownScript;
 use routing::scoring::FixedPenaltyScorer;
@@ -49,6 +49,9 @@ use core::{cmp, mem};
 use bitcoin::bech32::u5;
 use chain::keysinterface::{InMemorySigner, Recipient, KeyMaterial};
 
+#[cfg(feature = "std")]
+use std::time::{SystemTime, UNIX_EPOCH};
+
 pub struct TestVecWriter(pub Vec<u8>);
 impl Writer for TestVecWriter {
        fn write_all(&mut self, buf: &[u8]) -> Result<(), io::Error> {
@@ -246,37 +249,106 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster {
 
 pub struct TestChannelMessageHandler {
        pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
+       expected_recv_msgs: Mutex<Option<Vec<wire::Message<()>>>>,
 }
 
 impl TestChannelMessageHandler {
        pub fn new() -> Self {
                TestChannelMessageHandler {
                        pending_events: Mutex::new(Vec::new()),
+                       expected_recv_msgs: Mutex::new(None),
+               }
+       }
+
+       #[cfg(test)]
+       pub(crate) fn expect_receive_msg(&self, ev: wire::Message<()>) {
+               let mut expected_msgs = self.expected_recv_msgs.lock().unwrap();
+               if expected_msgs.is_none() { *expected_msgs = Some(Vec::new()); }
+               expected_msgs.as_mut().unwrap().push(ev);
+       }
+
+       fn received_msg(&self, ev: wire::Message<()>) {
+               let mut msgs = self.expected_recv_msgs.lock().unwrap();
+               if msgs.is_none() { return; }
+               assert!(!msgs.as_ref().unwrap().is_empty(), "Received message when we weren't expecting one");
+               #[cfg(test)]
+               assert_eq!(msgs.as_ref().unwrap()[0], ev);
+               msgs.as_mut().unwrap().remove(0);
+       }
+}
+
+impl Drop for TestChannelMessageHandler {
+       fn drop(&mut self) {
+               let l = self.expected_recv_msgs.lock().unwrap();
+               #[cfg(feature = "std")]
+               {
+                       if !std::thread::panicking() {
+                               assert!(l.is_none() || l.as_ref().unwrap().is_empty());
+                       }
                }
        }
 }
 
 impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
-       fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::OpenChannel) {}
-       fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::AcceptChannel) {}
-       fn handle_funding_created(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingCreated) {}
-       fn handle_funding_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingSigned) {}
-       fn handle_funding_locked(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingLocked) {}
-       fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, _msg: &msgs::Shutdown) {}
-       fn handle_closing_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::ClosingSigned) {}
-       fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateAddHTLC) {}
-       fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFulfillHTLC) {}
-       fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailHTLC) {}
-       fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailMalformedHTLC) {}
-       fn handle_commitment_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::CommitmentSigned) {}
-       fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, _msg: &msgs::RevokeAndACK) {}
-       fn handle_update_fee(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFee) {}
-       fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) {}
-       fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &msgs::AnnouncementSignatures) {}
-       fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelReestablish) {}
+       fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::OpenChannel) {
+               self.received_msg(wire::Message::OpenChannel(msg.clone()));
+       }
+       fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::AcceptChannel) {
+               self.received_msg(wire::Message::AcceptChannel(msg.clone()));
+       }
+       fn handle_funding_created(&self, _their_node_id: &PublicKey, msg: &msgs::FundingCreated) {
+               self.received_msg(wire::Message::FundingCreated(msg.clone()));
+       }
+       fn handle_funding_signed(&self, _their_node_id: &PublicKey, msg: &msgs::FundingSigned) {
+               self.received_msg(wire::Message::FundingSigned(msg.clone()));
+       }
+       fn handle_funding_locked(&self, _their_node_id: &PublicKey, msg: &msgs::FundingLocked) {
+               self.received_msg(wire::Message::FundingLocked(msg.clone()));
+       }
+       fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, msg: &msgs::Shutdown) {
+               self.received_msg(wire::Message::Shutdown(msg.clone()));
+       }
+       fn handle_closing_signed(&self, _their_node_id: &PublicKey, msg: &msgs::ClosingSigned) {
+               self.received_msg(wire::Message::ClosingSigned(msg.clone()));
+       }
+       fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateAddHTLC) {
+               self.received_msg(wire::Message::UpdateAddHTLC(msg.clone()));
+       }
+       fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) {
+               self.received_msg(wire::Message::UpdateFulfillHTLC(msg.clone()));
+       }
+       fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) {
+               self.received_msg(wire::Message::UpdateFailHTLC(msg.clone()));
+       }
+       fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) {
+               self.received_msg(wire::Message::UpdateFailMalformedHTLC(msg.clone()));
+       }
+       fn handle_commitment_signed(&self, _their_node_id: &PublicKey, msg: &msgs::CommitmentSigned) {
+               self.received_msg(wire::Message::CommitmentSigned(msg.clone()));
+       }
+       fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, msg: &msgs::RevokeAndACK) {
+               self.received_msg(wire::Message::RevokeAndACK(msg.clone()));
+       }
+       fn handle_update_fee(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFee) {
+               self.received_msg(wire::Message::UpdateFee(msg.clone()));
+       }
+       fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) {
+               // Don't call `received_msg` here as `TestRoutingMessageHandler` generates these sometimes
+       }
+       fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) {
+               self.received_msg(wire::Message::AnnouncementSignatures(msg.clone()));
+       }
+       fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReestablish) {
+               self.received_msg(wire::Message::ChannelReestablish(msg.clone()));
+       }
        fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {}
-       fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) {}
-       fn handle_error(&self, _their_node_id: &PublicKey, _msg: &msgs::ErrorMessage) {}
+       fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) {
+               // Don't bother with `received_msg` for Init as its auto-generated and we don't want to
+               // bother re-generating the expected Init message in all tests.
+       }
+       fn handle_error(&self, _their_node_id: &PublicKey, msg: &msgs::ErrorMessage) {
+               self.received_msg(wire::Message::Error(msg.clone()));
+       }
 }
 
 impl events::MessageSendEventsProvider for TestChannelMessageHandler {
@@ -341,6 +413,7 @@ fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate {
 pub struct TestRoutingMessageHandler {
        pub chan_upds_recvd: AtomicUsize,
        pub chan_anns_recvd: AtomicUsize,
+       pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
        pub request_full_sync: AtomicBool,
 }
 
@@ -349,6 +422,7 @@ impl TestRoutingMessageHandler {
                TestRoutingMessageHandler {
                        chan_upds_recvd: AtomicUsize::new(0),
                        chan_anns_recvd: AtomicUsize::new(0),
+                       pending_events: Mutex::new(vec![]),
                        request_full_sync: AtomicBool::new(false),
                }
        }
@@ -384,7 +458,35 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
                Vec::new()
        }
 
-       fn peer_connected(&self, _their_node_id: &PublicKey, _init_msg: &msgs::Init) {}
+       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) {
+               if !init_msg.features.supports_gossip_queries() {
+                       return ();
+               }
+
+               let should_request_full_sync = self.request_full_sync.load(Ordering::Acquire);
+
+               #[allow(unused_mut, unused_assignments)]
+               let mut gossip_start_time = 0;
+               #[cfg(feature = "std")]
+               {
+                       gossip_start_time = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs();
+                       if should_request_full_sync {
+                               gossip_start_time -= 60 * 60 * 24 * 7 * 2; // 2 weeks ago
+                       } else {
+                               gossip_start_time -= 60 * 60; // an hour ago
+                       }
+               }
+
+               let mut pending_events = self.pending_events.lock().unwrap();
+               pending_events.push(events::MessageSendEvent::SendGossipTimestampFilter {
+                       node_id: their_node_id.clone(),
+                       msg: msgs::GossipTimestampFilter {
+                               chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                               first_timestamp: gossip_start_time as u32,
+                               timestamp_range: u32::max_value(),
+                       },
+               });
+       }
 
        fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), msgs::LightningError> {
                Ok(())
@@ -405,7 +507,10 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
 
 impl events::MessageSendEventsProvider for TestRoutingMessageHandler {
        fn get_and_clear_pending_msg_events(&self) -> Vec<events::MessageSendEvent> {
-               vec![]
+               let mut ret = Vec::new();
+               let mut pending_events = self.pending_events.lock().unwrap();
+               core::mem::swap(&mut ret, &mut pending_events);
+               ret
        }
 }