Merge pull request #1434 from TheBlueMatt/2022-04-robust-payment-claims
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index d8aeafcd5e3c1f75dc101a7bf42bbfa1d07cf22d..31fdc6ffae4514411b33003574c14c362df69025 100644 (file)
@@ -35,6 +35,8 @@ use bitcoin::blockdata::transaction::{Transaction, TxOut};
 use bitcoin::network::constants::Network;
 
 use bitcoin::hash_types::BlockHash;
+use bitcoin::hashes::sha256::Hash as Sha256;
+use bitcoin::hashes::Hash as _;
 
 use bitcoin::secp256k1::PublicKey;
 
@@ -1305,9 +1307,9 @@ macro_rules! expect_payment_received {
                let events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                match events[0] {
-                       $crate::util::events::Event::PaymentReceived { ref payment_hash, ref purpose, amt } => {
+                       $crate::util::events::Event::PaymentReceived { ref payment_hash, ref purpose, amount_msat } => {
                                assert_eq!($expected_payment_hash, *payment_hash);
-                               assert_eq!($expected_recv_value, amt);
+                               assert_eq!($expected_recv_value, amount_msat);
                                match purpose {
                                        $crate::util::events::PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
                                                assert_eq!(&$expected_payment_preimage, payment_preimage);
@@ -1321,6 +1323,22 @@ macro_rules! expect_payment_received {
        }
 }
 
+#[macro_export]
+#[cfg(any(test, feature = "_bench_unstable", feature = "_test_utils"))]
+macro_rules! expect_payment_claimed {
+       ($node: expr, $expected_payment_hash: expr, $expected_recv_value: expr) => {
+               let events = $node.node.get_and_clear_pending_events();
+               assert_eq!(events.len(), 1);
+               match events[0] {
+                       $crate::util::events::Event::PaymentClaimed { ref payment_hash, amount_msat, .. } => {
+                               assert_eq!($expected_payment_hash, *payment_hash);
+                               assert_eq!($expected_recv_value, amount_msat);
+                       },
+                       _ => panic!("Unexpected event"),
+               }
+       }
+}
+
 #[cfg(test)]
 #[macro_export]
 macro_rules! expect_payment_sent_without_paths {
@@ -1531,7 +1549,7 @@ pub fn send_along_route_with_secret<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
        payment_id
 }
 
-pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path: &[&Node<'a, 'b, 'c>], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: Option<PaymentSecret>, ev: MessageSendEvent, payment_received_expected: bool, expected_preimage: Option<PaymentPreimage>) {
+pub fn do_pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path: &[&Node<'a, 'b, 'c>], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: Option<PaymentSecret>, ev: MessageSendEvent, payment_received_expected: bool, clear_recipient_events: bool, expected_preimage: Option<PaymentPreimage>) {
        let mut payment_event = SendEvent::from_event(ev);
        let mut prev_node = origin_node;
 
@@ -1544,12 +1562,12 @@ pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path
 
                expect_pending_htlcs_forwardable!(node);
 
-               if idx == expected_path.len() - 1 {
+               if idx == expected_path.len() - 1 && clear_recipient_events {
                        let events_2 = node.node.get_and_clear_pending_events();
                        if payment_received_expected {
                                assert_eq!(events_2.len(), 1);
                                match events_2[0] {
-                                       Event::PaymentReceived { ref payment_hash, ref purpose, amt} => {
+                                       Event::PaymentReceived { ref payment_hash, ref purpose, amount_msat } => {
                                                assert_eq!(our_payment_hash, *payment_hash);
                                                match &purpose {
                                                        PaymentPurpose::InvoicePayment { payment_preimage, payment_secret, .. } => {
@@ -1561,14 +1579,14 @@ pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path
                                                                assert!(our_payment_secret.is_none());
                                                        },
                                                }
-                                               assert_eq!(amt, recv_value);
+                                               assert_eq!(amount_msat, recv_value);
                                        },
                                        _ => panic!("Unexpected event"),
                                }
                        } else {
                                assert!(events_2.is_empty());
                        }
-               } else {
+               } else if idx != expected_path.len() - 1 {
                        let mut events_2 = node.node.get_and_clear_pending_msg_events();
                        assert_eq!(events_2.len(), 1);
                        check_added_monitors!(node, 1);
@@ -1580,6 +1598,10 @@ pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path
        }
 }
 
+pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path: &[&Node<'a, 'b, 'c>], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: Option<PaymentSecret>, ev: MessageSendEvent, payment_received_expected: bool, expected_preimage: Option<PaymentPreimage>) {
+       do_pass_along_path(origin_node, expected_path, recv_value, our_payment_hash, our_payment_secret, ev, payment_received_expected, true, expected_preimage);
+}
+
 pub fn pass_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&[&Node<'a, 'b, 'c>]], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: PaymentSecret) {
        let mut events = origin_node.node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), expected_route.len());
@@ -1601,7 +1623,19 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
        for path in expected_paths.iter() {
                assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id());
        }
-       assert!(expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage));
+       expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage);
+
+       let claim_event = expected_paths[0].last().unwrap().node.get_and_clear_pending_events();
+       assert_eq!(claim_event.len(), 1);
+       match claim_event[0] {
+               Event::PaymentClaimed { purpose: PaymentPurpose::SpontaneousPayment(preimage), .. }|
+               Event::PaymentClaimed { purpose: PaymentPurpose::InvoicePayment { payment_preimage: Some(preimage), ..}, .. } =>
+                       assert_eq!(preimage, our_payment_preimage),
+               Event::PaymentClaimed { purpose: PaymentPurpose::InvoicePayment { .. }, payment_hash, .. } =>
+                       assert_eq!(&payment_hash.0, &Sha256::hash(&our_payment_preimage.0)[..]),
+               _ => panic!(),
+       }
+
        check_added_monitors!(expected_paths[0].last().unwrap(), expected_paths.len());
 
        let mut expected_total_fee_msat = 0;
@@ -1696,7 +1730,7 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
        }
 
        // Ensure that claim_funds is idempotent.
-       assert!(!expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage));
+       expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage);
        assert!(expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events().is_empty());
        check_added_monitors!(expected_paths[0].last().unwrap(), 0);
 
@@ -1756,13 +1790,18 @@ pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&No
        claim_payment(&origin, expected_route, our_payment_preimage);
 }
 
-pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths_slice: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash)  {
-       let mut expected_paths: Vec<_> = expected_paths_slice.iter().collect();
+pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash) {
        for path in expected_paths.iter() {
                assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id());
        }
-       assert!(expected_paths[0].last().unwrap().node.fail_htlc_backwards(&our_payment_hash));
+       expected_paths[0].last().unwrap().node.fail_htlc_backwards(&our_payment_hash);
        expect_pending_htlcs_forwardable!(expected_paths[0].last().unwrap());
+
+       pass_failed_payment_back(origin_node, expected_paths, skip_last, our_payment_hash);
+}
+
+pub fn pass_failed_payment_back<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths_slice: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash) {
+       let mut expected_paths: Vec<_> = expected_paths_slice.iter().collect();
        check_added_monitors!(expected_paths[0].last().unwrap(), expected_paths.len());
 
        let mut per_path_msgs: Vec<((msgs::UpdateFailHTLC, msgs::CommitmentSigned), PublicKey)> = Vec::with_capacity(expected_paths.len());
@@ -1861,7 +1900,7 @@ pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expe
        }
 
        // Ensure that fail_htlc_backwards is idempotent.
-       assert!(!expected_paths[0].last().unwrap().node.fail_htlc_backwards(&our_payment_hash));
+       expected_paths[0].last().unwrap().node.fail_htlc_backwards(&our_payment_hash);
        assert!(expected_paths[0].last().unwrap().node.get_and_clear_pending_events().is_empty());
        assert!(expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events().is_empty());
        check_added_monitors!(expected_paths[0].last().unwrap(), 0);