Merge pull request #1730 from TheBlueMatt/2022-09-111-bindings-and-backports
[rust-lightning] / lightning / src / util / test_utils.rs
index b4d26904706240de647eec14ef97ef5b175ff8cf..f0e264b4dd19ba9ad129167bc35f8349e18aaf54 100644 (file)
@@ -17,7 +17,7 @@ use chain::channelmonitor;
 use chain::channelmonitor::MonitorEvent;
 use chain::transaction::OutPoint;
 use chain::keysinterface;
-use ln::features::{ChannelFeatures, InitFeatures};
+use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
 use ln::{msgs, wire};
 use ln::script::ShutdownScript;
 use routing::scoring::FixedPenaltyScorer;
@@ -287,9 +287,9 @@ impl TestChannelMessageHandler {
 
 impl Drop for TestChannelMessageHandler {
        fn drop(&mut self) {
-               let l = self.expected_recv_msgs.lock().unwrap();
                #[cfg(feature = "std")]
                {
+                       let l = self.expected_recv_msgs.lock().unwrap();
                        if !std::thread::panicking() {
                                assert!(l.is_none() || l.as_ref().unwrap().is_empty());
                        }
@@ -357,6 +357,12 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
        fn handle_error(&self, _their_node_id: &PublicKey, msg: &msgs::ErrorMessage) {
                self.received_msg(wire::Message::Error(msg.clone()));
        }
+       fn provided_node_features(&self) -> NodeFeatures {
+               NodeFeatures::known_channel_features()
+       }
+       fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures {
+               InitFeatures::known_channel_features()
+       }
 }
 
 impl events::MessageSendEventsProvider for TestChannelMessageHandler {
@@ -464,14 +470,12 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
                        return ();
                }
 
-               let should_request_full_sync = self.request_full_sync.load(Ordering::Acquire);
-
                #[allow(unused_mut, unused_assignments)]
                let mut gossip_start_time = 0;
                #[cfg(feature = "std")]
                {
                        gossip_start_time = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs();
-                       if should_request_full_sync {
+                       if self.request_full_sync.load(Ordering::Acquire) {
                                gossip_start_time -= 60 * 60 * 24 * 7 * 2; // 2 weeks ago
                        } else {
                                gossip_start_time -= 60 * 60; // an hour ago
@@ -504,6 +508,18 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
        fn handle_query_short_channel_ids(&self, _their_node_id: &PublicKey, _msg: msgs::QueryShortChannelIds) -> Result<(), msgs::LightningError> {
                Ok(())
        }
+
+       fn provided_node_features(&self) -> NodeFeatures {
+               let mut features = NodeFeatures::empty();
+               features.set_gossip_queries_optional();
+               features
+       }
+
+       fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures {
+               let mut features = InitFeatures::empty();
+               features.set_gossip_queries_optional();
+               features
+       }
 }
 
 impl events::MessageSendEventsProvider for TestRoutingMessageHandler {
@@ -517,10 +533,7 @@ impl events::MessageSendEventsProvider for TestRoutingMessageHandler {
 
 pub struct TestLogger {
        level: Level,
-       #[cfg(feature = "std")]
-       id: String,
-       #[cfg(not(feature = "std"))]
-       _id: String,
+       pub(crate) id: String,
        pub lines: Mutex<HashMap<(String, String), usize>>,
 }
 
@@ -531,10 +544,7 @@ impl TestLogger {
        pub fn with_id(id: String) -> TestLogger {
                TestLogger {
                        level: Level::Trace,
-                       #[cfg(feature = "std")]
                        id,
-                       #[cfg(not(feature = "std"))]
-                       _id: id,
                        lines: Mutex::new(HashMap::new())
                }
        }
@@ -558,10 +568,10 @@ impl TestLogger {
                assert_eq!(l, count)
        }
 
-    /// Search for the number of occurrences of logged lines which
-    /// 1. belong to the specified module and
-    /// 2. match the given regex pattern.
-    /// Assert that the number of occurrences equals the given `count`
+       /// Search for the number of occurrences of logged lines which
+       /// 1. belong to the specified module and
+       /// 2. match the given regex pattern.
+       /// Assert that the number of occurrences equals the given `count`
        pub fn assert_log_regex(&self, module: String, pattern: regex::Regex, count: usize) {
                let log_entries = self.lines.lock().unwrap();
                let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| {