Merge pull request #2441 from arik-so/2023-07-taproot-signer-wrapped
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index 84bc1a1b3f0656476aa1d00771f0368d97003cd0..1db4e873458001b820db2ba4d0284e89a95b20ce 100644 (file)
@@ -14,10 +14,10 @@ use crate::chain::{BestBlock, ChannelMonitorUpdateStatus, Confirm, Listen, Watch
 use crate::sign::EntropySource;
 use crate::chain::channelmonitor::ChannelMonitor;
 use crate::chain::transaction::OutPoint;
-use crate::events::{ClosureReason, Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, PathFailure, PaymentPurpose, PaymentFailureReason};
+use crate::events::{ClaimedHTLC, ClosureReason, Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, PathFailure, PaymentPurpose, PaymentFailureReason};
 use crate::events::bump_transaction::{BumpTransactionEventHandler, Wallet, WalletSource};
 use crate::ln::{PaymentPreimage, PaymentHash, PaymentSecret};
-use crate::ln::channelmanager::{AChannelManager, ChainParameters, ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, RecipientOnionFields, PaymentId, MIN_CLTV_EXPIRY_DELTA};
+use crate::ln::channelmanager::{self, AChannelManager, ChainParameters, ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, RecipientOnionFields, PaymentId, MIN_CLTV_EXPIRY_DELTA};
 use crate::routing::gossip::{P2PGossipSync, NetworkGraph, NetworkUpdate};
 use crate::routing::router::{self, PaymentParameters, Route};
 use crate::ln::features::InitFeatures;
@@ -368,31 +368,40 @@ pub struct NodeCfg<'a> {
        pub override_init_features: Rc<RefCell<Option<InitFeatures>>>,
 }
 
-type TestChannelManager<'a, 'b, 'c> = ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger>;
-
-pub struct Node<'a, 'b: 'a, 'c: 'b> {
-       pub chain_source: &'c test_utils::TestChainSource,
-       pub tx_broadcaster: &'c test_utils::TestBroadcaster,
-       pub fee_estimator: &'c test_utils::TestFeeEstimator,
-       pub router: &'b test_utils::TestRouter<'c>,
-       pub chain_monitor: &'b test_utils::TestChainMonitor<'c>,
-       pub keys_manager: &'b test_utils::TestKeysInterface,
-       pub node: &'a TestChannelManager<'a, 'b, 'c>,
-       pub network_graph: &'a NetworkGraph<&'c test_utils::TestLogger>,
-       pub gossip_sync: P2PGossipSync<&'b NetworkGraph<&'c test_utils::TestLogger>, &'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
+type TestChannelManager<'node_cfg, 'chan_mon_cfg> = ChannelManager<
+       &'node_cfg TestChainMonitor<'chan_mon_cfg>,
+       &'chan_mon_cfg test_utils::TestBroadcaster,
+       &'node_cfg test_utils::TestKeysInterface,
+       &'node_cfg test_utils::TestKeysInterface,
+       &'node_cfg test_utils::TestKeysInterface,
+       &'chan_mon_cfg test_utils::TestFeeEstimator,
+       &'node_cfg test_utils::TestRouter<'chan_mon_cfg>,
+       &'chan_mon_cfg test_utils::TestLogger,
+>;
+
+pub struct Node<'chan_man, 'node_cfg: 'chan_man, 'chan_mon_cfg: 'node_cfg> {
+       pub chain_source: &'chan_mon_cfg test_utils::TestChainSource,
+       pub tx_broadcaster: &'chan_mon_cfg test_utils::TestBroadcaster,
+       pub fee_estimator: &'chan_mon_cfg test_utils::TestFeeEstimator,
+       pub router: &'node_cfg test_utils::TestRouter<'chan_mon_cfg>,
+       pub chain_monitor: &'node_cfg test_utils::TestChainMonitor<'chan_mon_cfg>,
+       pub keys_manager: &'chan_mon_cfg test_utils::TestKeysInterface,
+       pub node: &'chan_man TestChannelManager<'node_cfg, 'chan_mon_cfg>,
+       pub network_graph: &'node_cfg NetworkGraph<&'chan_mon_cfg test_utils::TestLogger>,
+       pub gossip_sync: P2PGossipSync<&'node_cfg NetworkGraph<&'chan_mon_cfg test_utils::TestLogger>, &'chan_mon_cfg test_utils::TestChainSource, &'chan_mon_cfg test_utils::TestLogger>,
        pub node_seed: [u8; 32],
        pub network_payment_count: Rc<RefCell<u8>>,
        pub network_chan_count: Rc<RefCell<u32>>,
-       pub logger: &'c test_utils::TestLogger,
+       pub logger: &'chan_mon_cfg test_utils::TestLogger,
        pub blocks: Arc<Mutex<Vec<(Block, u32)>>>,
        pub connect_style: Rc<RefCell<ConnectStyle>>,
        pub override_init_features: Rc<RefCell<Option<InitFeatures>>>,
        pub wallet_source: Arc<test_utils::TestWalletSource>,
        pub bump_tx_handler: BumpTransactionEventHandler<
-               &'c test_utils::TestBroadcaster,
-               Arc<Wallet<Arc<test_utils::TestWalletSource>, &'c test_utils::TestLogger>>,
-               &'b test_utils::TestKeysInterface,
-               &'c test_utils::TestLogger,
+               &'chan_mon_cfg test_utils::TestBroadcaster,
+               Arc<Wallet<Arc<test_utils::TestWalletSource>, &'chan_mon_cfg test_utils::TestLogger>>,
+               &'chan_mon_cfg test_utils::TestKeysInterface,
+               &'chan_mon_cfg test_utils::TestLogger,
        >,
 }
 impl<'a, 'b, 'c> Node<'a, 'b, 'c> {
@@ -448,8 +457,8 @@ impl<H: NodeHolder> NodeHolder for &H {
        fn chain_monitor(&self) -> Option<&test_utils::TestChainMonitor> { (*self).chain_monitor() }
 }
 impl<'a, 'b: 'a, 'c: 'b> NodeHolder for Node<'a, 'b, 'c> {
-       type CM = TestChannelManager<'a, 'b, 'c>;
-       fn node(&self) -> &TestChannelManager<'a, 'b, 'c> { &self.node }
+       type CM = TestChannelManager<'b, 'c>;
+       fn node(&self) -> &TestChannelManager<'b, 'c> { &self.node }
        fn chain_monitor(&self) -> Option<&test_utils::TestChainMonitor> { Some(self.chain_monitor) }
 }
 
@@ -924,7 +933,22 @@ macro_rules! check_added_monitors {
        }
 }
 
-pub fn _reload_node<'a, 'b, 'c, 'd>(node: &'a Node<'b, 'c, 'd>, default_config: UserConfig, chanman_encoded: &[u8], monitors_encoded: &[&[u8]]) -> ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger> {
+/// Checks whether the claimed HTLC for the specified path has the correct channel information.
+///
+/// This will panic if the path is empty, if the HTLC's channel ID is not actually a channel that
+/// connects the final two nodes in the path, or if the `user_channel_id` is incorrect.
+pub fn check_claimed_htlc_channel<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, path: &[&Node<'a, 'b, 'c>], htlc: &ClaimedHTLC) {
+       let mut nodes = path.iter().rev();
+       let dest = nodes.next().expect("path should have a destination").node;
+       let prev = nodes.next().unwrap_or(&origin_node).node;
+       let dest_channels = dest.list_channels();
+       let ch = dest_channels.iter().find(|ch| ch.channel_id == htlc.channel_id)
+               .expect("HTLC's channel should be one of destination node's channels");
+       assert_eq!(htlc.user_channel_id, ch.user_channel_id);
+       assert_eq!(ch.counterparty.node_id, prev.get_our_node_id());
+}
+
+pub fn _reload_node<'a, 'b, 'c>(node: &'a Node<'a, 'b, 'c>, default_config: UserConfig, chanman_encoded: &[u8], monitors_encoded: &[&[u8]]) -> TestChannelManager<'b, 'c> {
        let mut monitors_read = Vec::with_capacity(monitors_encoded.len());
        for encoded in monitors_encoded {
                let mut monitor_read = &encoded[..];
@@ -940,7 +964,7 @@ pub fn _reload_node<'a, 'b, 'c, 'd>(node: &'a Node<'b, 'c, 'd>, default_config:
                for monitor in monitors_read.iter_mut() {
                        assert!(channel_monitors.insert(monitor.get_funding_txo().0, monitor).is_none());
                }
-               <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut node_read, ChannelManagerReadArgs {
+               <(BlockHash, TestChannelManager<'b, 'c>)>::read(&mut node_read, ChannelManagerReadArgs {
                        default_config,
                        entropy_source: node.keys_manager,
                        node_signer: node.keys_manager,
@@ -1418,14 +1442,18 @@ macro_rules! check_closed_broadcast {
 }
 
 /// Check that a channel's closing channel events has been issued
-pub fn check_closed_event(node: &Node, events_count: usize, expected_reason: ClosureReason, is_check_discard_funding: bool) {
+pub fn check_closed_event(node: &Node, events_count: usize, expected_reason: ClosureReason, is_check_discard_funding: bool,
+       expected_counterparty_node_ids: &[PublicKey], expected_channel_capacity: u64) {
        let events = node.node.get_and_clear_pending_events();
        assert_eq!(events.len(), events_count, "{:?}", events);
        let mut issues_discard_funding = false;
-       for event in events {
+       for (idx, event) in events.into_iter().enumerate() {
                match event {
-                       Event::ChannelClosed { ref reason, .. } => {
+                       Event::ChannelClosed { ref reason, counterparty_node_id, 
+                               channel_capacity_sats, .. } => {
                                assert_eq!(*reason, expected_reason);
+                               assert_eq!(counterparty_node_id.unwrap(), expected_counterparty_node_ids[idx]);
+                               assert_eq!(channel_capacity_sats.unwrap(), expected_channel_capacity);
                        },
                        Event::DiscardFunding { .. } => {
                                issues_discard_funding = true;
@@ -1441,11 +1469,12 @@ pub fn check_closed_event(node: &Node, events_count: usize, expected_reason: Clo
 /// Don't use this, use the identically-named function instead.
 #[macro_export]
 macro_rules! check_closed_event {
-       ($node: expr, $events: expr, $reason: expr) => {
-               check_closed_event!($node, $events, $reason, false);
+       ($node: expr, $events: expr, $reason: expr, $counterparty_node_ids: expr, $channel_capacity: expr) => {
+               check_closed_event!($node, $events, $reason, false, $counterparty_node_ids, $channel_capacity);
        };
-       ($node: expr, $events: expr, $reason: expr, $is_check_discard_funding: expr) => {
-               $crate::ln::functional_test_utils::check_closed_event(&$node, $events, $reason, $is_check_discard_funding);
+       ($node: expr, $events: expr, $reason: expr, $is_check_discard_funding: expr, $counterparty_node_ids: expr, $channel_capacity: expr) => {
+               $crate::ln::functional_test_utils::check_closed_event(&$node, $events, $reason, 
+                       $is_check_discard_funding, &$counterparty_node_ids, $channel_capacity);
        }
 }
 
@@ -1670,8 +1699,8 @@ macro_rules! commitment_signed_dance {
                        bs_revoke_and_ack
                }
        };
-       ($node_a: expr, $node_b: expr, (), $fail_backwards: expr, true /* skip last step */, false /* no extra message */) => {
-               assert!($crate::ln::functional_test_utils::commitment_signed_dance_through_cp_raa(&$node_a, &$node_b, $fail_backwards).is_none());
+       ($node_a: expr, $node_b: expr, (), $fail_backwards: expr, true /* skip last step */, false /* no extra message */, $incl_claim: expr) => {
+               assert!($crate::ln::functional_test_utils::commitment_signed_dance_through_cp_raa(&$node_a, &$node_b, $fail_backwards, $incl_claim).is_none());
        };
        ($node_a: expr, $node_b: expr, $commitment_signed: expr, $fail_backwards: expr) => {
                $crate::ln::functional_test_utils::do_commitment_signed_dance(&$node_a, &$node_b, &$commitment_signed, $fail_backwards, false);
@@ -1682,11 +1711,16 @@ macro_rules! commitment_signed_dance {
 /// the initiator's `revoke_and_ack` response. i.e. [`do_main_commitment_signed_dance`] plus the
 /// `revoke_and_ack` response to it.
 ///
+/// An HTLC claim on one channel blocks the RAA channel monitor update for the outbound edge
+/// channel until the inbound edge channel preimage monitor update completes. Thus, when checking
+/// for channel monitor updates, we need to know if an `update_fulfill_htlc` was included in the
+/// the commitment we're exchanging. `includes_claim` provides that information.
+///
 /// Returns any additional message `node_b` generated in addition to the `revoke_and_ack` response.
-pub fn commitment_signed_dance_through_cp_raa(node_a: &Node<'_, '_, '_>, node_b: &Node<'_, '_, '_>, fail_backwards: bool) -> Option<MessageSendEvent> {
+pub fn commitment_signed_dance_through_cp_raa(node_a: &Node<'_, '_, '_>, node_b: &Node<'_, '_, '_>, fail_backwards: bool, includes_claim: bool) -> Option<MessageSendEvent> {
        let (extra_msg_option, bs_revoke_and_ack) = do_main_commitment_signed_dance(node_a, node_b, fail_backwards);
        node_a.node.handle_revoke_and_ack(&node_b.node.get_our_node_id(), &bs_revoke_and_ack);
-       check_added_monitors(node_a, 1);
+       check_added_monitors(node_a, if includes_claim { 0 } else { 1 });
        extra_msg_option
 }
 
@@ -1733,7 +1767,23 @@ pub fn do_commitment_signed_dance(node_a: &Node<'_, '_, '_>, node_b: &Node<'_, '
        node_a.node.handle_commitment_signed(&node_b.node.get_our_node_id(), commitment_signed);
        check_added_monitors!(node_a, 1);
 
-       commitment_signed_dance!(node_a, node_b, (), fail_backwards, true, false);
+       // If this commitment signed dance was due to a claim, don't check for an RAA monitor update.
+       let got_claim = node_a.node.pending_events.lock().unwrap().iter().any(|(ev, action)| {
+               let matching_action = if let Some(channelmanager::EventCompletionAction::ReleaseRAAChannelMonitorUpdate
+                       { channel_funding_outpoint, counterparty_node_id }) = action
+               {
+                       if channel_funding_outpoint.to_channel_id() == commitment_signed.channel_id {
+                               assert_eq!(*counterparty_node_id, node_b.node.get_our_node_id());
+                               true
+                       } else { false }
+               } else { false };
+               if matching_action {
+                       if let Event::PaymentSent { .. } = ev {} else { panic!(); }
+               }
+               matching_action
+       });
+       if fail_backwards { assert!(!got_claim); }
+       commitment_signed_dance!(node_a, node_b, (), fail_backwards, true, false, got_claim);
 
        if skip_last_step { return; }
 
@@ -1878,7 +1928,7 @@ macro_rules! expect_payment_claimed {
 
 pub fn expect_payment_sent<CM: AChannelManager, H: NodeHolder<CM=CM>>(node: &H,
        expected_payment_preimage: PaymentPreimage, expected_fee_msat_opt: Option<Option<u64>>,
-       expect_per_path_claims: bool,
+       expect_per_path_claims: bool, expect_post_ev_mon_update: bool,
 ) {
        let events = node.node().get_and_clear_pending_events();
        let expected_payment_hash = PaymentHash(
@@ -1888,6 +1938,9 @@ pub fn expect_payment_sent<CM: AChannelManager, H: NodeHolder<CM=CM>>(node: &H,
        } else {
                assert_eq!(events.len(), 1);
        }
+       if expect_post_ev_mon_update {
+               check_added_monitors(node, 1);
+       }
        let expected_payment_id = match events[0] {
                Event::PaymentSent { ref payment_id, ref payment_preimage, ref payment_hash, ref fee_paid_msat } => {
                        assert_eq!(expected_payment_preimage, *payment_preimage);
@@ -1914,17 +1967,6 @@ pub fn expect_payment_sent<CM: AChannelManager, H: NodeHolder<CM=CM>>(node: &H,
        }
 }
 
-#[cfg(test)]
-#[macro_export]
-macro_rules! expect_payment_sent_without_paths {
-       ($node: expr, $expected_payment_preimage: expr) => {
-               expect_payment_sent!($node, $expected_payment_preimage, None::<u64>, false);
-       };
-       ($node: expr, $expected_payment_preimage: expr, $expected_fee_msat_opt: expr) => {
-               expect_payment_sent!($node, $expected_payment_preimage, $expected_fee_msat_opt, false);
-       }
-}
-
 #[macro_export]
 macro_rules! expect_payment_sent {
        ($node: expr, $expected_payment_preimage: expr) => {
@@ -1935,7 +1977,7 @@ macro_rules! expect_payment_sent {
        };
        ($node: expr, $expected_payment_preimage: expr, $expected_fee_msat_opt: expr, $expect_paths: expr) => {
                $crate::ln::functional_test_utils::expect_payment_sent(&$node, $expected_payment_preimage,
-                       $expected_fee_msat_opt.map(|o| Some(o)), $expect_paths);
+                       $expected_fee_msat_opt.map(|o| Some(o)), $expect_paths, true);
        }
 }
 
@@ -2250,15 +2292,41 @@ pub fn do_claim_payment_along_route_with_extra_penultimate_hop_fees<'a, 'b, 'c>(
                assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id());
        }
        expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage);
+       pass_claimed_payment_along_route(origin_node, expected_paths, expected_extra_fees, skip_last, our_payment_preimage)
+}
 
+pub fn pass_claimed_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], expected_extra_fees: &[u32], skip_last: bool, our_payment_preimage: PaymentPreimage) -> u64 {
        let claim_event = expected_paths[0].last().unwrap().node.get_and_clear_pending_events();
        assert_eq!(claim_event.len(), 1);
        match claim_event[0] {
-               Event::PaymentClaimed { purpose: PaymentPurpose::SpontaneousPayment(preimage), .. }|
-               Event::PaymentClaimed { purpose: PaymentPurpose::InvoicePayment { payment_preimage: Some(preimage), ..}, .. } =>
-                       assert_eq!(preimage, our_payment_preimage),
-               Event::PaymentClaimed { purpose: PaymentPurpose::InvoicePayment { .. }, payment_hash, .. } =>
-                       assert_eq!(&payment_hash.0, &Sha256::hash(&our_payment_preimage.0)[..]),
+               Event::PaymentClaimed {
+                       purpose: PaymentPurpose::SpontaneousPayment(preimage),
+                       amount_msat,
+                       ref htlcs,
+                       .. }
+               | Event::PaymentClaimed {
+                       purpose: PaymentPurpose::InvoicePayment { payment_preimage: Some(preimage), ..},
+                       ref htlcs,
+                       amount_msat,
+                       ..
+               } => {
+                       assert_eq!(preimage, our_payment_preimage);
+                       assert_eq!(htlcs.len(), expected_paths.len());  // One per path.
+                       assert_eq!(htlcs.iter().map(|h| h.value_msat).sum::<u64>(), amount_msat);
+                       expected_paths.iter().zip(htlcs).for_each(|(path, htlc)| check_claimed_htlc_channel(origin_node, path, htlc));
+               },
+               Event::PaymentClaimed {
+                       purpose: PaymentPurpose::InvoicePayment { .. },
+                       payment_hash,
+                       amount_msat,
+                       ref htlcs,
+                       ..
+               } => {
+                       assert_eq!(&payment_hash.0, &Sha256::hash(&our_payment_preimage.0)[..]);
+                       assert_eq!(htlcs.len(), expected_paths.len());  // One per path.
+                       assert_eq!(htlcs.iter().map(|h| h.value_msat).sum::<u64>(), amount_msat);
+                       expected_paths.iter().zip(htlcs).for_each(|(path, htlc)| check_claimed_htlc_channel(origin_node, path, htlc));
+               }
                _ => panic!(),
        }
 
@@ -2945,9 +3013,41 @@ macro_rules! handle_chan_reestablish_msgs {
        }
 }
 
+pub struct ReconnectArgs<'a, 'b, 'c, 'd> {
+       pub node_a: &'a Node<'b, 'c, 'd>,
+       pub node_b: &'a Node<'b, 'c, 'd>,
+       pub send_channel_ready: (bool, bool),
+       pub pending_htlc_adds: (i64, i64),
+       pub pending_htlc_claims: (usize, usize),
+       pub pending_htlc_fails: (usize, usize),
+       pub pending_cell_htlc_claims: (usize, usize),
+       pub pending_cell_htlc_fails: (usize, usize),
+       pub pending_raa: (bool, bool),
+}
+
+impl<'a, 'b, 'c, 'd> ReconnectArgs<'a, 'b, 'c, 'd> {
+       pub fn new(node_a: &'a Node<'b, 'c, 'd>, node_b: &'a Node<'b, 'c, 'd>) -> Self {
+               Self {
+                       node_a,
+                       node_b,
+                       send_channel_ready: (false, false),
+                       pending_htlc_adds: (0, 0),
+                       pending_htlc_claims: (0, 0),
+                       pending_htlc_fails: (0, 0),
+                       pending_cell_htlc_claims: (0, 0),
+                       pending_cell_htlc_fails: (0, 0),
+                       pending_raa: (false, false),
+               }
+       }
+}
+
 /// pending_htlc_adds includes both the holding cell and in-flight update_add_htlcs, whereas
 /// for claims/fails they are separated out.
-pub fn reconnect_nodes<'a, 'b, 'c>(node_a: &Node<'a, 'b, 'c>, node_b: &Node<'a, 'b, 'c>, send_channel_ready: (bool, bool), pending_htlc_adds: (i64, i64), pending_htlc_claims: (usize, usize), pending_htlc_fails: (usize, usize), pending_cell_htlc_claims: (usize, usize), pending_cell_htlc_fails: (usize, usize), pending_raa: (bool, bool))  {
+pub fn reconnect_nodes<'a, 'b, 'c, 'd>(args: ReconnectArgs<'a, 'b, 'c, 'd>) {
+       let ReconnectArgs {
+               node_a, node_b, send_channel_ready, pending_htlc_adds, pending_htlc_claims, pending_htlc_fails,
+               pending_cell_htlc_claims, pending_cell_htlc_fails, pending_raa
+       } = args;
        node_a.node.peer_connected(&node_b.node.get_our_node_id(), &msgs::Init {
                features: node_b.node.init_features(), networks: None, remote_network_address: None
        }, true).unwrap();