X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=9b2f222c519f9c672bf6d538ef23898de2e8020f;hb=7544030bb63fee6484fc178bb2ac8f382fe3b5b1;hp=27c2d9de874ed45abca03dabe171ca238d7288ff;hpb=ca163c3fae94e64b7f70a7a549cd761cfa7e52d2;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 27c2d9de..9b2f222c 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -17,9 +17,9 @@ use chain::channelmonitor; use chain::channelmonitor::MonitorEvent; use chain::transaction::OutPoint; use chain::keysinterface; -use ln::features::{ChannelFeatures, InitFeatures}; -use ln::msgs; -use ln::msgs::OptionalField; +use ln::channelmanager; +use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures}; +use ln::{msgs, wire}; use ln::script::ShutdownScript; use routing::scoring::FixedPenaltyScorer; use util::enforcing_trait_impls::{EnforcingSigner, EnforcementState}; @@ -31,12 +31,13 @@ use bitcoin::blockdata::constants::genesis_block; use bitcoin::blockdata::transaction::{Transaction, TxOut}; use bitcoin::blockdata::script::{Builder, Script}; use bitcoin::blockdata::opcodes; -use bitcoin::blockdata::block::BlockHeader; +use bitcoin::blockdata::block::Block; use bitcoin::network::constants::Network; use bitcoin::hash_types::{BlockHash, Txid}; -use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, Signature}; -use bitcoin::secp256k1::recovery::RecoverableSignature; +use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, ecdsa::Signature, Scalar}; +use bitcoin::secp256k1::ecdh::SharedSecret; +use bitcoin::secp256k1::ecdsa::RecoverableSignature; use regex; @@ -45,10 +46,14 @@ use prelude::*; use core::time::Duration; use sync::{Mutex, Arc}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use core::{cmp, mem}; +use core::mem; use bitcoin::bech32::u5; use chain::keysinterface::{InMemorySigner, Recipient, KeyMaterial}; +#[cfg(feature = "std")] +use std::time::{SystemTime, UNIX_EPOCH}; +use bitcoin::Sequence; + pub struct TestVecWriter(pub Vec); impl Writer for TestVecWriter { fn write_all(&mut self, buf: &[u8]) -> Result<(), io::Error> { @@ -71,6 +76,7 @@ impl keysinterface::KeysInterface for OnlyReadsKeysInterface { type Signer = EnforcingSigner; fn get_node_secret(&self, _recipient: Recipient) -> Result { unreachable!(); } + fn ecdh(&self, _recipient: Recipient, _other_key: &PublicKey, _tweak: Option<&Scalar>) -> Result { unreachable!(); } fn get_inbound_payment_key_material(&self) -> KeyMaterial { unreachable!(); } fn get_destination_script(&self) -> Script { unreachable!(); } fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { unreachable!(); } @@ -113,9 +119,14 @@ impl<'a> TestChainMonitor<'a> { expect_channel_force_closed: Mutex::new(None), } } + + pub fn complete_sole_pending_chan_update(&self, channel_id: &[u8; 32]) { + let (outpoint, _, latest_update) = self.latest_monitor_update_id.lock().unwrap().get(channel_id).unwrap().clone(); + self.chain_monitor.channel_monitor_updated(outpoint, latest_update).unwrap(); + } } impl<'a> chain::Watch for TestChainMonitor<'a> { - fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> Result<(), chain::ChannelMonitorUpdateErr> { + fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> chain::ChannelMonitorUpdateStatus { // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... let mut w = TestVecWriter(Vec::new()); @@ -129,7 +140,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { self.chain_monitor.watch_channel(funding_txo, new_monitor) } - fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), chain::ChannelMonitorUpdateErr> { + fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus { // Every monitor update should survive roundtrip let mut w = TestVecWriter(Vec::new()); update.write(&mut w).unwrap(); @@ -161,16 +172,16 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { update_res } - fn release_pending_monitor_events(&self) -> Vec { + fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec, Option)> { return self.chain_monitor.release_pending_monitor_events(); } } pub struct TestPersister { - pub update_ret: Mutex>, + pub update_ret: Mutex, /// If this is set to Some(), after the next return, we'll always return this until update_ret /// is changed: - pub next_update_ret: Mutex>>, + pub next_update_ret: Mutex>, /// When we get an update_persisted_channel call with no ChannelMonitorUpdate, we insert the /// MonitorUpdateId here. pub chain_sync_monitor_persistences: Mutex>>, @@ -181,23 +192,23 @@ pub struct TestPersister { impl TestPersister { pub fn new() -> Self { Self { - update_ret: Mutex::new(Ok(())), + update_ret: Mutex::new(chain::ChannelMonitorUpdateStatus::Completed), next_update_ret: Mutex::new(None), chain_sync_monitor_persistences: Mutex::new(HashMap::new()), offchain_monitor_updates: Mutex::new(HashMap::new()), } } - pub fn set_update_ret(&self, ret: Result<(), chain::ChannelMonitorUpdateErr>) { + pub fn set_update_ret(&self, ret: chain::ChannelMonitorUpdateStatus) { *self.update_ret.lock().unwrap() = ret; } - pub fn set_next_update_ret(&self, next_ret: Option>) { + pub fn set_next_update_ret(&self, next_ret: Option) { *self.next_update_ret.lock().unwrap() = next_ret; } } impl chainmonitor::Persist for TestPersister { - fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor, _id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { + fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor, _id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus { let ret = self.update_ret.lock().unwrap().clone(); if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { *self.update_ret.lock().unwrap() = next_ret; @@ -205,7 +216,7 @@ impl chainmonitor::Persist for TestPersiste ret } - fn update_persisted_channel(&self, funding_txo: OutPoint, update: &Option, _data: &channelmonitor::ChannelMonitor, update_id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { + fn update_persisted_channel(&self, funding_txo: OutPoint, update: &Option, _data: &channelmonitor::ChannelMonitor, update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus { let ret = self.update_ret.lock().unwrap().clone(); if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { *self.update_ret.lock().unwrap() = next_ret; @@ -221,21 +232,22 @@ impl chainmonitor::Persist for TestPersiste pub struct TestBroadcaster { pub txn_broadcasted: Mutex>, - pub blocks: Arc>>, + pub blocks: Arc>>, } impl TestBroadcaster { - pub fn new(blocks: Arc>>) -> TestBroadcaster { + pub fn new(blocks: Arc>>) -> TestBroadcaster { TestBroadcaster { txn_broadcasted: Mutex::new(Vec::new()), blocks } } } impl chaininterface::BroadcasterInterface for TestBroadcaster { fn broadcast_transaction(&self, tx: &Transaction) { - assert!(tx.lock_time < 1_500_000_000); - if tx.lock_time > self.blocks.lock().unwrap().len() as u32 + 1 && tx.lock_time < 500_000_000 { + let lock_time = tx.lock_time.0; + assert!(lock_time < 1_500_000_000); + if lock_time > self.blocks.lock().unwrap().len() as u32 + 1 && lock_time < 500_000_000 { for inp in tx.input.iter() { - if inp.sequence != 0xffffffff { + if inp.sequence != Sequence::MAX { panic!("We should never broadcast a transaction before its locktime ({})!", tx.lock_time); } } @@ -246,37 +258,113 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster { pub struct TestChannelMessageHandler { pub pending_events: Mutex>, + expected_recv_msgs: Mutex>>>, } impl TestChannelMessageHandler { pub fn new() -> Self { TestChannelMessageHandler { pending_events: Mutex::new(Vec::new()), + expected_recv_msgs: Mutex::new(None), + } + } + + #[cfg(test)] + pub(crate) fn expect_receive_msg(&self, ev: wire::Message<()>) { + let mut expected_msgs = self.expected_recv_msgs.lock().unwrap(); + if expected_msgs.is_none() { *expected_msgs = Some(Vec::new()); } + expected_msgs.as_mut().unwrap().push(ev); + } + + fn received_msg(&self, _ev: wire::Message<()>) { + let mut msgs = self.expected_recv_msgs.lock().unwrap(); + if msgs.is_none() { return; } + assert!(!msgs.as_ref().unwrap().is_empty(), "Received message when we weren't expecting one"); + #[cfg(test)] + assert_eq!(msgs.as_ref().unwrap()[0], _ev); + msgs.as_mut().unwrap().remove(0); + } +} + +impl Drop for TestChannelMessageHandler { + fn drop(&mut self) { + #[cfg(feature = "std")] + { + let l = self.expected_recv_msgs.lock().unwrap(); + if !std::thread::panicking() { + assert!(l.is_none() || l.as_ref().unwrap().is_empty()); + } } } } impl msgs::ChannelMessageHandler for TestChannelMessageHandler { - fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::OpenChannel) {} - fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::AcceptChannel) {} - fn handle_funding_created(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingCreated) {} - fn handle_funding_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingSigned) {} - fn handle_funding_locked(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingLocked) {} - fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, _msg: &msgs::Shutdown) {} - fn handle_closing_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::ClosingSigned) {} - fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateAddHTLC) {} - fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFulfillHTLC) {} - fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailHTLC) {} - fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailMalformedHTLC) {} - fn handle_commitment_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::CommitmentSigned) {} - fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, _msg: &msgs::RevokeAndACK) {} - fn handle_update_fee(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFee) {} - fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) {} - fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &msgs::AnnouncementSignatures) {} - fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelReestablish) {} + fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::OpenChannel) { + self.received_msg(wire::Message::OpenChannel(msg.clone())); + } + fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::AcceptChannel) { + self.received_msg(wire::Message::AcceptChannel(msg.clone())); + } + fn handle_funding_created(&self, _their_node_id: &PublicKey, msg: &msgs::FundingCreated) { + self.received_msg(wire::Message::FundingCreated(msg.clone())); + } + fn handle_funding_signed(&self, _their_node_id: &PublicKey, msg: &msgs::FundingSigned) { + self.received_msg(wire::Message::FundingSigned(msg.clone())); + } + fn handle_channel_ready(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReady) { + self.received_msg(wire::Message::ChannelReady(msg.clone())); + } + fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, msg: &msgs::Shutdown) { + self.received_msg(wire::Message::Shutdown(msg.clone())); + } + fn handle_closing_signed(&self, _their_node_id: &PublicKey, msg: &msgs::ClosingSigned) { + self.received_msg(wire::Message::ClosingSigned(msg.clone())); + } + fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateAddHTLC) { + self.received_msg(wire::Message::UpdateAddHTLC(msg.clone())); + } + fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) { + self.received_msg(wire::Message::UpdateFulfillHTLC(msg.clone())); + } + fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) { + self.received_msg(wire::Message::UpdateFailHTLC(msg.clone())); + } + fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) { + self.received_msg(wire::Message::UpdateFailMalformedHTLC(msg.clone())); + } + fn handle_commitment_signed(&self, _their_node_id: &PublicKey, msg: &msgs::CommitmentSigned) { + self.received_msg(wire::Message::CommitmentSigned(msg.clone())); + } + fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, msg: &msgs::RevokeAndACK) { + self.received_msg(wire::Message::RevokeAndACK(msg.clone())); + } + fn handle_update_fee(&self, _their_node_id: &PublicKey, msg: &msgs::UpdateFee) { + self.received_msg(wire::Message::UpdateFee(msg.clone())); + } + fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) { + // Don't call `received_msg` here as `TestRoutingMessageHandler` generates these sometimes + } + fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) { + self.received_msg(wire::Message::AnnouncementSignatures(msg.clone())); + } + fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReestablish) { + self.received_msg(wire::Message::ChannelReestablish(msg.clone())); + } fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {} - fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) {} - fn handle_error(&self, _their_node_id: &PublicKey, _msg: &msgs::ErrorMessage) {} + fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) -> Result<(), ()> { + // Don't bother with `received_msg` for Init as its auto-generated and we don't want to + // bother re-generating the expected Init message in all tests. + Ok(()) + } + fn handle_error(&self, _their_node_id: &PublicKey, msg: &msgs::ErrorMessage) { + self.received_msg(wire::Message::Error(msg.clone())); + } + fn provided_node_features(&self) -> NodeFeatures { + channelmanager::provided_node_features() + } + fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures { + channelmanager::provided_init_features() + } } impl events::MessageSendEventsProvider for TestChannelMessageHandler { @@ -297,7 +385,7 @@ fn get_dummy_channel_announcement(short_chan_id: u64) -> msgs::ChannelAnnounceme let node_1_btckey = SecretKey::from_slice(&[40; 32]).unwrap(); let node_2_btckey = SecretKey::from_slice(&[39; 32]).unwrap(); let unsigned_ann = msgs::UnsignedChannelAnnouncement { - features: ChannelFeatures::known(), + features: ChannelFeatures::empty(), chain_hash: genesis_block(network).header.block_hash(), short_channel_id: short_chan_id, node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_privkey), @@ -330,7 +418,7 @@ fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate { flags: 0, cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: OptionalField::Absent, + htlc_maximum_msat: msgs::MAX_VALUE_MSAT, fee_base_msat: 0, fee_proportional_millionths: 0, excess_data: vec![], @@ -341,6 +429,7 @@ fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate { pub struct TestRoutingMessageHandler { pub chan_upds_recvd: AtomicUsize, pub chan_anns_recvd: AtomicUsize, + pub pending_events: Mutex>, pub request_full_sync: AtomicBool, } @@ -349,6 +438,7 @@ impl TestRoutingMessageHandler { TestRoutingMessageHandler { chan_upds_recvd: AtomicUsize::new(0), chan_anns_recvd: AtomicUsize::new(0), + pending_events: Mutex::new(vec![]), request_full_sync: AtomicBool::new(false), } } @@ -365,26 +455,46 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { self.chan_upds_recvd.fetch_add(1, Ordering::AcqRel); Err(msgs::LightningError { err: "".to_owned(), action: msgs::ErrorAction::IgnoreError }) } - fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option, Option)> { - let mut chan_anns = Vec::new(); - const TOTAL_UPDS: u64 = 50; - let end: u64 = cmp::min(starting_point + batch_amount as u64, TOTAL_UPDS); - for i in starting_point..end { - let chan_upd_1 = get_dummy_channel_update(i); - let chan_upd_2 = get_dummy_channel_update(i); - let chan_ann = get_dummy_channel_announcement(i); + fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(msgs::ChannelAnnouncement, Option, Option)> { + let chan_upd_1 = get_dummy_channel_update(starting_point); + let chan_upd_2 = get_dummy_channel_update(starting_point); + let chan_ann = get_dummy_channel_announcement(starting_point); - chan_anns.push((chan_ann, Some(chan_upd_1), Some(chan_upd_2))); - } - - chan_anns + Some((chan_ann, Some(chan_upd_1), Some(chan_upd_2))) } - fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec { - Vec::new() + fn get_next_node_announcement(&self, _starting_point: Option<&PublicKey>) -> Option { + None } - fn sync_routing_table(&self, _their_node_id: &PublicKey, _init_msg: &msgs::Init) {} + fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) -> Result<(), ()> { + if !init_msg.features.supports_gossip_queries() { + return Ok(()); + } + + #[allow(unused_mut, unused_assignments)] + let mut gossip_start_time = 0; + #[cfg(feature = "std")] + { + gossip_start_time = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs(); + if self.request_full_sync.load(Ordering::Acquire) { + gossip_start_time -= 60 * 60 * 24 * 7 * 2; // 2 weeks ago + } else { + gossip_start_time -= 60 * 60; // an hour ago + } + } + + let mut pending_events = self.pending_events.lock().unwrap(); + pending_events.push(events::MessageSendEvent::SendGossipTimestampFilter { + node_id: their_node_id.clone(), + msg: msgs::GossipTimestampFilter { + chain_hash: genesis_block(Network::Testnet).header.block_hash(), + first_timestamp: gossip_start_time as u32, + timestamp_range: u32::max_value(), + }, + }); + Ok(()) + } fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), msgs::LightningError> { Ok(()) @@ -401,20 +511,32 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { fn handle_query_short_channel_ids(&self, _their_node_id: &PublicKey, _msg: msgs::QueryShortChannelIds) -> Result<(), msgs::LightningError> { Ok(()) } + + fn provided_node_features(&self) -> NodeFeatures { + let mut features = NodeFeatures::empty(); + features.set_gossip_queries_optional(); + features + } + + fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures { + let mut features = InitFeatures::empty(); + features.set_gossip_queries_optional(); + features + } } impl events::MessageSendEventsProvider for TestRoutingMessageHandler { fn get_and_clear_pending_msg_events(&self) -> Vec { - vec![] + let mut ret = Vec::new(); + let mut pending_events = self.pending_events.lock().unwrap(); + core::mem::swap(&mut ret, &mut pending_events); + ret } } pub struct TestLogger { level: Level, - #[cfg(feature = "std")] - id: String, - #[cfg(not(feature = "std"))] - _id: String, + pub(crate) id: String, pub lines: Mutex>, } @@ -425,10 +547,7 @@ impl TestLogger { pub fn with_id(id: String) -> TestLogger { TestLogger { level: Level::Trace, - #[cfg(feature = "std")] id, - #[cfg(not(feature = "std"))] - _id: id, lines: Mutex::new(HashMap::new()) } } @@ -452,10 +571,10 @@ impl TestLogger { assert_eq!(l, count) } - /// Search for the number of occurrences of logged lines which - /// 1. belong to the specified module and - /// 2. match the given regex pattern. - /// Assert that the number of occurrences equals the given `count` + /// Search for the number of occurrences of logged lines which + /// 1. belong to the specified module and + /// 2. match the given regex pattern. + /// Assert that the number of occurrences equals the given `count` pub fn assert_log_regex(&self, module: String, pattern: regex::Regex, count: usize) { let log_entries = self.lines.lock().unwrap(); let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| { @@ -489,6 +608,9 @@ impl keysinterface::KeysInterface for TestKeysInterface { fn get_node_secret(&self, recipient: Recipient) -> Result { self.backing.get_node_secret(recipient) } + fn ecdh(&self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&Scalar>) -> Result { + self.backing.ecdh(recipient, other_key, tweak) + } fn get_inbound_payment_key_material(&self) -> keysinterface::KeyMaterial { self.backing.get_inbound_payment_key_material() } @@ -614,7 +736,6 @@ pub struct TestChainSource { pub utxo_ret: Mutex>, pub watched_txn: Mutex>, pub watched_outputs: Mutex>, - expectations: Mutex>>, } impl TestChainSource { @@ -625,17 +746,8 @@ impl TestChainSource { utxo_ret: Mutex::new(Ok(TxOut { value: u64::max_value(), script_pubkey })), watched_txn: Mutex::new(HashSet::new()), watched_outputs: Mutex::new(HashSet::new()), - expectations: Mutex::new(None), } } - - /// Sets an expectation that [`chain::Filter::register_output`] is called. - pub fn expect(&self, expectation: OnRegisterOutput) -> &Self { - self.expectations.lock().unwrap() - .get_or_insert_with(|| VecDeque::new()) - .push_back(expectation); - self - } } impl chain::Access for TestChainSource { @@ -653,24 +765,8 @@ impl chain::Filter for TestChainSource { self.watched_txn.lock().unwrap().insert((*txid, script_pubkey.clone())); } - fn register_output(&self, output: WatchedOutput) -> Option<(usize, Transaction)> { - let dependent_tx = match &mut *self.expectations.lock().unwrap() { - None => None, - Some(expectations) => match expectations.pop_front() { - None => { - panic!("Unexpected register_output: {:?}", - (output.outpoint, output.script_pubkey)); - }, - Some(expectation) => { - assert_eq!(output.outpoint, expectation.outpoint()); - assert_eq!(&output.script_pubkey, expectation.script_pubkey()); - expectation.returns - }, - }, - }; - + fn register_output(&self, output: WatchedOutput) { self.watched_outputs.lock().unwrap().insert((output.outpoint, output.script_pubkey)); - dependent_tx } } @@ -679,47 +775,6 @@ impl Drop for TestChainSource { if panicking() { return; } - - if let Some(expectations) = &*self.expectations.lock().unwrap() { - if !expectations.is_empty() { - panic!("Unsatisfied expectations: {:?}", expectations); - } - } - } -} - -/// An expectation that [`chain::Filter::register_output`] was called with a transaction output and -/// returns an optional dependent transaction that spends the output in the same block. -pub struct OnRegisterOutput { - /// The transaction output to register. - pub with: TxOutReference, - - /// A dependent transaction spending the output along with its position in the block. - pub returns: Option<(usize, Transaction)>, -} - -/// A transaction output as identified by an index into a transaction's output list. -pub struct TxOutReference(pub Transaction, pub usize); - -impl OnRegisterOutput { - fn outpoint(&self) -> OutPoint { - let txid = self.with.0.txid(); - let index = self.with.1 as u16; - OutPoint { txid, index } - } - - fn script_pubkey(&self) -> &Script { - let index = self.with.1; - &self.with.0.output[index].script_pubkey - } -} - -impl core::fmt::Debug for OnRegisterOutput { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("OnRegisterOutput") - .field("outpoint", &self.outpoint()) - .field("script_pubkey", self.script_pubkey()) - .finish() } }