X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=1dae61ab3ef5b1841241b069612f71c7e5cc8f38;hb=ccf92157620da45032d75f06b5972eaf142c1ce3;hp=b2bf947e67c5c91f9baa8bfcf6fd8af512f8f6c7;hpb=4005724ae6719594e42d11718ce9b83fa0641f3b;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index b2bf947e..1dae61ab 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -7,24 +7,29 @@ // You may not use this file except in accordance with one or both of these // licenses. -use chain; -use chain::WatchedOutput; -use chain::chaininterface; -use chain::chaininterface::ConfirmationTarget; -use chain::chainmonitor; -use chain::chainmonitor::MonitorUpdateId; -use chain::channelmonitor; -use chain::channelmonitor::MonitorEvent; -use chain::transaction::OutPoint; -use chain::keysinterface; -use ln::features::{ChannelFeatures, InitFeatures}; -use ln::{msgs, wire}; -use ln::script::ShutdownScript; -use routing::scoring::FixedPenaltyScorer; -use util::enforcing_trait_impls::{EnforcingSigner, EnforcementState}; -use util::events; -use util::logger::{Logger, Level, Record}; -use util::ser::{Readable, ReadableArgs, Writer, Writeable}; +use crate::chain; +use crate::chain::WatchedOutput; +use crate::chain::chaininterface; +use crate::chain::chaininterface::ConfirmationTarget; +use crate::chain::chainmonitor; +use crate::chain::chainmonitor::MonitorUpdateId; +use crate::chain::channelmonitor; +use crate::chain::channelmonitor::MonitorEvent; +use crate::chain::transaction::OutPoint; +use crate::chain::keysinterface; +use crate::ln::channelmanager; +use crate::ln::features::{ChannelFeatures, InitFeatures, NodeFeatures}; +use crate::ln::{msgs, wire}; +use crate::ln::msgs::LightningError; +use crate::ln::script::ShutdownScript; +use crate::routing::gossip::NetworkGraph; +use crate::routing::router::{find_route, InFlightHtlcs, Route, RouteHop, RouteParameters, Router, ScorerAccountingForInFlightHtlcs}; +use crate::routing::scoring::FixedPenaltyScorer; +use crate::util::config::UserConfig; +use crate::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState}; +use crate::util::events; +use crate::util::logger::{Logger, Level, Record}; +use crate::util::ser::{Readable, ReadableArgs, Writer, Writeable}; use bitcoin::blockdata::constants::genesis_block; use bitcoin::blockdata::transaction::{Transaction, TxOut}; @@ -40,14 +45,14 @@ use bitcoin::secp256k1::ecdsa::RecoverableSignature; use regex; -use io; -use prelude::*; +use crate::io; +use crate::prelude::*; use core::time::Duration; -use sync::{Mutex, Arc}; +use crate::sync::{Mutex, Arc}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use core::{cmp, mem}; +use core::mem; use bitcoin::bech32::u5; -use chain::keysinterface::{InMemorySigner, Recipient, KeyMaterial}; +use crate::chain::keysinterface::{InMemorySigner, Recipient, EntropySource, NodeSigner, SignerProvider}; #[cfg(feature = "std")] use std::time::{SystemTime, UNIX_EPOCH}; @@ -70,21 +75,67 @@ impl chaininterface::FeeEstimator for TestFeeEstimator { } } +pub struct TestRouter<'a> { + pub network_graph: Arc>, + pub next_routes: Mutex>>, +} + +impl<'a> TestRouter<'a> { + pub fn new(network_graph: Arc>) -> Self { + Self { network_graph, next_routes: Mutex::new(VecDeque::new()), } + } + + pub fn expect_find_route(&self, result: Result) { + let mut expected_routes = self.next_routes.lock().unwrap(); + expected_routes.push_back(result); + } +} + +impl<'a> Router for TestRouter<'a> { + fn find_route( + &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&channelmanager::ChannelDetails]>, + inflight_htlcs: &InFlightHtlcs + ) -> Result { + if let Some(find_route_res) = self.next_routes.lock().unwrap().pop_front() { + return find_route_res + } + let logger = TestLogger::new(); + find_route( + payer, params, &self.network_graph, first_hops, &logger, + &ScorerAccountingForInFlightHtlcs::new(TestScorer::with_penalty(0), &inflight_htlcs), + &[42; 32] + ) + } + fn notify_payment_path_failed(&self, _path: &[&RouteHop], _short_channel_id: u64) {} + fn notify_payment_path_successful(&self, _path: &[&RouteHop]) {} + fn notify_payment_probe_successful(&self, _path: &[&RouteHop]) {} + fn notify_payment_probe_failed(&self, _path: &[&RouteHop], _short_channel_id: u64) {} +} + +#[cfg(feature = "std")] // If we put this on the `if`, we get "attributes are not yet allowed on `if` expressions" on 1.41.1 +impl<'a> Drop for TestRouter<'a> { + fn drop(&mut self) { + if std::thread::panicking() { + return; + } + assert!(self.next_routes.lock().unwrap().is_empty()); + } +} + pub struct OnlyReadsKeysInterface {} -impl keysinterface::KeysInterface for OnlyReadsKeysInterface { + +impl EntropySource for OnlyReadsKeysInterface { + fn get_secure_random_bytes(&self) -> [u8; 32] { [0; 32] }} + +impl SignerProvider for OnlyReadsKeysInterface { type Signer = EnforcingSigner; - fn get_node_secret(&self, _recipient: Recipient) -> Result { unreachable!(); } - fn ecdh(&self, _recipient: Recipient, _other_key: &PublicKey, _tweak: Option<&Scalar>) -> Result { unreachable!(); } - fn get_inbound_payment_key_material(&self) -> KeyMaterial { unreachable!(); } - fn get_destination_script(&self) -> Script { unreachable!(); } - fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { unreachable!(); } - fn get_channel_signer(&self, _inbound: bool, _channel_value_satoshis: u64) -> EnforcingSigner { unreachable!(); } - fn get_secure_random_bytes(&self) -> [u8; 32] { [0; 32] } + fn generate_channel_keys_id(&self, _inbound: bool, _channel_value_satoshis: u64, _user_channel_id: u128) -> [u8; 32] { unreachable!(); } + + fn derive_channel_signer(&self, _channel_value_satoshis: u64, _channel_keys_id: [u8; 32]) -> Self::Signer { unreachable!(); } fn read_chan_signer(&self, mut reader: &[u8]) -> Result { - let dummy_sk = SecretKey::from_slice(&[42; 32]).unwrap(); - let inner: InMemorySigner = ReadableArgs::read(&mut reader, dummy_sk)?; + let inner: InMemorySigner = Readable::read(&mut reader)?; let state = Arc::new(Mutex::new(EnforcementState::new())); Ok(EnforcingSigner::new_with_revoked( @@ -93,7 +144,9 @@ impl keysinterface::KeysInterface for OnlyReadsKeysInterface { false )) } - fn sign_invoice(&self, _hrp_bytes: &[u8], _invoice_data: &[u5], _recipient: Recipient) -> Result { unreachable!(); } + + fn get_destination_script(&self) -> Script { unreachable!(); } + fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { unreachable!(); } } pub struct TestChainMonitor<'a> { @@ -125,13 +178,13 @@ impl<'a> TestChainMonitor<'a> { } } impl<'a> chain::Watch for TestChainMonitor<'a> { - fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> Result<(), chain::ChannelMonitorUpdateErr> { + fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> chain::ChannelMonitorUpdateStatus { // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... let mut w = TestVecWriter(Vec::new()); monitor.write(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), self.keys_manager).unwrap().1; + &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager)).unwrap().1; assert!(new_monitor == monitor); self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id(), MonitorUpdateId::from_new_monitor(&monitor))); @@ -139,12 +192,12 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { self.chain_monitor.watch_channel(funding_txo, new_monitor) } - fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), chain::ChannelMonitorUpdateErr> { + fn update_channel(&self, funding_txo: OutPoint, update: &channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus { // Every monitor update should survive roundtrip let mut w = TestVecWriter(Vec::new()); update.write(&mut w).unwrap(); assert!(channelmonitor::ChannelMonitorUpdate::read( - &mut io::Cursor::new(&w.0)).unwrap() == update); + &mut io::Cursor::new(&w.0)).unwrap() == *update); self.monitor_updates.lock().unwrap().entry(funding_txo.to_channel_id()).or_insert(Vec::new()).push(update.clone()); @@ -157,7 +210,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { } self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), - (funding_txo, update.update_id, MonitorUpdateId::from_monitor_update(&update))); + (funding_txo, update.update_id, MonitorUpdateId::from_monitor_update(update))); let update_res = self.chain_monitor.update_channel(funding_txo, update); // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... @@ -165,7 +218,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { w.0.clear(); monitor.write(&mut w).unwrap(); let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), self.keys_manager).unwrap().1; + &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager)).unwrap().1; assert!(new_monitor == *monitor); self.added_monitors.lock().unwrap().push((funding_txo, new_monitor)); update_res @@ -177,10 +230,9 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { } pub struct TestPersister { - pub update_ret: Mutex>, - /// If this is set to Some(), after the next return, we'll always return this until update_ret - /// is changed: - pub next_update_ret: Mutex>>, + /// The queue of update statuses we'll return. If none are queued, ::Completed will always be + /// returned. + pub update_rets: Mutex>, /// When we get an update_persisted_channel call with no ChannelMonitorUpdate, we insert the /// MonitorUpdateId here. pub chain_sync_monitor_persistences: Mutex>>, @@ -191,34 +243,29 @@ pub struct TestPersister { impl TestPersister { pub fn new() -> Self { Self { - update_ret: Mutex::new(Ok(())), - next_update_ret: Mutex::new(None), + update_rets: Mutex::new(VecDeque::new()), chain_sync_monitor_persistences: Mutex::new(HashMap::new()), offchain_monitor_updates: Mutex::new(HashMap::new()), } } - pub fn set_update_ret(&self, ret: Result<(), chain::ChannelMonitorUpdateErr>) { - *self.update_ret.lock().unwrap() = ret; - } - - pub fn set_next_update_ret(&self, next_ret: Option>) { - *self.next_update_ret.lock().unwrap() = next_ret; + /// Queue an update status to return. + pub fn set_update_ret(&self, next_ret: chain::ChannelMonitorUpdateStatus) { + self.update_rets.lock().unwrap().push_back(next_ret); } } -impl chainmonitor::Persist for TestPersister { - fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor, _id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { - let ret = self.update_ret.lock().unwrap().clone(); - if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { - *self.update_ret.lock().unwrap() = next_ret; +impl chainmonitor::Persist for TestPersister { + fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor, _id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus { + if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() { + return update_ret } - ret + chain::ChannelMonitorUpdateStatus::Completed } - fn update_persisted_channel(&self, funding_txo: OutPoint, update: &Option, _data: &channelmonitor::ChannelMonitor, update_id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { - let ret = self.update_ret.lock().unwrap().clone(); - if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { - *self.update_ret.lock().unwrap() = next_ret; + fn update_persisted_channel(&self, funding_txo: OutPoint, update: Option<&channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor, update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus { + let mut ret = chain::ChannelMonitorUpdateStatus::Completed; + if let Some(update_ret) = self.update_rets.lock().unwrap().pop_front() { + ret = update_ret; } if update.is_none() { self.chain_sync_monitor_persistences.lock().unwrap().entry(funding_txo).or_insert(HashSet::new()).insert(update_id); @@ -287,9 +334,9 @@ impl TestChannelMessageHandler { impl Drop for TestChannelMessageHandler { fn drop(&mut self) { - let l = self.expected_recv_msgs.lock().unwrap(); #[cfg(feature = "std")] { + let l = self.expected_recv_msgs.lock().unwrap(); if !std::thread::panicking() { assert!(l.is_none() || l.as_ref().unwrap().is_empty()); } @@ -298,10 +345,10 @@ impl Drop for TestChannelMessageHandler { } impl msgs::ChannelMessageHandler for TestChannelMessageHandler { - fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::OpenChannel) { + fn handle_open_channel(&self, _their_node_id: &PublicKey, msg: &msgs::OpenChannel) { self.received_msg(wire::Message::OpenChannel(msg.clone())); } - fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, msg: &msgs::AcceptChannel) { + fn handle_accept_channel(&self, _their_node_id: &PublicKey, msg: &msgs::AcceptChannel) { self.received_msg(wire::Message::AcceptChannel(msg.clone())); } fn handle_funding_created(&self, _their_node_id: &PublicKey, msg: &msgs::FundingCreated) { @@ -313,7 +360,7 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { fn handle_channel_ready(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReady) { self.received_msg(wire::Message::ChannelReady(msg.clone())); } - fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, msg: &msgs::Shutdown) { + fn handle_shutdown(&self, _their_node_id: &PublicKey, msg: &msgs::Shutdown) { self.received_msg(wire::Message::Shutdown(msg.clone())); } fn handle_closing_signed(&self, _their_node_id: &PublicKey, msg: &msgs::ClosingSigned) { @@ -350,13 +397,20 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler { self.received_msg(wire::Message::ChannelReestablish(msg.clone())); } fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {} - fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) { + fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) -> Result<(), ()> { // Don't bother with `received_msg` for Init as its auto-generated and we don't want to // bother re-generating the expected Init message in all tests. + Ok(()) } fn handle_error(&self, _their_node_id: &PublicKey, msg: &msgs::ErrorMessage) { self.received_msg(wire::Message::Error(msg.clone())); } + fn provided_node_features(&self) -> NodeFeatures { + channelmanager::provided_node_features(&UserConfig::default()) + } + fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures { + channelmanager::provided_init_features(&UserConfig::default()) + } } impl events::MessageSendEventsProvider for TestChannelMessageHandler { @@ -377,7 +431,7 @@ fn get_dummy_channel_announcement(short_chan_id: u64) -> msgs::ChannelAnnounceme let node_1_btckey = SecretKey::from_slice(&[40; 32]).unwrap(); let node_2_btckey = SecretKey::from_slice(&[39; 32]).unwrap(); let unsigned_ann = msgs::UnsignedChannelAnnouncement { - features: ChannelFeatures::known(), + features: ChannelFeatures::empty(), chain_hash: genesis_block(network).header.block_hash(), short_channel_id: short_chan_id, node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_privkey), @@ -447,38 +501,29 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { self.chan_upds_recvd.fetch_add(1, Ordering::AcqRel); Err(msgs::LightningError { err: "".to_owned(), action: msgs::ErrorAction::IgnoreError }) } - fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option, Option)> { - let mut chan_anns = Vec::new(); - const TOTAL_UPDS: u64 = 50; - let end: u64 = cmp::min(starting_point + batch_amount as u64, TOTAL_UPDS); - for i in starting_point..end { - let chan_upd_1 = get_dummy_channel_update(i); - let chan_upd_2 = get_dummy_channel_update(i); - let chan_ann = get_dummy_channel_announcement(i); + fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(msgs::ChannelAnnouncement, Option, Option)> { + let chan_upd_1 = get_dummy_channel_update(starting_point); + let chan_upd_2 = get_dummy_channel_update(starting_point); + let chan_ann = get_dummy_channel_announcement(starting_point); - chan_anns.push((chan_ann, Some(chan_upd_1), Some(chan_upd_2))); - } - - chan_anns + Some((chan_ann, Some(chan_upd_1), Some(chan_upd_2))) } - fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec { - Vec::new() + fn get_next_node_announcement(&self, _starting_point: Option<&PublicKey>) -> Option { + None } - fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) { + fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) -> Result<(), ()> { if !init_msg.features.supports_gossip_queries() { - return (); + return Ok(()); } - let should_request_full_sync = self.request_full_sync.load(Ordering::Acquire); - #[allow(unused_mut, unused_assignments)] let mut gossip_start_time = 0; #[cfg(feature = "std")] { gossip_start_time = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs(); - if should_request_full_sync { + if self.request_full_sync.load(Ordering::Acquire) { gossip_start_time -= 60 * 60 * 24 * 7 * 2; // 2 weeks ago } else { gossip_start_time -= 60 * 60; // an hour ago @@ -494,6 +539,7 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { timestamp_range: u32::max_value(), }, }); + Ok(()) } fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), msgs::LightningError> { @@ -511,6 +557,18 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { fn handle_query_short_channel_ids(&self, _their_node_id: &PublicKey, _msg: msgs::QueryShortChannelIds) -> Result<(), msgs::LightningError> { Ok(()) } + + fn provided_node_features(&self) -> NodeFeatures { + let mut features = NodeFeatures::empty(); + features.set_gossip_queries_optional(); + features + } + + fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures { + let mut features = InitFeatures::empty(); + features.set_gossip_queries_optional(); + features + } } impl events::MessageSendEventsProvider for TestRoutingMessageHandler { @@ -524,10 +582,7 @@ impl events::MessageSendEventsProvider for TestRoutingMessageHandler { pub struct TestLogger { level: Level, - #[cfg(feature = "std")] - id: String, - #[cfg(not(feature = "std"))] - _id: String, + pub(crate) id: String, pub lines: Mutex>, } @@ -538,10 +593,7 @@ impl TestLogger { pub fn with_id(id: String) -> TestLogger { TestLogger { level: Level::Trace, - #[cfg(feature = "std")] id, - #[cfg(not(feature = "std"))] - _id: id, lines: Mutex::new(HashMap::new()) } } @@ -565,10 +617,10 @@ impl TestLogger { assert_eq!(l, count) } - /// Search for the number of occurrences of logged lines which - /// 1. belong to the specified module and - /// 2. match the given regex pattern. - /// Assert that the number of occurrences equals the given `count` + /// Search for the number of occurrences of logged lines which + /// 1. belong to the specified module and + /// 2. match the given regex pattern. + /// Assert that the number of occurrences equals the given `count` pub fn assert_log_regex(&self, module: String, pattern: regex::Regex, count: usize) { let log_entries = self.lines.lock().unwrap(); let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| { @@ -588,6 +640,49 @@ impl Logger for TestLogger { } } +pub struct TestNodeSigner { + node_secret: SecretKey, +} + +impl TestNodeSigner { + pub fn new(node_secret: SecretKey) -> Self { + Self { node_secret } + } +} + +impl NodeSigner for TestNodeSigner { + fn get_inbound_payment_key_material(&self) -> crate::chain::keysinterface::KeyMaterial { + unreachable!() + } + + fn get_node_id(&self, recipient: Recipient) -> Result { + let node_secret = match recipient { + Recipient::Node => Ok(&self.node_secret), + Recipient::PhantomNode => Err(()) + }?; + Ok(PublicKey::from_secret_key(&Secp256k1::signing_only(), node_secret)) + } + + fn ecdh(&self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&bitcoin::secp256k1::Scalar>) -> Result { + let mut node_secret = match recipient { + Recipient::Node => Ok(self.node_secret.clone()), + Recipient::PhantomNode => Err(()) + }?; + if let Some(tweak) = tweak { + node_secret = node_secret.mul_tweak(tweak).map_err(|_| ())?; + } + Ok(SharedSecret::new(other_key, &node_secret)) + } + + fn sign_invoice(&self, _: &[u8], _: &[bitcoin::bech32::u5], _: Recipient) -> Result { + unreachable!() + } + + fn sign_gossip_message(&self, _msg: msgs::UnsignedGossipMessage) -> Result { + unreachable!() + } +} + pub struct TestKeysInterface { pub backing: keysinterface::PhantomKeysManager, pub override_random_bytes: Mutex>, @@ -596,48 +691,55 @@ pub struct TestKeysInterface { expectations: Mutex>>, } -impl keysinterface::KeysInterface for TestKeysInterface { - type Signer = EnforcingSigner; +impl EntropySource for TestKeysInterface { + fn get_secure_random_bytes(&self) -> [u8; 32] { + let override_random_bytes = self.override_random_bytes.lock().unwrap(); + if let Some(bytes) = &*override_random_bytes { + return *bytes; + } + self.backing.get_secure_random_bytes() + } +} - fn get_node_secret(&self, recipient: Recipient) -> Result { - self.backing.get_node_secret(recipient) +impl NodeSigner for TestKeysInterface { + fn get_node_id(&self, recipient: Recipient) -> Result { + self.backing.get_node_id(recipient) } + fn ecdh(&self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&Scalar>) -> Result { self.backing.ecdh(recipient, other_key, tweak) } + fn get_inbound_payment_key_material(&self) -> keysinterface::KeyMaterial { self.backing.get_inbound_payment_key_material() } - fn get_destination_script(&self) -> Script { self.backing.get_destination_script() } - fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { - match &mut *self.expectations.lock().unwrap() { - None => self.backing.get_shutdown_scriptpubkey(), - Some(expectations) => match expectations.pop_front() { - None => panic!("Unexpected get_shutdown_scriptpubkey"), - Some(expectation) => expectation.returns, - }, - } + fn sign_invoice(&self, hrp_bytes: &[u8], invoice_data: &[u5], recipient: Recipient) -> Result { + self.backing.sign_invoice(hrp_bytes, invoice_data, recipient) } - fn get_channel_signer(&self, inbound: bool, channel_value_satoshis: u64) -> EnforcingSigner { - let keys = self.backing.get_channel_signer(inbound, channel_value_satoshis); - let state = self.make_enforcement_state_cell(keys.commitment_seed); - EnforcingSigner::new_with_revoked(keys, state, self.disable_revocation_policy_check) + fn sign_gossip_message(&self, msg: msgs::UnsignedGossipMessage) -> Result { + self.backing.sign_gossip_message(msg) } +} - fn get_secure_random_bytes(&self) -> [u8; 32] { - let override_random_bytes = self.override_random_bytes.lock().unwrap(); - if let Some(bytes) = &*override_random_bytes { - return *bytes; - } - self.backing.get_secure_random_bytes() +impl SignerProvider for TestKeysInterface { + type Signer = EnforcingSigner; + + fn generate_channel_keys_id(&self, inbound: bool, channel_value_satoshis: u64, user_channel_id: u128) -> [u8; 32] { + self.backing.generate_channel_keys_id(inbound, channel_value_satoshis, user_channel_id) + } + + fn derive_channel_signer(&self, channel_value_satoshis: u64, channel_keys_id: [u8; 32]) -> EnforcingSigner { + let keys = self.backing.derive_channel_signer(channel_value_satoshis, channel_keys_id); + let state = self.make_enforcement_state_cell(keys.commitment_seed); + EnforcingSigner::new_with_revoked(keys, state, self.disable_revocation_policy_check) } fn read_chan_signer(&self, buffer: &[u8]) -> Result { let mut reader = io::Cursor::new(buffer); - let inner: InMemorySigner = ReadableArgs::read(&mut reader, self.get_node_secret(Recipient::Node).unwrap())?; + let inner: InMemorySigner = Readable::read(&mut reader)?; let state = self.make_enforcement_state_cell(inner.commitment_seed); Ok(EnforcingSigner::new_with_revoked( @@ -647,8 +749,16 @@ impl keysinterface::KeysInterface for TestKeysInterface { )) } - fn sign_invoice(&self, hrp_bytes: &[u8], invoice_data: &[u5], recipient: Recipient) -> Result { - self.backing.sign_invoice(hrp_bytes, invoice_data, recipient) + fn get_destination_script(&self) -> Script { self.backing.get_destination_script() } + + fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { + match &mut *self.expectations.lock().unwrap() { + None => self.backing.get_shutdown_scriptpubkey(), + Some(expectations) => match expectations.pop_front() { + None => panic!("Unexpected get_shutdown_scriptpubkey"), + Some(expectation) => expectation.returns, + }, + } } } @@ -664,7 +774,7 @@ impl TestKeysInterface { } } - /// Sets an expectation that [`keysinterface::KeysInterface::get_shutdown_scriptpubkey`] is + /// Sets an expectation that [`keysinterface::SignerProvider::get_shutdown_scriptpubkey`] is /// called. pub fn expect(&self, expectation: OnGetShutdownScriptpubkey) -> &Self { self.expectations.lock().unwrap() @@ -712,7 +822,7 @@ impl Drop for TestKeysInterface { } } -/// An expectation that [`keysinterface::KeysInterface::get_shutdown_scriptpubkey`] was called and +/// An expectation that [`keysinterface::SignerProvider::get_shutdown_scriptpubkey`] was called and /// returns a [`ShutdownScript`]. pub struct OnGetShutdownScriptpubkey { /// A shutdown script used to close a channel.