Merge pull request #1106 from TheBlueMatt/2021-10-no-perm-err-broadcast
[rust-lightning] / lightning / src / util / test_utils.rs
index 1d4eccd81e6678ebf207daeda35681887824d89f..9b2f222c519f9c672bf6d538ef23898de2e8020f 100644 (file)
@@ -17,9 +17,9 @@ use chain::channelmonitor;
 use chain::channelmonitor::MonitorEvent;
 use chain::transaction::OutPoint;
 use chain::keysinterface;
-use ln::features::{ChannelFeatures, InitFeatures};
+use ln::channelmanager;
+use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
 use ln::{msgs, wire};
-use ln::msgs::OptionalField;
 use ln::script::ShutdownScript;
 use routing::scoring::FixedPenaltyScorer;
 use util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
@@ -31,11 +31,12 @@ use bitcoin::blockdata::constants::genesis_block;
 use bitcoin::blockdata::transaction::{Transaction, TxOut};
 use bitcoin::blockdata::script::{Builder, Script};
 use bitcoin::blockdata::opcodes;
-use bitcoin::blockdata::block::BlockHeader;
+use bitcoin::blockdata::block::Block;
 use bitcoin::network::constants::Network;
 use bitcoin::hash_types::{BlockHash, Txid};
 
-use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, ecdsa::Signature};
+use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, ecdsa::Signature, Scalar};
+use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::secp256k1::ecdsa::RecoverableSignature;
 
 use regex;
@@ -45,12 +46,13 @@ use prelude::*;
 use core::time::Duration;
 use sync::{Mutex, Arc};
 use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
-use core::{cmp, mem};
+use core::mem;
 use bitcoin::bech32::u5;
 use chain::keysinterface::{InMemorySigner, Recipient, KeyMaterial};
 
 #[cfg(feature = "std")]
 use std::time::{SystemTime, UNIX_EPOCH};
+use bitcoin::Sequence;
 
 pub struct TestVecWriter(pub Vec<u8>);
 impl Writer for TestVecWriter {
@@ -74,6 +76,7 @@ impl keysinterface::KeysInterface for OnlyReadsKeysInterface {
        type Signer = EnforcingSigner;
 
        fn get_node_secret(&self, _recipient: Recipient) -> Result<SecretKey, ()> { unreachable!(); }
+       fn ecdh(&self, _recipient: Recipient, _other_key: &PublicKey, _tweak: Option<&Scalar>) -> Result<SharedSecret, ()> { unreachable!(); }
        fn get_inbound_payment_key_material(&self) -> KeyMaterial { unreachable!(); }
        fn get_destination_script(&self) -> Script { unreachable!(); }
        fn get_shutdown_scriptpubkey(&self) -> ShutdownScript { unreachable!(); }
@@ -123,7 +126,7 @@ impl<'a> TestChainMonitor<'a> {
        }
 }
 impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
-       fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor<EnforcingSigner>) -> Result<(), chain::ChannelMonitorUpdateErr> {
+       fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor<EnforcingSigner>) -> chain::ChannelMonitorUpdateStatus {
                // At every point where we get a monitor update, we should be able to send a useful monitor
                // to a watchtower and disk...
                let mut w = TestVecWriter(Vec::new());
@@ -137,7 +140,7 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                self.chain_monitor.watch_channel(funding_txo, new_monitor)
        }
 
-       fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), chain::ChannelMonitorUpdateErr> {
+       fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> chain::ChannelMonitorUpdateStatus {
                // Every monitor update should survive roundtrip
                let mut w = TestVecWriter(Vec::new());
                update.write(&mut w).unwrap();
@@ -169,16 +172,16 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                update_res
        }
 
-       fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec<MonitorEvent>)> {
+       fn release_pending_monitor_events(&self) -> Vec<(OutPoint, Vec<MonitorEvent>, Option<PublicKey>)> {
                return self.chain_monitor.release_pending_monitor_events();
        }
 }
 
 pub struct TestPersister {
-       pub update_ret: Mutex<Result<(), chain::ChannelMonitorUpdateErr>>,
+       pub update_ret: Mutex<chain::ChannelMonitorUpdateStatus>,
        /// If this is set to Some(), after the next return, we'll always return this until update_ret
        /// is changed:
-       pub next_update_ret: Mutex<Option<Result<(), chain::ChannelMonitorUpdateErr>>>,
+       pub next_update_ret: Mutex<Option<chain::ChannelMonitorUpdateStatus>>,
        /// When we get an update_persisted_channel call with no ChannelMonitorUpdate, we insert the
        /// MonitorUpdateId here.
        pub chain_sync_monitor_persistences: Mutex<HashMap<OutPoint, HashSet<MonitorUpdateId>>>,
@@ -189,23 +192,23 @@ pub struct TestPersister {
 impl TestPersister {
        pub fn new() -> Self {
                Self {
-                       update_ret: Mutex::new(Ok(())),
+                       update_ret: Mutex::new(chain::ChannelMonitorUpdateStatus::Completed),
                        next_update_ret: Mutex::new(None),
                        chain_sync_monitor_persistences: Mutex::new(HashMap::new()),
                        offchain_monitor_updates: Mutex::new(HashMap::new()),
                }
        }
 
-       pub fn set_update_ret(&self, ret: Result<(), chain::ChannelMonitorUpdateErr>) {
+       pub fn set_update_ret(&self, ret: chain::ChannelMonitorUpdateStatus) {
                *self.update_ret.lock().unwrap() = ret;
        }
 
-       pub fn set_next_update_ret(&self, next_ret: Option<Result<(), chain::ChannelMonitorUpdateErr>>) {
+       pub fn set_next_update_ret(&self, next_ret: Option<chain::ChannelMonitorUpdateStatus>) {
                *self.next_update_ret.lock().unwrap() = next_ret;
        }
 }
 impl<Signer: keysinterface::Sign> chainmonitor::Persist<Signer> for TestPersister {
-       fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<Signer>, _id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> {
+       fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<Signer>, _id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
                let ret = self.update_ret.lock().unwrap().clone();
                if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
                        *self.update_ret.lock().unwrap() = next_ret;
@@ -213,7 +216,7 @@ impl<Signer: keysinterface::Sign> chainmonitor::Persist<Signer> for TestPersiste
                ret
        }
 
-       fn update_persisted_channel(&self, funding_txo: OutPoint, update: &Option<channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor<Signer>, update_id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> {
+       fn update_persisted_channel(&self, funding_txo: OutPoint, update: &Option<channelmonitor::ChannelMonitorUpdate>, _data: &channelmonitor::ChannelMonitor<Signer>, update_id: MonitorUpdateId) -> chain::ChannelMonitorUpdateStatus {
                let ret = self.update_ret.lock().unwrap().clone();
                if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
                        *self.update_ret.lock().unwrap() = next_ret;
@@ -229,21 +232,22 @@ impl<Signer: keysinterface::Sign> chainmonitor::Persist<Signer> for TestPersiste
 
 pub struct TestBroadcaster {
        pub txn_broadcasted: Mutex<Vec<Transaction>>,
-       pub blocks: Arc<Mutex<Vec<(BlockHeader, u32)>>>,
+       pub blocks: Arc<Mutex<Vec<(Block, u32)>>>,
 }
 
 impl TestBroadcaster {
-       pub fn new(blocks: Arc<Mutex<Vec<(BlockHeader, u32)>>>) -> TestBroadcaster {
+       pub fn new(blocks: Arc<Mutex<Vec<(Block, u32)>>>) -> TestBroadcaster {
                TestBroadcaster { txn_broadcasted: Mutex::new(Vec::new()), blocks }
        }
 }
 
 impl chaininterface::BroadcasterInterface for TestBroadcaster {
        fn broadcast_transaction(&self, tx: &Transaction) {
-               assert!(tx.lock_time < 1_500_000_000);
-               if tx.lock_time > self.blocks.lock().unwrap().len() as u32 + 1 && tx.lock_time < 500_000_000 {
+               let lock_time = tx.lock_time.0;
+               assert!(lock_time < 1_500_000_000);
+               if lock_time > self.blocks.lock().unwrap().len() as u32 + 1 && lock_time < 500_000_000 {
                        for inp in tx.input.iter() {
-                               if inp.sequence != 0xffffffff {
+                               if inp.sequence != Sequence::MAX {
                                        panic!("We should never broadcast a transaction before its locktime ({})!", tx.lock_time);
                                }
                        }
@@ -284,9 +288,9 @@ impl TestChannelMessageHandler {
 
 impl Drop for TestChannelMessageHandler {
        fn drop(&mut self) {
-               let l = self.expected_recv_msgs.lock().unwrap();
                #[cfg(feature = "std")]
                {
+                       let l = self.expected_recv_msgs.lock().unwrap();
                        if !std::thread::panicking() {
                                assert!(l.is_none() || l.as_ref().unwrap().is_empty());
                        }
@@ -307,8 +311,8 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
        fn handle_funding_signed(&self, _their_node_id: &PublicKey, msg: &msgs::FundingSigned) {
                self.received_msg(wire::Message::FundingSigned(msg.clone()));
        }
-       fn handle_funding_locked(&self, _their_node_id: &PublicKey, msg: &msgs::FundingLocked) {
-               self.received_msg(wire::Message::FundingLocked(msg.clone()));
+       fn handle_channel_ready(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReady) {
+               self.received_msg(wire::Message::ChannelReady(msg.clone()));
        }
        fn handle_shutdown(&self, _their_node_id: &PublicKey, _their_features: &InitFeatures, msg: &msgs::Shutdown) {
                self.received_msg(wire::Message::Shutdown(msg.clone()));
@@ -347,13 +351,20 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
                self.received_msg(wire::Message::ChannelReestablish(msg.clone()));
        }
        fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {}
-       fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) {
+       fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) -> Result<(), ()> {
                // Don't bother with `received_msg` for Init as its auto-generated and we don't want to
                // bother re-generating the expected Init message in all tests.
+               Ok(())
        }
        fn handle_error(&self, _their_node_id: &PublicKey, msg: &msgs::ErrorMessage) {
                self.received_msg(wire::Message::Error(msg.clone()));
        }
+       fn provided_node_features(&self) -> NodeFeatures {
+               channelmanager::provided_node_features()
+       }
+       fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures {
+               channelmanager::provided_init_features()
+       }
 }
 
 impl events::MessageSendEventsProvider for TestChannelMessageHandler {
@@ -374,7 +385,7 @@ fn get_dummy_channel_announcement(short_chan_id: u64) -> msgs::ChannelAnnounceme
        let node_1_btckey = SecretKey::from_slice(&[40; 32]).unwrap();
        let node_2_btckey = SecretKey::from_slice(&[39; 32]).unwrap();
        let unsigned_ann = msgs::UnsignedChannelAnnouncement {
-               features: ChannelFeatures::known(),
+               features: ChannelFeatures::empty(),
                chain_hash: genesis_block(network).header.block_hash(),
                short_channel_id: short_chan_id,
                node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_privkey),
@@ -407,7 +418,7 @@ fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate {
                        flags: 0,
                        cltv_expiry_delta: 0,
                        htlc_minimum_msat: 0,
-                       htlc_maximum_msat: OptionalField::Absent,
+                       htlc_maximum_msat: msgs::MAX_VALUE_MSAT,
                        fee_base_msat: 0,
                        fee_proportional_millionths: 0,
                        excess_data: vec![],
@@ -444,38 +455,29 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
                self.chan_upds_recvd.fetch_add(1, Ordering::AcqRel);
                Err(msgs::LightningError { err: "".to_owned(), action: msgs::ErrorAction::IgnoreError })
        }
-       fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> {
-               let mut chan_anns = Vec::new();
-               const TOTAL_UPDS: u64 = 50;
-               let end: u64 = cmp::min(starting_point + batch_amount as u64, TOTAL_UPDS);
-               for i in starting_point..end {
-                       let chan_upd_1 = get_dummy_channel_update(i);
-                       let chan_upd_2 = get_dummy_channel_update(i);
-                       let chan_ann = get_dummy_channel_announcement(i);
+       fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> {
+               let chan_upd_1 = get_dummy_channel_update(starting_point);
+               let chan_upd_2 = get_dummy_channel_update(starting_point);
+               let chan_ann = get_dummy_channel_announcement(starting_point);
 
-                       chan_anns.push((chan_ann, Some(chan_upd_1), Some(chan_upd_2)));
-               }
-
-               chan_anns
+               Some((chan_ann, Some(chan_upd_1), Some(chan_upd_2)))
        }
 
-       fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec<msgs::NodeAnnouncement> {
-               Vec::new()
+       fn get_next_node_announcement(&self, _starting_point: Option<&PublicKey>) -> Option<msgs::NodeAnnouncement> {
+               None
        }
 
-       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) {
+       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) -> Result<(), ()> {
                if !init_msg.features.supports_gossip_queries() {
-                       return ();
+                       return Ok(());
                }
 
-               let should_request_full_sync = self.request_full_sync.load(Ordering::Acquire);
-
                #[allow(unused_mut, unused_assignments)]
                let mut gossip_start_time = 0;
                #[cfg(feature = "std")]
                {
                        gossip_start_time = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time must be > 1970").as_secs();
-                       if should_request_full_sync {
+                       if self.request_full_sync.load(Ordering::Acquire) {
                                gossip_start_time -= 60 * 60 * 24 * 7 * 2; // 2 weeks ago
                        } else {
                                gossip_start_time -= 60 * 60; // an hour ago
@@ -491,6 +493,7 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
                                timestamp_range: u32::max_value(),
                        },
                });
+               Ok(())
        }
 
        fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), msgs::LightningError> {
@@ -508,6 +511,18 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
        fn handle_query_short_channel_ids(&self, _their_node_id: &PublicKey, _msg: msgs::QueryShortChannelIds) -> Result<(), msgs::LightningError> {
                Ok(())
        }
+
+       fn provided_node_features(&self) -> NodeFeatures {
+               let mut features = NodeFeatures::empty();
+               features.set_gossip_queries_optional();
+               features
+       }
+
+       fn provided_init_features(&self, _their_init_features: &PublicKey) -> InitFeatures {
+               let mut features = InitFeatures::empty();
+               features.set_gossip_queries_optional();
+               features
+       }
 }
 
 impl events::MessageSendEventsProvider for TestRoutingMessageHandler {
@@ -521,10 +536,7 @@ impl events::MessageSendEventsProvider for TestRoutingMessageHandler {
 
 pub struct TestLogger {
        level: Level,
-       #[cfg(feature = "std")]
-       id: String,
-       #[cfg(not(feature = "std"))]
-       _id: String,
+       pub(crate) id: String,
        pub lines: Mutex<HashMap<(String, String), usize>>,
 }
 
@@ -535,10 +547,7 @@ impl TestLogger {
        pub fn with_id(id: String) -> TestLogger {
                TestLogger {
                        level: Level::Trace,
-                       #[cfg(feature = "std")]
                        id,
-                       #[cfg(not(feature = "std"))]
-                       _id: id,
                        lines: Mutex::new(HashMap::new())
                }
        }
@@ -562,10 +571,10 @@ impl TestLogger {
                assert_eq!(l, count)
        }
 
-    /// Search for the number of occurrences of logged lines which
-    /// 1. belong to the specified module and
-    /// 2. match the given regex pattern.
-    /// Assert that the number of occurrences equals the given `count`
+       /// Search for the number of occurrences of logged lines which
+       /// 1. belong to the specified module and
+       /// 2. match the given regex pattern.
+       /// Assert that the number of occurrences equals the given `count`
        pub fn assert_log_regex(&self, module: String, pattern: regex::Regex, count: usize) {
                let log_entries = self.lines.lock().unwrap();
                let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| {
@@ -599,6 +608,9 @@ impl keysinterface::KeysInterface for TestKeysInterface {
        fn get_node_secret(&self, recipient: Recipient) -> Result<SecretKey, ()> {
                self.backing.get_node_secret(recipient)
        }
+       fn ecdh(&self, recipient: Recipient, other_key: &PublicKey, tweak: Option<&Scalar>) -> Result<SharedSecret, ()> {
+               self.backing.ecdh(recipient, other_key, tweak)
+       }
        fn get_inbound_payment_key_material(&self) -> keysinterface::KeyMaterial {
                self.backing.get_inbound_payment_key_material()
        }
@@ -724,7 +736,6 @@ pub struct TestChainSource {
        pub utxo_ret: Mutex<Result<TxOut, chain::AccessError>>,
        pub watched_txn: Mutex<HashSet<(Txid, Script)>>,
        pub watched_outputs: Mutex<HashSet<(OutPoint, Script)>>,
-       expectations: Mutex<Option<VecDeque<OnRegisterOutput>>>,
 }
 
 impl TestChainSource {
@@ -735,17 +746,8 @@ impl TestChainSource {
                        utxo_ret: Mutex::new(Ok(TxOut { value: u64::max_value(), script_pubkey })),
                        watched_txn: Mutex::new(HashSet::new()),
                        watched_outputs: Mutex::new(HashSet::new()),
-                       expectations: Mutex::new(None),
                }
        }
-
-       /// Sets an expectation that [`chain::Filter::register_output`] is called.
-       pub fn expect(&self, expectation: OnRegisterOutput) -> &Self {
-               self.expectations.lock().unwrap()
-                       .get_or_insert_with(|| VecDeque::new())
-                       .push_back(expectation);
-               self
-       }
 }
 
 impl chain::Access for TestChainSource {
@@ -763,24 +765,8 @@ impl chain::Filter for TestChainSource {
                self.watched_txn.lock().unwrap().insert((*txid, script_pubkey.clone()));
        }
 
-       fn register_output(&self, output: WatchedOutput) -> Option<(usize, Transaction)> {
-               let dependent_tx = match &mut *self.expectations.lock().unwrap() {
-                       None => None,
-                       Some(expectations) => match expectations.pop_front() {
-                               None => {
-                                       panic!("Unexpected register_output: {:?}",
-                                               (output.outpoint, output.script_pubkey));
-                               },
-                               Some(expectation) => {
-                                       assert_eq!(output.outpoint, expectation.outpoint());
-                                       assert_eq!(&output.script_pubkey, expectation.script_pubkey());
-                                       expectation.returns
-                               },
-                       },
-               };
-
+       fn register_output(&self, output: WatchedOutput) {
                self.watched_outputs.lock().unwrap().insert((output.outpoint, output.script_pubkey));
-               dependent_tx
        }
 }
 
@@ -789,47 +775,6 @@ impl Drop for TestChainSource {
                if panicking() {
                        return;
                }
-
-               if let Some(expectations) = &*self.expectations.lock().unwrap() {
-                       if !expectations.is_empty() {
-                               panic!("Unsatisfied expectations: {:?}", expectations);
-                       }
-               }
-       }
-}
-
-/// An expectation that [`chain::Filter::register_output`] was called with a transaction output and
-/// returns an optional dependent transaction that spends the output in the same block.
-pub struct OnRegisterOutput {
-       /// The transaction output to register.
-       pub with: TxOutReference,
-
-       /// A dependent transaction spending the output along with its position in the block.
-       pub returns: Option<(usize, Transaction)>,
-}
-
-/// A transaction output as identified by an index into a transaction's output list.
-pub struct TxOutReference(pub Transaction, pub usize);
-
-impl OnRegisterOutput {
-       fn outpoint(&self) -> OutPoint {
-               let txid = self.with.0.txid();
-               let index = self.with.1 as u16;
-               OutPoint { txid, index }
-       }
-
-       fn script_pubkey(&self) -> &Script {
-               let index = self.with.1;
-               &self.with.0.output[index].script_pubkey
-       }
-}
-
-impl core::fmt::Debug for OnRegisterOutput {
-       fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
-               f.debug_struct("OnRegisterOutput")
-                       .field("outpoint", &self.outpoint())
-                       .field("script_pubkey", self.script_pubkey())
-                       .finish()
        }
 }