X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=331ef3b1172c9492d34c39626b556e31bc2e1a64;hb=a9d49aee5f1e24281cedd8ac8177d700359b86d0;hp=67f6fd6b2888dfecd4f6acb3928302e2ce41ea79;hpb=3a643df99797ee2dd5cc19a6f9d090212b1c7963;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 67f6fd6b..331ef3b1 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -16,7 +16,7 @@ use crate::chain::chainmonitor::MonitorUpdateId; use crate::chain::channelmonitor; use crate::chain::channelmonitor::MonitorEvent; use crate::chain::transaction::OutPoint; -use crate::chain::keysinterface; +use crate::sign; use crate::events; use crate::ln::channelmanager; use crate::ln::features::{ChannelFeatures, InitFeatures, NodeFeatures}; @@ -54,7 +54,7 @@ use crate::sync::{Mutex, Arc}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use core::mem; use bitcoin::bech32::u5; -use crate::chain::keysinterface::{InMemorySigner, Recipient, EntropySource, NodeSigner, SignerProvider}; +use crate::sign::{InMemorySigner, Recipient, EntropySource, NodeSigner, SignerProvider}; #[cfg(feature = "std")] use std::time::{SystemTime, UNIX_EPOCH}; @@ -126,10 +126,10 @@ impl<'a> Router for TestRouter<'a> { // Since the path is reversed, the last element in our iteration is the first // hop. if idx == path.hops.len() - 1 { - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage); + scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage, &()); } else { let curr_hop_path_idx = path.hops.len() - 1 - idx; - scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage); + scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path.hops[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage, &()); } } } @@ -140,7 +140,7 @@ impl<'a> Router for TestRouter<'a> { let scorer = self.scorer.lock().unwrap(); find_route( payer, params, &self.network_graph, first_hops, &logger, - &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), + &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(), &[42; 32] ) } @@ -180,8 +180,8 @@ impl SignerProvider for OnlyReadsKeysInterface { )) } - fn get_destination_script(&self) -> Script { unreachable!(); } - fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { unreachable!(); } + fn get_destination_script(&self) -> Result { Err(()) } + fn get_shutdown_scriptpubkey(&self) -> Result { Err(()) } } pub struct TestChainMonitor<'a> { @@ -289,7 +289,7 @@ impl TestPersister { self.update_rets.lock().unwrap().push_back(next_ret); } } -impl chainmonitor::Persist for TestPersister { +impl chainmonitor::Persist for TestPersister { fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor, _id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus { if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() { return update_ret @@ -341,17 +341,20 @@ impl TestBroadcaster { } impl chaininterface::BroadcasterInterface for TestBroadcaster { - fn broadcast_transaction(&self, tx: &Transaction) { - let lock_time = tx.lock_time.0; - assert!(lock_time < 1_500_000_000); - if bitcoin::LockTime::from(tx.lock_time).is_block_height() && lock_time > self.blocks.lock().unwrap().last().unwrap().1 { - for inp in tx.input.iter() { - if inp.sequence != Sequence::MAX { - panic!("We should never broadcast a transaction before its locktime ({})!", tx.lock_time); + fn broadcast_transactions(&self, txs: &[&Transaction]) { + for tx in txs { + let lock_time = tx.lock_time.0; + assert!(lock_time < 1_500_000_000); + if bitcoin::LockTime::from(tx.lock_time).is_block_height() && lock_time > self.blocks.lock().unwrap().last().unwrap().1 { + for inp in tx.input.iter() { + if inp.sequence != Sequence::MAX { + panic!("We should never broadcast a transaction before its locktime ({})!", tx.lock_time); + } } } } - self.txn_broadcasted.lock().unwrap().push(tx.clone()); + let owned_txs: Vec = txs.iter().map(|tx| (*tx).clone()).collect(); + self.txn_broadcasted.lock().unwrap().extend(owned_txs); } } @@ -359,6 +362,7 @@ pub struct TestChannelMessageHandler { pub pending_events: Mutex>, expected_recv_msgs: Mutex>>>, connected_peers: Mutex>, + pub message_fetch_counter: AtomicUsize, } impl TestChannelMessageHandler { @@ -367,6 +371,7 @@ impl TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), expected_recv_msgs: Mutex::new(None), connected_peers: Mutex::new(HashSet::new()), + message_fetch_counter: AtomicUsize::new(0), } } @@ -469,10 +474,55 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures { channelmanager::provided_init_features(&UserConfig::default()) } + + fn handle_open_channel_v2(&self, _their_node_id: &PublicKey, msg: &msgs::OpenChannelV2) { + self.received_msg(wire::Message::OpenChannelV2(msg.clone())); + } + + fn handle_accept_channel_v2(&self, _their_node_id: &PublicKey, msg: &msgs::AcceptChannelV2) { + self.received_msg(wire::Message::AcceptChannelV2(msg.clone())); + } + + fn handle_tx_add_input(&self, _their_node_id: &PublicKey, msg: &msgs::TxAddInput) { + self.received_msg(wire::Message::TxAddInput(msg.clone())); + } + + fn handle_tx_add_output(&self, _their_node_id: &PublicKey, msg: &msgs::TxAddOutput) { + self.received_msg(wire::Message::TxAddOutput(msg.clone())); + } + + fn handle_tx_remove_input(&self, _their_node_id: &PublicKey, msg: &msgs::TxRemoveInput) { + self.received_msg(wire::Message::TxRemoveInput(msg.clone())); + } + + fn handle_tx_remove_output(&self, _their_node_id: &PublicKey, msg: &msgs::TxRemoveOutput) { + self.received_msg(wire::Message::TxRemoveOutput(msg.clone())); + } + + fn handle_tx_complete(&self, _their_node_id: &PublicKey, msg: &msgs::TxComplete) { + self.received_msg(wire::Message::TxComplete(msg.clone())); + } + + fn handle_tx_signatures(&self, _their_node_id: &PublicKey, msg: &msgs::TxSignatures) { + self.received_msg(wire::Message::TxSignatures(msg.clone())); + } + + fn handle_tx_init_rbf(&self, _their_node_id: &PublicKey, msg: &msgs::TxInitRbf) { + self.received_msg(wire::Message::TxInitRbf(msg.clone())); + } + + fn handle_tx_ack_rbf(&self, _their_node_id: &PublicKey, msg: &msgs::TxAckRbf) { + self.received_msg(wire::Message::TxAckRbf(msg.clone())); + } + + fn handle_tx_abort(&self, _their_node_id: &PublicKey, msg: &msgs::TxAbort) { + self.received_msg(wire::Message::TxAbort(msg.clone())); + } } impl events::MessageSendEventsProvider for TestChannelMessageHandler { fn get_and_clear_pending_msg_events(&self) -> Vec { + self.message_fetch_counter.fetch_add(1, Ordering::AcqRel); let mut pending_events = self.pending_events.lock().unwrap(); let mut ret = Vec::new(); mem::swap(&mut ret, &mut *pending_events); @@ -694,7 +744,7 @@ impl Logger for TestLogger { fn log(&self, record: &Record) { *self.lines.lock().unwrap().entry((record.module_path.to_string(), format!("{}", record.args))).or_insert(0) += 1; if record.level >= self.level { - #[cfg(feature = "std")] + #[cfg(all(not(ldk_bench), feature = "std"))] println!("{:<5} {} [{} : {}, {}] {}", record.level.to_string(), self.id, record.module_path, record.file, record.line, record.args); } } @@ -711,7 +761,7 @@ impl TestNodeSigner { } impl NodeSigner for TestNodeSigner { - fn get_inbound_payment_key_material(&self) -> crate::chain::keysinterface::KeyMaterial { + fn get_inbound_payment_key_material(&self) -> crate::sign::KeyMaterial { unreachable!() } @@ -744,7 +794,7 @@ impl NodeSigner for TestNodeSigner { } pub struct TestKeysInterface { - pub backing: keysinterface::PhantomKeysManager, + pub backing: sign::PhantomKeysManager, pub override_random_bytes: Mutex>, pub disable_revocation_policy_check: bool, enforcement_states: Mutex>>>, @@ -770,7 +820,7 @@ impl NodeSigner for TestKeysInterface { self.backing.ecdh(recipient, other_key, tweak) } - fn get_inbound_payment_key_material(&self) -> keysinterface::KeyMaterial { + fn get_inbound_payment_key_material(&self) -> sign::KeyMaterial { self.backing.get_inbound_payment_key_material() } @@ -809,14 +859,14 @@ impl SignerProvider for TestKeysInterface { )) } - fn get_destination_script(&self) -> Script { self.backing.get_destination_script() } + fn get_destination_script(&self) -> Result { self.backing.get_destination_script() } - fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { + fn get_shutdown_scriptpubkey(&self) -> Result { match &mut *self.expectations.lock().unwrap() { None => self.backing.get_shutdown_scriptpubkey(), Some(expectations) => match expectations.pop_front() { None => panic!("Unexpected get_shutdown_scriptpubkey"), - Some(expectation) => expectation.returns, + Some(expectation) => Ok(expectation.returns), }, } } @@ -826,7 +876,7 @@ impl TestKeysInterface { pub fn new(seed: &[u8; 32], network: Network) -> Self { let now = Duration::from_secs(genesis_block(network).header.time as u64); Self { - backing: keysinterface::PhantomKeysManager::new(seed, now.as_secs(), now.subsec_nanos(), seed), + backing: sign::PhantomKeysManager::new(seed, now.as_secs(), now.subsec_nanos(), seed), override_random_bytes: Mutex::new(None), disable_revocation_policy_check: false, enforcement_states: Mutex::new(HashMap::new()), @@ -834,7 +884,7 @@ impl TestKeysInterface { } } - /// Sets an expectation that [`keysinterface::SignerProvider::get_shutdown_scriptpubkey`] is + /// Sets an expectation that [`sign::SignerProvider::get_shutdown_scriptpubkey`] is /// called. pub fn expect(&self, expectation: OnGetShutdownScriptpubkey) -> &Self { self.expectations.lock().unwrap() @@ -882,7 +932,7 @@ impl Drop for TestKeysInterface { } } -/// An expectation that [`keysinterface::SignerProvider::get_shutdown_scriptpubkey`] was called and +/// An expectation that [`sign::SignerProvider::get_shutdown_scriptpubkey`] was called and /// returns a [`ShutdownScript`]. pub struct OnGetShutdownScriptpubkey { /// A shutdown script used to close a channel. @@ -968,8 +1018,9 @@ impl crate::util::ser::Writeable for TestScorer { } impl Score for TestScorer { + type ScoreParams = (); fn channel_penalty_msat( - &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage + &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage, _score_params: &Self::ScoreParams ) -> u64 { if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() { match scorer_expectations.pop_front() {