X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=212f1f4b60a392f6ad605fe599f200b3c543b778;hb=b66e3c53768f6bc7bc43064f1818051d22477c63;hp=4b3cc4113bc35214dc62e3a6d12b6b9290f88eca;hpb=29b392a96de6c4d597a2d546d06e02824b83b6dd;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 4b3cc411..212f1f4b 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -32,6 +32,7 @@ use crate::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState}; use crate::util::logger::{Logger, Level, Record}; use crate::util::ser::{Readable, ReadableArgs, Writer, Writeable}; +use bitcoin::blockdata::constants::ChainHash; use bitcoin::blockdata::constants::genesis_block; use bitcoin::blockdata::transaction::{Transaction, TxOut}; use bitcoin::blockdata::script::{Builder, Script}; @@ -44,11 +45,13 @@ use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, ecdsa::Signature, Scal use bitcoin::secp256k1::ecdh::SharedSecret; use bitcoin::secp256k1::ecdsa::RecoverableSignature; +#[cfg(any(test, feature = "_test_utils"))] use regex; use crate::io; use crate::prelude::*; use core::cell::RefCell; +use core::ops::DerefMut; use core::time::Duration; use crate::sync::{Mutex, Arc}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; @@ -111,8 +114,8 @@ impl<'a> Router for TestRouter<'a> { if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() { assert_eq!(find_route_query, *params); if let Ok(ref route) = find_route_res { - let locked_scorer = self.scorer.lock().unwrap(); - let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs); + let mut binding = self.scorer.lock().unwrap(); + let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs); for path in &route.paths { let mut aggregate_msat = 0u64; for (idx, hop) in path.hops.iter().rev().enumerate() { @@ -126,10 +129,10 @@ impl<'a> Router for TestRouter<'a> { // Since the path is reversed, the last element in our iteration is the first // hop. if idx == path.hops.len() - 1 { - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage); + scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage, &()); } else { let curr_hop_path_idx = path.hops.len() - 1 - idx; - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage); + scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage, &()); } } } @@ -137,10 +140,9 @@ impl<'a> Router for TestRouter<'a> { return find_route_res; } let logger = TestLogger::new(); - let scorer = self.scorer.lock().unwrap(); find_route( payer, params, &self.network_graph, first_hops, &logger, - &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), + &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(), &[42; 32] ) } @@ -362,14 +364,18 @@ pub struct TestChannelMessageHandler { pub pending_events: Mutex>, expected_recv_msgs: Mutex>>>, connected_peers: Mutex>, + pub message_fetch_counter: AtomicUsize, + genesis_hash: ChainHash, } impl TestChannelMessageHandler { - pub fn new() -> Self { + pub fn new(genesis_hash: ChainHash) -> Self { TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), expected_recv_msgs: Mutex::new(None), connected_peers: Mutex::new(HashSet::new()), + message_fetch_counter: AtomicUsize::new(0), + genesis_hash, } } @@ -472,10 +478,59 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures { channelmanager::provided_init_features(&UserConfig::default()) } + + fn get_genesis_hashes(&self) -> Option> { + Some(vec![self.genesis_hash]) + } + + fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, msg: &msgs::OpenChannelV2) { + self.received_msg(wire::Message::OpenChannelV2(msg.clone())); + } + + fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, msg: &msgs::AcceptChannelV2) { + self.received_msg(wire::Message::AcceptChannelV2(msg.clone())); + } + + fn handle_tx_add_input(&self, _their_node_id: &PublicKey, msg: &msgs::TxAddInput) { + self.received_msg(wire::Message::TxAddInput(msg.clone())); + } + + fn handle_tx_add_output(&self, _their_node_id: &PublicKey, msg: &msgs::TxAddOutput) { + self.received_msg(wire::Message::TxAddOutput(msg.clone())); + } + + fn handle_tx_remove_input(&self, _their_node_id: &PublicKey, msg: &msgs::TxRemoveInput) { + self.received_msg(wire::Message::TxRemoveInput(msg.clone())); + } + + fn handle_tx_remove_output(&self, _their_node_id: &PublicKey, msg: &msgs::TxRemoveOutput) { + self.received_msg(wire::Message::TxRemoveOutput(msg.clone())); + } + + fn handle_tx_complete(&self, _their_node_id: &PublicKey, msg: &msgs::TxComplete) { + self.received_msg(wire::Message::TxComplete(msg.clone())); + } + + fn handle_tx_signatures(&self, _their_node_id: &PublicKey, msg: &msgs::TxSignatures) { + self.received_msg(wire::Message::TxSignatures(msg.clone())); + } + + fn handle_tx_init_rbf(&self, _their_node_id: &PublicKey, msg: &msgs::TxInitRbf) { + self.received_msg(wire::Message::TxInitRbf(msg.clone())); + } + + fn handle_tx_ack_rbf(&self, _their_node_id: &PublicKey, msg: &msgs::TxAckRbf) { + self.received_msg(wire::Message::TxAckRbf(msg.clone())); + } + + fn handle_tx_abort(&self, _their_node_id: &PublicKey, msg: &msgs::TxAbort) { + self.received_msg(wire::Message::TxAbort(msg.clone())); + } } impl events::MessageSendEventsProvider for TestChannelMessageHandler { fn get_and_clear_pending_msg_events(&self) -> Vec { + self.message_fetch_counter.fetch_add(1, Ordering::AcqRel); let mut pending_events = self.pending_events.lock().unwrap(); let mut ret = Vec::new(); mem::swap(&mut ret, &mut *pending_events); @@ -684,6 +739,7 @@ impl TestLogger { /// 1. belong to the specified module and /// 2. match the given regex pattern. /// Assert that the number of occurrences equals the given `count` + #[cfg(any(test, feature = "_test_utils"))] pub fn assert_log_regex(&self, module: &str, pattern: regex::Regex, count: usize) { let log_entries = self.lines.lock().unwrap(); let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| { @@ -697,7 +753,7 @@ impl Logger for TestLogger { fn log(&self, record: &Record) { *self.lines.lock().unwrap().entry((record.module_path.to_string(), format!("{}", record.args))).or_insert(0) += 1; if record.level >= self.level { - #[cfg(feature = "std")] + #[cfg(all(not(ldk_bench), feature = "std"))] println!("{:<5} {} [{} : {}, {}] {}", record.level.to_string(), self.id, record.module_path, record.file, record.line, record.args); } } @@ -971,8 +1027,9 @@ impl crate::util::ser::Writeable for TestScorer { } impl Score for TestScorer { + type ScoreParams = (); fn channel_penalty_msat( - &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage + &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage, _score_params: &Self::ScoreParams ) -> u64 { if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() { match scorer_expectations.pop_front() {