Check custom_tlvs in Event::PaymentClaimed
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index 0877bfd37737172910d7c441cdcd7a37cc1aa8d9..b9688d04348e22a6d111167e62ecdadfc390ed90 100644 (file)
@@ -15,7 +15,7 @@ use crate::chain::channelmonitor::ChannelMonitor;
 use crate::chain::transaction::OutPoint;
 use crate::events::{ClaimedHTLC, ClosureReason, Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, PathFailure, PaymentPurpose, PaymentFailureReason};
 use crate::events::bump_transaction::{BumpTransactionEvent, BumpTransactionEventHandler, Wallet, WalletSource};
-use crate::ln::{ChannelId, PaymentPreimage, PaymentHash, PaymentSecret};
+use crate::ln::types::{ChannelId, PaymentPreimage, PaymentHash, PaymentSecret};
 use crate::ln::channelmanager::{AChannelManager, ChainParameters, ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, RecipientOnionFields, PaymentId, MIN_CLTV_EXPIRY_DELTA};
 use crate::ln::features::InitFeatures;
 use crate::ln::msgs;
@@ -415,6 +415,7 @@ type TestOnionMessenger<'chan_man, 'node_cfg, 'chan_mon_cfg> = OnionMessenger<
        DedicatedEntropy,
        &'node_cfg test_utils::TestKeysInterface,
        &'chan_mon_cfg test_utils::TestLogger,
+       &'chan_man TestChannelManager<'node_cfg, 'chan_mon_cfg>,
        &'node_cfg test_utils::TestMessageRouter<'chan_mon_cfg>,
        &'chan_man TestChannelManager<'node_cfg, 'chan_mon_cfg>,
        IgnoringMessageHandler,
@@ -1553,11 +1554,29 @@ macro_rules! check_warn_msg {
        }}
 }
 
+/// Checks if at least one peer is connected.
+fn is_any_peer_connected(node: &Node) -> bool {
+       let peer_state = node.node.per_peer_state.read().unwrap();
+       for (_, peer_mutex) in peer_state.iter() {
+               let peer = peer_mutex.lock().unwrap();
+               if peer.is_connected { return true; }
+       }
+       false
+}
+
 /// Check that a channel's closing channel update has been broadcasted, and optionally
 /// check whether an error message event has occurred.
 pub fn check_closed_broadcast(node: &Node, num_channels: usize, with_error_msg: bool) -> Vec<msgs::ErrorMessage> {
+       let mut dummy_connected = false;
+       if !is_any_peer_connected(node) {
+               connect_dummy_node(&node);
+               dummy_connected = true;
+       }
        let msg_events = node.node.get_and_clear_pending_msg_events();
        assert_eq!(msg_events.len(), if with_error_msg { num_channels * 2 } else { num_channels });
+       if dummy_connected {
+               disconnect_dummy_node(&node);
+       }
        msg_events.into_iter().filter_map(|msg_event| {
                match msg_event {
                        MessageSendEvent::BroadcastChannelUpdate { ref msg } => {
@@ -2110,7 +2129,15 @@ pub fn check_payment_claimable(
                        assert_eq!(expected_recv_value, *amount_msat);
                        assert_eq!(expected_receiver_node_id, receiver_node_id.unwrap());
                        match purpose {
-                               PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                               PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                       assert_eq!(&expected_payment_preimage, payment_preimage);
+                                       assert_eq!(expected_payment_secret, *payment_secret);
+                               },
+                               PaymentPurpose::Bolt12OfferPayment { payment_preimage, payment_secret, .. } => {
+                                       assert_eq!(&expected_payment_preimage, payment_preimage);
+                                       assert_eq!(expected_payment_secret, *payment_secret);
+                               },
+                               PaymentPurpose::Bolt12RefundPayment { payment_preimage, payment_secret, .. } => {
                                        assert_eq!(&expected_payment_preimage, payment_preimage);
                                        assert_eq!(expected_payment_secret, *payment_secret);
                                },
@@ -2507,6 +2534,8 @@ pub struct PassAlongPathArgs<'a, 'b, 'c, 'd> {
        pub clear_recipient_events: bool,
        pub expected_preimage: Option<PaymentPreimage>,
        pub is_probe: bool,
+       pub custom_tlvs: Vec<(u64, Vec<u8>)>,
+       pub payment_metadata: Option<Vec<u8>>,
 }
 
 impl<'a, 'b, 'c, 'd> PassAlongPathArgs<'a, 'b, 'c, 'd> {
@@ -2517,7 +2546,7 @@ impl<'a, 'b, 'c, 'd> PassAlongPathArgs<'a, 'b, 'c, 'd> {
                Self {
                        origin_node, expected_path, recv_value, payment_hash, payment_secret: None, event,
                        payment_claimable_expected: true, clear_recipient_events: true, expected_preimage: None,
-                       is_probe: false,
+                       is_probe: false, custom_tlvs: Vec::new(), payment_metadata: None,
                }
        }
        pub fn without_clearing_recipient_events(mut self) -> Self {
@@ -2541,13 +2570,21 @@ impl<'a, 'b, 'c, 'd> PassAlongPathArgs<'a, 'b, 'c, 'd> {
                self.expected_preimage = Some(payment_preimage);
                self
        }
+       pub fn with_custom_tlvs(mut self, custom_tlvs: Vec<(u64, Vec<u8>)>) -> Self {
+               self.custom_tlvs = custom_tlvs;
+               self
+       }
+       pub fn with_payment_metadata(mut self, payment_metadata: Vec<u8>) -> Self {
+               self.payment_metadata = Some(payment_metadata);
+               self
+       }
 }
 
 pub fn do_pass_along_path<'a, 'b, 'c>(args: PassAlongPathArgs) -> Option<Event> {
        let PassAlongPathArgs {
                origin_node, expected_path, recv_value, payment_hash: our_payment_hash,
                payment_secret: our_payment_secret, event: ev, payment_claimable_expected,
-               clear_recipient_events, expected_preimage, is_probe
+               clear_recipient_events, expected_preimage, is_probe, custom_tlvs, payment_metadata,
        } = args;
 
        let mut payment_event = SendEvent::from_event(ev);
@@ -2580,8 +2617,20 @@ pub fn do_pass_along_path<'a, 'b, 'c>(args: PassAlongPathArgs) -> Option<Event>
                                                assert_eq!(our_payment_hash, *payment_hash);
                                                assert_eq!(node.node.get_our_node_id(), receiver_node_id.unwrap());
                                                assert!(onion_fields.is_some());
+                                               assert_eq!(onion_fields.as_ref().unwrap().custom_tlvs, custom_tlvs);
+                                               assert_eq!(onion_fields.as_ref().unwrap().payment_metadata, payment_metadata);
                                                match &purpose {
-                                                       PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                                       PaymentPurpose::Bolt11InvoicePayment { payment_preimage, payment_secret, .. } => {
+                                                               assert_eq!(expected_preimage, *payment_preimage);
+                                                               assert_eq!(our_payment_secret.unwrap(), *payment_secret);
+                                                               assert_eq!(Some(*payment_secret), onion_fields.as_ref().unwrap().payment_secret);
+                                                       },
+                                                       PaymentPurpose::Bolt12OfferPayment { payment_preimage, payment_secret, .. } => {
+                                                               assert_eq!(expected_preimage, *payment_preimage);
+                                                               assert_eq!(our_payment_secret.unwrap(), *payment_secret);
+                                                               assert_eq!(Some(*payment_secret), onion_fields.as_ref().unwrap().payment_secret);
+                                                       },
+                                                       PaymentPurpose::Bolt12RefundPayment { payment_preimage, payment_secret, .. } => {
                                                                assert_eq!(expected_preimage, *payment_preimage);
                                                                assert_eq!(our_payment_secret.unwrap(), *payment_secret);
                                                                assert_eq!(Some(*payment_secret), onion_fields.as_ref().unwrap().payment_secret);
@@ -2667,18 +2716,12 @@ pub fn send_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, route: Route
        (our_payment_preimage, our_payment_hash, our_payment_secret, payment_id)
 }
 
-pub fn do_claim_payment_along_route<'a, 'b, 'c>(
-       origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool,
-       our_payment_preimage: PaymentPreimage
-) -> u64 {
-       for path in expected_paths.iter() {
-               assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id());
+pub fn do_claim_payment_along_route(args: ClaimAlongRouteArgs) -> u64 {
+       for path in args.expected_paths.iter() {
+               assert_eq!(path.last().unwrap().node.get_our_node_id(), args.expected_paths[0].last().unwrap().node.get_our_node_id());
        }
-       expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage);
-       pass_claimed_payment_along_route(
-               ClaimAlongRouteArgs::new(origin_node, expected_paths, our_payment_preimage)
-                       .skip_last(skip_last)
-       )
+       args.expected_paths[0].last().unwrap().node.claim_funds(args.payment_preimage);
+       pass_claimed_payment_along_route(args)
 }
 
 pub struct ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
@@ -2688,6 +2731,7 @@ pub struct ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
        pub expected_min_htlc_overpay: Vec<u32>,
        pub skip_last: bool,
        pub payment_preimage: PaymentPreimage,
+       pub custom_tlvs: Vec<(u64, Vec<u8>)>,
        // Allow forwarding nodes to have taken 1 msat more fee than expected based on the downstream
        // fulfill amount.
        //
@@ -2706,7 +2750,7 @@ impl<'a, 'b, 'c, 'd> ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
                Self {
                        origin_node, expected_paths, expected_extra_fees: vec![0; expected_paths.len()],
                        expected_min_htlc_overpay: vec![0; expected_paths.len()], skip_last: false, payment_preimage,
-                       allow_1_msat_fee_overpay: false,
+                       allow_1_msat_fee_overpay: false, custom_tlvs: vec![],
                }
        }
        pub fn skip_last(mut self, skip_last: bool) -> Self {
@@ -2725,12 +2769,16 @@ impl<'a, 'b, 'c, 'd> ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
                self.allow_1_msat_fee_overpay = true;
                self
        }
+       pub fn with_custom_tlvs(mut self, custom_tlvs: Vec<(u64, Vec<u8>)>) -> Self {
+               self.custom_tlvs = custom_tlvs;
+               self
+       }
 }
 
-pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArgs) -> u64 {
+pub fn pass_claimed_payment_along_route(args: ClaimAlongRouteArgs) -> u64 {
        let ClaimAlongRouteArgs {
                origin_node, expected_paths, expected_extra_fees, expected_min_htlc_overpay, skip_last,
-               payment_preimage: our_payment_preimage, allow_1_msat_fee_overpay,
+               payment_preimage: our_payment_preimage, allow_1_msat_fee_overpay, custom_tlvs,
        } = args;
        let claim_event = expected_paths[0].last().unwrap().node.get_and_clear_pending_events();
        assert_eq!(claim_event.len(), 1);
@@ -2738,32 +2786,36 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
        let mut fwd_amt_msat = 0;
        match claim_event[0] {
                Event::PaymentClaimed {
-                       purpose: PaymentPurpose::SpontaneousPayment(preimage),
+                       purpose: PaymentPurpose::SpontaneousPayment(preimage)
+                               | PaymentPurpose::Bolt11InvoicePayment { payment_preimage: Some(preimage), .. }
+                               | PaymentPurpose::Bolt12OfferPayment { payment_preimage: Some(preimage), .. }
+                               | PaymentPurpose::Bolt12RefundPayment { payment_preimage: Some(preimage), .. },
                        amount_msat,
                        ref htlcs,
-                       .. }
-               | Event::PaymentClaimed {
-                       purpose: PaymentPurpose::InvoicePayment { payment_preimage: Some(preimage), ..},
-                       ref htlcs,
-                       amount_msat,
+                       ref onion_fields,
                        ..
                } => {
                        assert_eq!(preimage, our_payment_preimage);
                        assert_eq!(htlcs.len(), expected_paths.len());  // One per path.
                        assert_eq!(htlcs.iter().map(|h| h.value_msat).sum::<u64>(), amount_msat);
+                       assert_eq!(onion_fields.as_ref().unwrap().custom_tlvs, custom_tlvs);
                        expected_paths.iter().zip(htlcs).for_each(|(path, htlc)| check_claimed_htlc_channel(origin_node, path, htlc));
                        fwd_amt_msat = amount_msat;
                },
                Event::PaymentClaimed {
-                       purpose: PaymentPurpose::InvoicePayment { .. },
+                       purpose: PaymentPurpose::Bolt11InvoicePayment { .. }
+                               | PaymentPurpose::Bolt12OfferPayment { .. }
+                               | PaymentPurpose::Bolt12RefundPayment { .. },
                        payment_hash,
                        amount_msat,
                        ref htlcs,
+                       ref onion_fields,
                        ..
                } => {
                        assert_eq!(&payment_hash.0, &Sha256::hash(&our_payment_preimage.0)[..]);
                        assert_eq!(htlcs.len(), expected_paths.len());  // One per path.
                        assert_eq!(htlcs.iter().map(|h| h.value_msat).sum::<u64>(), amount_msat);
+                       assert_eq!(onion_fields.as_ref().unwrap().custom_tlvs, custom_tlvs);
                        expected_paths.iter().zip(htlcs).for_each(|(path, htlc)| check_claimed_htlc_channel(origin_node, path, htlc));
                        fwd_amt_msat = amount_msat;
                }
@@ -2907,15 +2959,20 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
 
        expected_total_fee_msat
 }
-pub fn claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_preimage: PaymentPreimage) {
-       let expected_total_fee_msat = do_claim_payment_along_route(origin_node, expected_paths, skip_last, our_payment_preimage);
+pub fn claim_payment_along_route(args: ClaimAlongRouteArgs) {
+       let origin_node = args.origin_node;
+       let payment_preimage = args.payment_preimage;
+       let skip_last = args.skip_last;
+       let expected_total_fee_msat = do_claim_payment_along_route(args);
        if !skip_last {
-               expect_payment_sent!(origin_node, our_payment_preimage, Some(expected_total_fee_msat));
+               expect_payment_sent!(origin_node, payment_preimage, Some(expected_total_fee_msat));
        }
 }
 
 pub fn claim_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], our_payment_preimage: PaymentPreimage) {
-       claim_payment_along_route(origin_node, &[expected_route], false, our_payment_preimage);
+       claim_payment_along_route(
+               ClaimAlongRouteArgs::new(origin_node, &[expected_route], our_payment_preimage)
+       );
 }
 
 pub const TEST_FINAL_CLTV: u32 = 70;
@@ -3175,8 +3232,8 @@ pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeC
        for i in 0..node_count {
                let dedicated_entropy = DedicatedEntropy(RandomBytes::new([i as u8; 32]));
                let onion_messenger = OnionMessenger::new(
-                       dedicated_entropy, cfgs[i].keys_manager, cfgs[i].logger, &cfgs[i].message_router,
-                       &chan_mgrs[i], IgnoringMessageHandler {},
+                       dedicated_entropy, cfgs[i].keys_manager, cfgs[i].logger, &chan_mgrs[i],
+                       &cfgs[i].message_router, &chan_mgrs[i], IgnoringMessageHandler {},
                );
                let gossip_sync = P2PGossipSync::new(cfgs[i].network_graph.as_ref(), None, cfgs[i].logger);
                let wallet_source = Arc::new(test_utils::TestWalletSource::new(SecretKey::from_slice(&[i as u8 + 1; 32]).unwrap()));
@@ -3224,6 +3281,28 @@ pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeC
        nodes
 }
 
+pub fn connect_dummy_node<'a, 'b: 'a, 'c: 'b>(node: &Node<'a, 'b, 'c>) {
+       let node_id_dummy = PublicKey::from_slice(&[2; 33]).unwrap();
+
+       let mut dummy_init_features = InitFeatures::empty();
+       dummy_init_features.set_static_remote_key_required();
+
+       let init_dummy = msgs::Init {
+               features: dummy_init_features,
+               networks: None,
+               remote_network_address: None
+       };
+
+       node.node.peer_connected(&node_id_dummy, &init_dummy, true).unwrap();
+       node.onion_messenger.peer_connected(&node_id_dummy, &init_dummy, true).unwrap();
+}
+
+pub fn disconnect_dummy_node<'a, 'b: 'a, 'c: 'b>(node: &Node<'a, 'b, 'c>) {
+       let node_id_dummy = PublicKey::from_slice(&[2; 33]).unwrap();
+       node.node.peer_disconnected(&node_id_dummy);
+       node.onion_messenger.peer_disconnected(&node_id_dummy);
+}
+
 // Note that the following only works for CLTV values up to 128
 pub const ACCEPTED_HTLC_SCRIPT_WEIGHT: usize = 137; // Here we have a diff due to HTLC CLTV expiry being < 2^15 in test
 pub const ACCEPTED_HTLC_SCRIPT_WEIGHT_ANCHORS: usize = 140; // Here we have a diff due to HTLC CLTV expiry being < 2^15 in test
@@ -3335,15 +3414,21 @@ pub fn check_preimage_claim<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, prev_txn: &Vec<
 }
 
 pub fn handle_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec<Node<'a, 'b, 'c>>, a: usize, b: usize, needs_err_handle: bool, expected_error: &str)  {
+       let mut dummy_connected = false;
+       if !is_any_peer_connected(&nodes[a]) {
+               connect_dummy_node(&nodes[a]);
+               dummy_connected = true
+       }
+
        let events_1 = nodes[a].node.get_and_clear_pending_msg_events();
        assert_eq!(events_1.len(), 2);
-       let as_update = match events_1[0] {
+       let as_update = match events_1[1] {
                MessageSendEvent::BroadcastChannelUpdate { ref msg } => {
                        msg.clone()
                },
                _ => panic!("Unexpected event"),
        };
-       match events_1[1] {
+       match events_1[0] {
                MessageSendEvent::HandleError { node_id, action: msgs::ErrorAction::SendErrorMessage { ref msg } } => {
                        assert_eq!(node_id, nodes[b].node.get_our_node_id());
                        assert_eq!(msg.data, expected_error);
@@ -3360,17 +3445,24 @@ pub fn handle_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec<Node<'a, '
                },
                _ => panic!("Unexpected event"),
        }
-
+       if dummy_connected {
+               disconnect_dummy_node(&nodes[a]);
+               dummy_connected = false;
+       }
+       if !is_any_peer_connected(&nodes[b]) {
+               connect_dummy_node(&nodes[b]);
+               dummy_connected = true;
+       }
        let events_2 = nodes[b].node.get_and_clear_pending_msg_events();
        assert_eq!(events_2.len(), if needs_err_handle { 1 } else { 2 });
-       let bs_update = match events_2[0] {
+       let bs_update = match events_2.last().unwrap() {
                MessageSendEvent::BroadcastChannelUpdate { ref msg } => {
                        msg.clone()
                },
                _ => panic!("Unexpected event"),
        };
        if !needs_err_handle {
-               match events_2[1] {
+               match events_2[0] {
                        MessageSendEvent::HandleError { node_id, action: msgs::ErrorAction::SendErrorMessage { ref msg } } => {
                                assert_eq!(node_id, nodes[a].node.get_our_node_id());
                                assert_eq!(msg.data, expected_error);
@@ -3382,7 +3474,9 @@ pub fn handle_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec<Node<'a, '
                        _ => panic!("Unexpected event"),
                }
        }
-
+       if dummy_connected {
+               disconnect_dummy_node(&nodes[b]);
+       }
        for node in nodes {
                node.gossip_sync.handle_channel_update(&as_update).unwrap();
                node.gossip_sync.handle_channel_update(&bs_update).unwrap();