X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=src%2Futil%2Ftest_utils.rs;h=4360c5701c40e93927eab81bcdc3a4963236a34e;hb=efcfb93ebe0ecf0889b0cb3393cbe4d3c1e13ab6;hp=3acb0d02e78e2bdec6214dd82b5ea117e341d16d;hpb=832fc4fd4435fa236f15d3e737bebf64619ff60e;p=rust-lightning diff --git a/src/util/test_utils.rs b/src/util/test_utils.rs index 3acb0d02..4360c570 100644 --- a/src/util/test_utils.rs +++ b/src/util/test_utils.rs @@ -1,6 +1,7 @@ use chain::chaininterface; use chain::chaininterface::ConfirmationTarget; use chain::transaction::OutPoint; +use chain::keysinterface; use ln::channelmonitor; use ln::msgs; use ln::msgs::{HandleError}; @@ -10,15 +11,17 @@ use util::logger::{Logger, Level, Record}; use util::ser::{ReadableArgs, Writer}; use bitcoin::blockdata::transaction::Transaction; +use bitcoin::blockdata::script::Script; use bitcoin::util::hash::Sha256dHash; +use bitcoin::network::constants::Network; -use secp256k1::PublicKey; +use secp256k1::{SecretKey, PublicKey}; use std::sync::{Arc,Mutex}; use std::{mem}; -struct VecWriter(Vec); -impl Writer for VecWriter { +pub struct TestVecWriter(pub Vec); +impl Writer for TestVecWriter { fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> { self.0.extend_from_slice(buf); Ok(()) @@ -55,7 +58,7 @@ impl channelmonitor::ManyChannelMonitor for TestChannelMonitor { fn add_update_monitor(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> { // At every point where we get a monitor update, we should be able to send a useful monitor // to a watchtower and disk... - let mut w = VecWriter(Vec::new()); + let mut w = TestVecWriter(Vec::new()); monitor.write_for_disk(&mut w).unwrap(); assert!(<(Sha256dHash, channelmonitor::ChannelMonitor)>::read( &mut ::std::io::Cursor::new(&w.0), Arc::new(TestLogger::new())).unwrap().1 == monitor); @@ -208,3 +211,40 @@ impl Logger for TestLogger { } } } + +pub struct TestKeysInterface { + backing: keysinterface::KeysManager, + pub override_session_priv: Mutex>, + pub override_channel_id_priv: Mutex>, +} + +impl keysinterface::KeysInterface for TestKeysInterface { + fn get_node_secret(&self) -> SecretKey { self.backing.get_node_secret() } + fn get_destination_script(&self) -> Script { self.backing.get_destination_script() } + fn get_shutdown_pubkey(&self) -> PublicKey { self.backing.get_shutdown_pubkey() } + fn get_channel_keys(&self, inbound: bool) -> keysinterface::ChannelKeys { self.backing.get_channel_keys(inbound) } + + fn get_session_key(&self) -> SecretKey { + match *self.override_session_priv.lock().unwrap() { + Some(key) => key.clone(), + None => self.backing.get_session_key() + } + } + + fn get_channel_id(&self) -> [u8; 32] { + match *self.override_channel_id_priv.lock().unwrap() { + Some(key) => key.clone(), + None => self.backing.get_channel_id() + } + } +} + +impl TestKeysInterface { + pub fn new(seed: &[u8; 32], network: Network, logger: Arc) -> Self { + Self { + backing: keysinterface::KeysManager::new(seed, network, logger), + override_session_priv: Mutex::new(None), + override_channel_id_priv: Mutex::new(None), + } + } +}