Merge pull request #2396 from tnull/2023-07-fix-github-actions
[rust-lightning] / lightning / src / util / test_utils.rs
index 7d5b8c0bef05f81e4963b10937659ee51f8b7bc6..212f1f4b60a392f6ad605fe599f200b3c543b778 100644 (file)
@@ -32,6 +32,7 @@ use crate::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
 use crate::util::logger::{Logger, Level, Record};
 use crate::util::ser::{Readable, ReadableArgs, Writer, Writeable};
 
+use bitcoin::blockdata::constants::ChainHash;
 use bitcoin::blockdata::constants::genesis_block;
 use bitcoin::blockdata::transaction::{Transaction, TxOut};
 use bitcoin::blockdata::script::{Builder, Script};
@@ -44,11 +45,13 @@ use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, ecdsa::Signature, Scal
 use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::secp256k1::ecdsa::RecoverableSignature;
 
+#[cfg(any(test, feature = "_test_utils"))]
 use regex;
 
 use crate::io;
 use crate::prelude::*;
 use core::cell::RefCell;
+use core::ops::DerefMut;
 use core::time::Duration;
 use crate::sync::{Mutex, Arc};
 use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
@@ -111,8 +114,8 @@ impl<'a> Router for TestRouter<'a> {
                if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
                        assert_eq!(find_route_query, *params);
                        if let Ok(ref route) = find_route_res {
-                               let locked_scorer = self.scorer.lock().unwrap();
-                               let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
+                               let mut binding = self.scorer.lock().unwrap();
+                               let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
                                for path in &route.paths {
                                        let mut aggregate_msat = 0u64;
                                        for (idx, hop) in path.hops.iter().rev().enumerate() {
@@ -137,10 +140,9 @@ impl<'a> Router for TestRouter<'a> {
                        return find_route_res;
                }
                let logger = TestLogger::new();
-               let scorer = self.scorer.lock().unwrap();
                find_route(
                        payer, params, &self.network_graph, first_hops, &logger,
-                       &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(),
+                       &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(),
                        &[42; 32]
                )
        }
@@ -362,14 +364,18 @@ pub struct TestChannelMessageHandler {
        pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
        expected_recv_msgs: Mutex<Option<Vec<wire::Message<()>>>>,
        connected_peers: Mutex<HashSet<PublicKey>>,
+       pub message_fetch_counter: AtomicUsize,
+       genesis_hash: ChainHash,
 }
 
 impl TestChannelMessageHandler {
-       pub fn new() -> Self {
+       pub fn new(genesis_hash: ChainHash) -> Self {
                TestChannelMessageHandler {
                        pending_events: Mutex::new(Vec::new()),
                        expected_recv_msgs: Mutex::new(None),
                        connected_peers: Mutex::new(HashSet::new()),
+                       message_fetch_counter: AtomicUsize::new(0),
+                       genesis_hash,
                }
        }
 
@@ -473,6 +479,10 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
                channelmanager::provided_init_features(&UserConfig::default())
        }
 
+       fn get_genesis_hashes(&self) -> Option<Vec<ChainHash>> {
+               Some(vec![self.genesis_hash])
+       }
+
        fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, msg: &msgs::OpenChannelV2) {
                self.received_msg(wire::Message::OpenChannelV2(msg.clone()));
        }
@@ -520,6 +530,7 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
 
 impl events::MessageSendEventsProvider for TestChannelMessageHandler {
        fn get_and_clear_pending_msg_events(&self) -> Vec<events::MessageSendEvent> {
+               self.message_fetch_counter.fetch_add(1, Ordering::AcqRel);
                let mut pending_events = self.pending_events.lock().unwrap();
                let mut ret = Vec::new();
                mem::swap(&mut ret, &mut *pending_events);
@@ -728,6 +739,7 @@ impl TestLogger {
        /// 1. belong to the specified module and
        /// 2. match the given regex pattern.
        /// Assert that the number of occurrences equals the given `count`
+       #[cfg(any(test, feature = "_test_utils"))]
        pub fn assert_log_regex(&self, module: &str, pattern: regex::Regex, count: usize) {
                let log_entries = self.lines.lock().unwrap();
                let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| {