From 2cbb8358f114f52976c9f51ee17f71f36f23f20a Mon Sep 17 00:00:00 2001 From: Devrandom Date: Wed, 2 Dec 2020 18:50:17 +0100 Subject: [PATCH] Use TestKeysInterface in functional tests This allows stateful validation in EnforcingChannelKeys --- lightning-persister/src/lib.rs | 4 +- lightning/src/ln/chanmon_update_fail_tests.rs | 2 +- lightning/src/ln/functional_test_utils.rs | 22 ++++--- lightning/src/ln/functional_tests.rs | 60 +++++++++---------- lightning/src/util/test_utils.rs | 9 ++- 5 files changed, 51 insertions(+), 46 deletions(-) diff --git a/lightning-persister/src/lib.rs b/lightning-persister/src/lib.rs index bb6eb54c..bc718789 100644 --- a/lightning-persister/src/lib.rs +++ b/lightning-persister/src/lib.rs @@ -198,8 +198,8 @@ mod tests { let persister_1 = FilesystemPersister::new("test_filesystem_persister_1".to_string()); let chanmon_cfgs = create_chanmon_cfgs(2); let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs); - let chain_mon_0 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &persister_0); - let chain_mon_1 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[1].chain_source), &chanmon_cfgs[1].tx_broadcaster, &chanmon_cfgs[1].logger, &chanmon_cfgs[1].fee_estimator, &persister_1); + let chain_mon_0 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &persister_0, &node_cfgs[0].keys_manager); + let chain_mon_1 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[1].chain_source), &chanmon_cfgs[1].tx_broadcaster, &chanmon_cfgs[1].logger, &chanmon_cfgs[1].fee_estimator, &persister_1, &node_cfgs[1].keys_manager); node_cfgs[0].chain_monitor = chain_mon_0; node_cfgs[1].chain_monitor = chain_mon_1; let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index b9376029..426b7cb2 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -109,7 +109,7 @@ fn test_monitor_and_persister_update_fail() { let new_monitor = <(BlockHash, ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0), &test_utils::OnlyReadsKeysInterface {}).unwrap().1; assert!(new_monitor == *monitor); - let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); + let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert!(chain_mon.watch_channel(outpoint, new_monitor).is_ok()); chain_mon }; diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 7863286f..e487b152 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -21,7 +21,7 @@ use ln::msgs; use ln::msgs::{ChannelMessageHandler,RoutingMessageHandler}; use util::enforcing_trait_impls::EnforcingChannelKeys; use util::test_utils; -use util::test_utils::{TestChainMonitor, OnlyReadsKeysInterface}; +use util::test_utils::TestChainMonitor; use util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider}; use util::errors::APIError; use util::config::UserConfig; @@ -96,6 +96,8 @@ pub struct TestChanMonCfg { pub chain_source: test_utils::TestChainSource, pub persister: test_utils::TestPersister, pub logger: test_utils::TestLogger, + pub keys_manager: test_utils::TestKeysInterface, + } pub struct NodeCfg<'a> { @@ -103,7 +105,7 @@ pub struct NodeCfg<'a> { pub tx_broadcaster: &'a test_utils::TestBroadcaster, pub fee_estimator: &'a test_utils::TestFeeEstimator, pub chain_monitor: test_utils::TestChainMonitor<'a>, - pub keys_manager: test_utils::TestKeysInterface, + pub keys_manager: &'a test_utils::TestKeysInterface, pub logger: &'a test_utils::TestLogger, pub node_seed: [u8; 32], } @@ -172,7 +174,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { let mut w = test_utils::TestVecWriter(Vec::new()); old_monitor.write(&mut w).unwrap(); let (_, deserialized_monitor) = <(BlockHash, ChannelMonitor)>::read( - &mut ::std::io::Cursor::new(&w.0), &OnlyReadsKeysInterface {}).unwrap(); + &mut ::std::io::Cursor::new(&w.0), self.keys_manager).unwrap(); deserialized_monitors.push(deserialized_monitor); } } @@ -205,7 +207,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone()) }; let chain_source = test_utils::TestChainSource::new(Network::Testnet); - let chain_monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &broadcaster, &self.logger, &feeest, &persister); + let chain_monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &broadcaster, &self.logger, &feeest, &persister, &self.keys_manager); for deserialized_monitor in deserialized_monitors.drain(..) { if let Err(_) = chain_monitor.watch_channel(deserialized_monitor.get_funding_txo().0, deserialized_monitor) { panic!(); @@ -1132,7 +1134,10 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec { let chain_source = test_utils::TestChainSource::new(Network::Testnet); let logger = test_utils::TestLogger::with_id(format!("node {}", i)); let persister = test_utils::TestPersister::new(); - chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister }); + let seed = [i as u8; 32]; + let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet); + + chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager }); } chan_mon_cfgs @@ -1142,10 +1147,9 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec(node_count: usize, cfgs: &'a Vec default_config.channel_options.announced_channel = true; default_config.peer_channel_config_limits.force_announced_channel_preference = false; default_config.own_channel_config.our_htlc_minimum_msat = 1000; // sanitization being done by the sender, to exerce receiver logic we need to lift of limit - let node = ChannelManager::new(Network::Testnet, cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger, &cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, 0); + let node = ChannelManager::new(Network::Testnet, cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger, cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, 0); chanmgrs.push(node); } diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index a9367d9c..9575fd7f 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -53,7 +53,7 @@ use regex; use std::collections::{BTreeSet, HashMap, HashSet}; use std::default::Default; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use std::sync::atomic::Ordering; use std::mem; @@ -4284,9 +4284,9 @@ fn test_no_txn_manager_serialize_deserialize() { logger = test_utils::TestLogger::new(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; persister = test_utils::TestPersister::new(); - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); - nodes[0].chain_monitor = &new_chain_monitor; keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet); + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister, &keys_manager); + nodes[0].chain_monitor = &new_chain_monitor; let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..]; let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor)>::read( &mut chan_0_monitor_read, &keys_manager).unwrap(); @@ -4394,8 +4394,8 @@ fn test_manager_serialize_deserialize_events() { fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; logger = test_utils::TestLogger::new(); persister = test_utils::TestPersister::new(); - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet); + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister, &keys_manager); nodes[0].chain_monitor = &new_chain_monitor; let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..]; let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor)>::read( @@ -4470,7 +4470,7 @@ fn test_simple_manager_serialize_deserialize() { let fee_estimator: test_utils::TestFeeEstimator; let persister: test_utils::TestPersister; let new_chain_monitor: test_utils::TestChainMonitor; - let keys_manager: test_utils::TestKeysInterface; + let keys_manager: &test_utils::TestKeysInterface; let nodes_0_deserialized: ChannelManager; let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs); create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); @@ -4487,12 +4487,12 @@ fn test_simple_manager_serialize_deserialize() { logger = test_utils::TestLogger::new(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; persister = test_utils::TestPersister::new(); - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); - keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet); + keys_manager = &chanmon_cfgs[0].keys_manager; + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister, keys_manager); nodes[0].chain_monitor = &new_chain_monitor; let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..]; let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor)>::read( - &mut chan_0_monitor_read, &keys_manager).unwrap(); + &mut chan_0_monitor_read, keys_manager).unwrap(); assert!(chan_0_monitor_read.is_empty()); let mut nodes_0_read = &nodes_0_serialized[..]; @@ -4501,7 +4501,7 @@ fn test_simple_manager_serialize_deserialize() { channel_monitors.insert(chan_0_monitor.get_funding_txo().0, &mut chan_0_monitor); <(BlockHash, ChannelManager)>::read(&mut nodes_0_read, ChannelManagerReadArgs { default_config: UserConfig::default(), - keys_manager: &keys_manager, + keys_manager, fee_estimator: &fee_estimator, chain_monitor: nodes[0].chain_monitor, tx_broadcaster: nodes[0].tx_broadcaster.clone(), @@ -4532,7 +4532,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let fee_estimator: test_utils::TestFeeEstimator; let persister: test_utils::TestPersister; let new_chain_monitor: test_utils::TestChainMonitor; - let keys_manager: test_utils::TestKeysInterface; + let keys_manager: &test_utils::TestKeysInterface; let nodes_0_deserialized: ChannelManager; let mut nodes = create_network(4, &node_cfgs, &node_chanmgrs); create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); @@ -4568,15 +4568,15 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { logger = test_utils::TestLogger::new(); fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; persister = test_utils::TestPersister::new(); - new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister); + keys_manager = &chanmon_cfgs[0].keys_manager; + new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister, keys_manager); nodes[0].chain_monitor = &new_chain_monitor; - keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet); let mut node_0_stale_monitors = Vec::new(); for serialized in node_0_stale_monitors_serialized.iter() { let mut read = &serialized[..]; - let (_, monitor) = <(BlockHash, ChannelMonitor)>::read(&mut read, &keys_manager).unwrap(); + let (_, monitor) = <(BlockHash, ChannelMonitor)>::read(&mut read, keys_manager).unwrap(); assert!(read.is_empty()); node_0_stale_monitors.push(monitor); } @@ -4584,7 +4584,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let mut node_0_monitors = Vec::new(); for serialized in node_0_monitors_serialized.iter() { let mut read = &serialized[..]; - let (_, monitor) = <(BlockHash, ChannelMonitor)>::read(&mut read, &keys_manager).unwrap(); + let (_, monitor) = <(BlockHash, ChannelMonitor)>::read(&mut read, keys_manager).unwrap(); assert!(read.is_empty()); node_0_monitors.push(monitor); } @@ -4593,7 +4593,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { if let Err(msgs::DecodeError::InvalidValue) = <(BlockHash, ChannelManager)>::read(&mut nodes_0_read, ChannelManagerReadArgs { default_config: UserConfig::default(), - keys_manager: &keys_manager, + keys_manager, fee_estimator: &fee_estimator, chain_monitor: nodes[0].chain_monitor, tx_broadcaster: nodes[0].tx_broadcaster.clone(), @@ -4607,7 +4607,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() { let (_, nodes_0_deserialized_tmp) = <(BlockHash, ChannelManager)>::read(&mut nodes_0_read, ChannelManagerReadArgs { default_config: UserConfig::default(), - keys_manager: &keys_manager, + keys_manager, fee_estimator: &fee_estimator, chain_monitor: nodes[0].chain_monitor, tx_broadcaster: nodes[0].tx_broadcaster.clone(), @@ -5713,8 +5713,8 @@ fn test_key_derivation_params() { // We manually create the node configuration to backup the seed. let seed = [42; 32]; let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet); - let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &chanmon_cfgs[0].persister); - let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, chain_monitor, keys_manager, node_seed: seed }; + let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &chanmon_cfgs[0].persister, &keys_manager); + let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, chain_monitor, keys_manager: &keys_manager, node_seed: seed }; let mut node_cfgs = create_node_cfgs(3, &chanmon_cfgs); node_cfgs.remove(0); node_cfgs.insert(0, node); @@ -7325,8 +7325,7 @@ fn test_user_configurable_csv_delay() { let nodes = create_network(2, &node_cfgs, &node_chanmgrs); // We test config.our_to_self > BREAKDOWN_TIMEOUT is enforced in Channel::new_outbound() - let keys_manager = Arc::new(test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet)); - if let Err(error) = Channel::new_outbound(&&test_utils::TestFeeEstimator { sat_per_kw: 253 }, &keys_manager, nodes[1].node.get_our_node_id(), 1000000, 1000000, 0, &low_our_to_self_config) { + if let Err(error) = Channel::new_outbound(&&test_utils::TestFeeEstimator { sat_per_kw: 253 }, &nodes[0].keys_manager, nodes[1].node.get_our_node_id(), 1000000, 1000000, 0, &low_our_to_self_config) { match error { APIError::APIMisuseError { err } => { assert!(regex::Regex::new(r"Configured with an unreasonable our_to_self_delay \(\d+\) putting user funds at risks").unwrap().is_match(err.as_str())); }, _ => panic!("Unexpected event"), @@ -7337,7 +7336,7 @@ fn test_user_configurable_csv_delay() { nodes[1].node.create_channel(nodes[0].node.get_our_node_id(), 1000000, 1000000, 42, None).unwrap(); let mut open_channel = get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, nodes[0].node.get_our_node_id()); open_channel.to_self_delay = 200; - if let Err(error) = Channel::new_from_req(&&test_utils::TestFeeEstimator { sat_per_kw: 253 }, &keys_manager, nodes[1].node.get_our_node_id(), InitFeatures::known(), &open_channel, 0, &low_our_to_self_config) { + if let Err(error) = Channel::new_from_req(&&test_utils::TestFeeEstimator { sat_per_kw: 253 }, &nodes[0].keys_manager, nodes[1].node.get_our_node_id(), InitFeatures::known(), &open_channel, 0, &low_our_to_self_config) { match error { ChannelError::Close(err) => { assert!(regex::Regex::new(r"Configured with an unreasonable our_to_self_delay \(\d+\) putting user funds at risks").unwrap().is_match(err.as_str())); }, _ => panic!("Unexpected event"), @@ -7363,7 +7362,7 @@ fn test_user_configurable_csv_delay() { nodes[1].node.create_channel(nodes[0].node.get_our_node_id(), 1000000, 1000000, 42, None).unwrap(); let mut open_channel = get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, nodes[0].node.get_our_node_id()); open_channel.to_self_delay = 200; - if let Err(error) = Channel::new_from_req(&&test_utils::TestFeeEstimator { sat_per_kw: 253 }, &keys_manager, nodes[1].node.get_our_node_id(), InitFeatures::known(), &open_channel, 0, &high_their_to_self_config) { + if let Err(error) = Channel::new_from_req(&&test_utils::TestFeeEstimator { sat_per_kw: 253 }, &nodes[0].keys_manager, nodes[1].node.get_our_node_id(), InitFeatures::known(), &open_channel, 0, &high_their_to_self_config) { match error { ChannelError::Close(err) => { assert!(regex::Regex::new(r"They wanted our payments to be delayed by a needlessly long period\. Upper limit: \d+\. Actual: \d+").unwrap().is_match(err.as_str())); }, _ => panic!("Unexpected event"), @@ -7377,15 +7376,15 @@ fn test_data_loss_protect() { // * we don't broadcast our Local Commitment Tx in case of fallen behind // * we close channel in case of detecting other being fallen behind // * we are able to claim our own outputs thanks to to_remote being static - let keys_manager; let persister; let logger; let fee_estimator; let tx_broadcaster; let chain_source; + let chanmon_cfgs = create_chanmon_cfgs(2); + let keys_manager = &chanmon_cfgs[0].keys_manager; let monitor; let node_state_0; - let chanmon_cfgs = create_chanmon_cfgs(2); let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs); @@ -7405,18 +7404,17 @@ fn test_data_loss_protect() { // Restore node A from previous state logger = test_utils::TestLogger::with_id(format!("node {}", 0)); - keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet); - let mut chain_monitor = <(BlockHash, ChannelMonitor)>::read(&mut ::std::io::Cursor::new(previous_chain_monitor_state.0), &keys_manager).unwrap().1; + let mut chain_monitor = <(BlockHash, ChannelMonitor)>::read(&mut ::std::io::Cursor::new(previous_chain_monitor_state.0), keys_manager).unwrap().1; chain_source = test_utils::TestChainSource::new(Network::Testnet); tx_broadcaster = test_utils::TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new())}; fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 }; persister = test_utils::TestPersister::new(); - monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &fee_estimator, &persister); + monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &fee_estimator, &persister, keys_manager); node_state_0 = { let mut channel_monitors = HashMap::new(); channel_monitors.insert(OutPoint { txid: chan.3.txid(), index: 0 }, &mut chain_monitor); <(BlockHash, ChannelManager)>::read(&mut ::std::io::Cursor::new(previous_node_state), ChannelManagerReadArgs { - keys_manager: &keys_manager, + keys_manager: keys_manager, fee_estimator: &fee_estimator, chain_monitor: &monitor, logger: &logger, @@ -8281,7 +8279,7 @@ fn test_update_err_monitor_lockdown() { let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0), &test_utils::OnlyReadsKeysInterface {}).unwrap().1; assert!(new_monitor == *monitor); - let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); + let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok()); watchtower }; @@ -8340,7 +8338,7 @@ fn test_concurrent_monitor_claim() { let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0), &test_utils::OnlyReadsKeysInterface {}).unwrap().1; assert!(new_monitor == *monitor); - let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); + let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok()); watchtower }; @@ -8366,7 +8364,7 @@ fn test_concurrent_monitor_claim() { let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0), &test_utils::OnlyReadsKeysInterface {}).unwrap().1; assert!(new_monitor == *monitor); - let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister); + let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok()); watchtower }; diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 3680a724..b55c19e3 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -39,6 +39,7 @@ use std::sync::Mutex; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::{cmp, mem}; use std::collections::{HashMap, HashSet}; +use chain::keysinterface::InMemoryChannelKeys; pub struct TestVecWriter(pub Vec); impl Writer for TestVecWriter { @@ -79,17 +80,19 @@ pub struct TestChainMonitor<'a> { pub added_monitors: Mutex)>>, pub latest_monitor_update_id: Mutex>, pub chain_monitor: chainmonitor::ChainMonitor>, + pub keys_manager: &'a TestKeysInterface, pub update_ret: Mutex>>, // If this is set to Some(), after the next return, we'll always return this until update_ret // is changed: pub next_update_ret: Mutex>>, } impl<'a> TestChainMonitor<'a> { - pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, persister: &'a channelmonitor::Persist) -> Self { + pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, persister: &'a channelmonitor::Persist, keys_manager: &'a TestKeysInterface) -> Self { Self { added_monitors: Mutex::new(Vec::new()), latest_monitor_update_id: Mutex::new(HashMap::new()), chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister), + keys_manager, update_ret: Mutex::new(None), next_update_ret: Mutex::new(None), } @@ -104,7 +107,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { let mut w = TestVecWriter(Vec::new()); monitor.write(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut ::std::io::Cursor::new(&w.0), &OnlyReadsKeysInterface {}).unwrap().1; + &mut ::std::io::Cursor::new(&w.0), self.keys_manager).unwrap().1; assert!(new_monitor == monitor); self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id())); self.added_monitors.lock().unwrap().push((funding_txo, monitor)); @@ -137,7 +140,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { w.0.clear(); monitor.write(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut ::std::io::Cursor::new(&w.0), &OnlyReadsKeysInterface {}).unwrap().1; + &mut ::std::io::Cursor::new(&w.0), self.keys_manager).unwrap().1; assert!(new_monitor == *monitor); self.added_monitors.lock().unwrap().push((funding_txo, new_monitor)); -- 2.30.2