Merge pull request #977 from TheBlueMatt/2021-06-fix-double-claim-close
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index a6d82ccee525ad164023620d0b59d409819c0692..d1d322bfa95f08b3906cea5d68975b66500fe5f9 100644 (file)
@@ -23,7 +23,7 @@ use ln::msgs::{ChannelMessageHandler,RoutingMessageHandler};
 use util::enforcing_trait_impls::EnforcingSigner;
 use util::test_utils;
 use util::test_utils::TestChainMonitor;
-use util::events::{Event, MessageSendEvent, MessageSendEventsProvider};
+use util::events::{Event, MessageSendEvent, MessageSendEventsProvider, PaymentPurpose};
 use util::errors::APIError;
 use util::config::UserConfig;
 use util::ser::{ReadableArgs, Writeable, Readable};
@@ -42,7 +42,7 @@ use bitcoin::secp256k1::key::PublicKey;
 use prelude::*;
 use core::cell::RefCell;
 use std::rc::Rc;
-use std::sync::{Arc, Mutex};
+use sync::{Arc, Mutex};
 use core::mem;
 
 pub const CHAN_CONFIRM_DEPTH: u32 = 10;
@@ -269,7 +269,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                        // Check that if we serialize and then deserialize all our channel monitors we get the
                        // same set of outputs to watch for on chain as we have now. Note that if we write
                        // tests that fully close channels and remove the monitors at some point this may break.
-                       let feeest = test_utils::TestFeeEstimator { sat_per_kw: 253 };
+                       let feeest = test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) };
                        let mut deserialized_monitors = Vec::new();
                        {
                                let old_monitors = self.chain_monitor.chain_monitor.monitors.read().unwrap();
@@ -295,7 +295,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                <(BlockHash, ChannelManager<EnforcingSigner, &test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>)>::read(&mut ::std::io::Cursor::new(w.0), ChannelManagerReadArgs {
                                        default_config: *self.node.get_current_default_configuration(),
                                        keys_manager: self.keys_manager,
-                                       fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: 253 },
+                                       fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) },
                                        chain_monitor: self.chain_monitor,
                                        tx_broadcaster: &test_utils::TestBroadcaster {
                                                txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone()),
@@ -976,11 +976,16 @@ macro_rules! expect_payment_received {
                let events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                match events[0] {
-                       Event::PaymentReceived { ref payment_hash, ref payment_preimage, ref payment_secret, amt, user_payment_id: _ } => {
+                       Event::PaymentReceived { ref payment_hash, ref purpose, amt } => {
                                assert_eq!($expected_payment_hash, *payment_hash);
-                               assert!(payment_preimage.is_none());
-                               assert_eq!($expected_payment_secret, *payment_secret);
                                assert_eq!($expected_recv_value, amt);
+                               match purpose {
+                                       PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                               assert!(payment_preimage.is_none());
+                                               assert_eq!($expected_payment_secret, *payment_secret);
+                                       },
+                                       _ => {},
+                               }
                        },
                        _ => panic!("Unexpected event"),
                }
@@ -1051,7 +1056,7 @@ pub fn send_along_route_with_secret<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
        pass_along_route(origin_node, expected_paths, recv_value, our_payment_hash, our_payment_secret);
 }
 
-pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path: &[&Node<'a, 'b, 'c>], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: PaymentSecret, ev: MessageSendEvent, payment_received_expected: bool) {
+pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path: &[&Node<'a, 'b, 'c>], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: Option<PaymentSecret>, ev: MessageSendEvent, payment_received_expected: bool, expected_preimage: Option<PaymentPreimage>) {
        let mut payment_event = SendEvent::from_event(ev);
        let mut prev_node = origin_node;
 
@@ -1069,10 +1074,18 @@ pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path
                        if payment_received_expected {
                                assert_eq!(events_2.len(), 1);
                                match events_2[0] {
-                                       Event::PaymentReceived { ref payment_hash, ref payment_preimage, ref payment_secret, amt, user_payment_id: _ } => {
+                                       Event::PaymentReceived { ref payment_hash, ref purpose, amt} => {
                                                assert_eq!(our_payment_hash, *payment_hash);
-                                               assert!(payment_preimage.is_none());
-                                               assert_eq!(our_payment_secret, *payment_secret);
+                                               match &purpose {
+                                                       PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                                               assert_eq!(expected_preimage, *payment_preimage);
+                                                               assert_eq!(our_payment_secret.unwrap(), *payment_secret);
+                                                       },
+                                                       PaymentPurpose::SpontaneousPayment(payment_preimage) => {
+                                                               assert_eq!(expected_preimage.unwrap(), *payment_preimage);
+                                                               assert!(our_payment_secret.is_none());
+                                                       },
+                                               }
                                                assert_eq!(amt, recv_value);
                                        },
                                        _ => panic!("Unexpected event"),
@@ -1099,7 +1112,7 @@ pub fn pass_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
                // Once we've gotten through all the HTLCs, the last one should result in a
                // PaymentReceived (but each previous one should not!), .
                let expect_payment = path_idx == expected_route.len() - 1;
-               pass_along_path(origin_node, expected_path, recv_value, our_payment_hash.clone(), our_payment_secret, ev, expect_payment);
+               pass_along_path(origin_node, expected_path, recv_value, our_payment_hash.clone(), Some(our_payment_secret), ev, expect_payment, None);
        }
 }
 
@@ -1206,7 +1219,10 @@ pub const TEST_FINAL_CLTV: u32 = 70;
 pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) {
        let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
        let logger = test_utils::TestLogger::new();
-       let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), recv_value, TEST_FINAL_CLTV, &logger).unwrap();
+       let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(),
+               &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()),
+               Some(&origin_node.node.list_usable_channels().iter().collect::<Vec<_>>()), &[],
+               recv_value, TEST_FINAL_CLTV, &logger).unwrap();
        assert_eq!(route.paths.len(), 1);
        assert_eq!(route.paths[0].len(), expected_route.len());
        for (node, hop) in expected_route.iter().zip(route.paths[0].iter()) {
@@ -1316,7 +1332,7 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec<TestChanMonCfg> {
                        txn_broadcasted: Mutex::new(Vec::new()),
                        blocks: Arc::new(Mutex::new(vec![(genesis_block(Network::Testnet).header, 0)])),
                };
-               let fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
+               let fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) };
                let chain_source = test_utils::TestChainSource::new(Network::Testnet);
                let logger = test_utils::TestLogger::with_id(format!("node {}", i));
                let persister = test_utils::TestPersister::new();
@@ -1341,22 +1357,29 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec<TestChanMon
        nodes
 }
 
+pub fn test_default_channel_config() -> UserConfig {
+       let mut default_config = UserConfig::default();
+       // Set cltv_expiry_delta slightly lower to keep the final CLTV values inside one byte in our
+       // tests so that our script-length checks don't fail (see ACCEPTED_HTLC_SCRIPT_WEIGHT).
+       default_config.channel_options.cltv_expiry_delta = 6*6;
+       default_config.channel_options.announced_channel = true;
+       default_config.peer_channel_config_limits.force_announced_channel_preference = false;
+       // When most of our tests were written, the default HTLC minimum was fixed at 1000.
+       // It now defaults to 1, so we simply set it to the expected value here.
+       default_config.own_channel_config.our_htlc_minimum_msat = 1000;
+       default_config
+}
+
 pub fn create_node_chanmgrs<'a, 'b>(node_count: usize, cfgs: &'a Vec<NodeCfg<'b>>, node_config: &[Option<UserConfig>]) -> Vec<ChannelManager<EnforcingSigner, &'a TestChainMonitor<'b>, &'b test_utils::TestBroadcaster, &'a test_utils::TestKeysInterface, &'b test_utils::TestFeeEstimator, &'b test_utils::TestLogger>> {
        let mut chanmgrs = Vec::new();
        for i in 0..node_count {
-               let mut default_config = UserConfig::default();
-               // Set cltv_expiry_delta slightly lower to keep the final CLTV values inside one byte in our
-               // tests so that our script-length checks don't fail (see ACCEPTED_HTLC_SCRIPT_WEIGHT).
-               default_config.channel_options.cltv_expiry_delta = 6*6;
-               default_config.channel_options.announced_channel = true;
-               default_config.peer_channel_config_limits.force_announced_channel_preference = false;
-               default_config.own_channel_config.our_htlc_minimum_msat = 1000; // sanitization being done by the sender, to exerce receiver logic we need to lift of limit
                let network = Network::Testnet;
                let params = ChainParameters {
                        network,
                        best_block: BestBlock::from_genesis(network),
                };
-               let node = ChannelManager::new(cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger, cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, params);
+               let node = ChannelManager::new(cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger, cfgs[i].keys_manager,
+                       if node_config[i].is_some() { node_config[i].clone().unwrap() } else { test_default_channel_config() }, params);
                chanmgrs.push(node);
        }
 
@@ -1583,18 +1606,20 @@ macro_rules! handle_chan_reestablish_msgs {
                        let mut revoke_and_ack = None;
                        let mut commitment_update = None;
                        let order = if let Some(ev) = msg_events.get(idx) {
-                               idx += 1;
                                match ev {
                                        &MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => {
                                                assert_eq!(*node_id, $dst_node.node.get_our_node_id());
                                                revoke_and_ack = Some(msg.clone());
+                                               idx += 1;
                                                RAACommitmentOrder::RevokeAndACKFirst
                                        },
                                        &MessageSendEvent::UpdateHTLCs { ref node_id, ref updates } => {
                                                assert_eq!(*node_id, $dst_node.node.get_our_node_id());
                                                commitment_update = Some(updates.clone());
+                                               idx += 1;
                                                RAACommitmentOrder::CommitmentFirst
                                        },
+                                       &MessageSendEvent::SendChannelUpdate { .. } => RAACommitmentOrder::CommitmentFirst,
                                        _ => panic!("Unexpected event"),
                                }
                        } else {
@@ -1607,16 +1632,24 @@ macro_rules! handle_chan_reestablish_msgs {
                                                assert_eq!(*node_id, $dst_node.node.get_our_node_id());
                                                assert!(revoke_and_ack.is_none());
                                                revoke_and_ack = Some(msg.clone());
+                                               idx += 1;
                                        },
                                        &MessageSendEvent::UpdateHTLCs { ref node_id, ref updates } => {
                                                assert_eq!(*node_id, $dst_node.node.get_our_node_id());
                                                assert!(commitment_update.is_none());
                                                commitment_update = Some(updates.clone());
+                                               idx += 1;
                                        },
+                                       &MessageSendEvent::SendChannelUpdate { .. } => {},
                                        _ => panic!("Unexpected event"),
                                }
                        }
 
+                       if let Some(&MessageSendEvent::SendChannelUpdate { ref node_id, ref msg }) = msg_events.get(idx) {
+                               assert_eq!(*node_id, $dst_node.node.get_our_node_id());
+                               assert_eq!(msg.contents.flags & 2, 0); // "disabled" flag must not be set as we just reconnected.
+                       }
+
                        (funding_locked, revoke_and_ack, commitment_update, order)
                }
        }