Merge pull request #2924 from tnull/2024-03-add-user-channel-id-to-payment-forwarded
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index 5246737b5c747c29cb3068ff9d2163a710f547b9..5840fead943049a6ac2a26eb4e784253ad3bcffb 100644 (file)
@@ -2223,31 +2223,60 @@ macro_rules! expect_payment_path_successful {
        }
 }
 
+/// Returns the total fee earned by this HTLC forward, in msat.
 pub fn expect_payment_forwarded<CM: AChannelManager, H: NodeHolder<CM=CM>>(
        event: Event, node: &H, prev_node: &H, next_node: &H, expected_fee: Option<u64>,
        expected_extra_fees_msat: Option<u64>, upstream_force_closed: bool,
-       downstream_force_closed: bool
-) {
+       downstream_force_closed: bool, allow_1_msat_fee_overpay: bool,
+) -> Option<u64> {
        match event {
                Event::PaymentForwarded {
-                       total_fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id,
-                       outbound_amount_forwarded_msat: _, skimmed_fee_msat
+                       prev_channel_id, next_channel_id, prev_user_channel_id, next_user_channel_id,
+                       total_fee_earned_msat, skimmed_fee_msat, claim_from_onchain_tx, ..
                } => {
-                       assert_eq!(total_fee_earned_msat, expected_fee);
+                       if allow_1_msat_fee_overpay {
+                               // Aggregating fees for blinded paths may result in a rounding error, causing slight
+                               // overpayment in fees.
+                               let actual_fee = total_fee_earned_msat.unwrap();
+                               let expected_fee = expected_fee.unwrap();
+                               assert!(actual_fee == expected_fee || actual_fee == expected_fee + 1);
+                       } else {
+                               assert_eq!(total_fee_earned_msat, expected_fee);
+                       }
 
                        // Check that the (knowingly) withheld amount is always less or equal to the expected
                        // overpaid amount.
                        assert!(skimmed_fee_msat == expected_extra_fees_msat);
                        if !upstream_force_closed {
                                // Is the event prev_channel_id in one of the channels between the two nodes?
-                               assert!(node.node().list_channels().iter().any(|x| x.counterparty.node_id == prev_node.node().get_our_node_id() && x.channel_id == prev_channel_id.unwrap()));
+                               assert!(node.node().list_channels().iter().any(|x|
+                                       x.counterparty.node_id == prev_node.node().get_our_node_id() &&
+                                       x.channel_id == prev_channel_id.unwrap() &&
+                                       x.user_channel_id == prev_user_channel_id.unwrap()
+                               ));
                        }
                        // We check for force closures since a force closed channel is removed from the
                        // node's channel list
                        if !downstream_force_closed {
-                               assert!(node.node().list_channels().iter().any(|x| x.counterparty.node_id == next_node.node().get_our_node_id() && x.channel_id == next_channel_id.unwrap()));
+                               // As documented, `next_user_channel_id` will only be `Some` if we didn't settle via an
+                               // onchain transaction, just as the `total_fee_earned_msat` field. Rather than
+                               // introducing yet another variable, we use the latter's state as a flag to detect
+                               // this and only check if it's `Some`.
+                               if total_fee_earned_msat.is_none() {
+                                       assert!(node.node().list_channels().iter().any(|x|
+                                               x.counterparty.node_id == next_node.node().get_our_node_id() &&
+                                               x.channel_id == next_channel_id.unwrap()
+                                       ));
+                               } else {
+                                       assert!(node.node().list_channels().iter().any(|x|
+                                               x.counterparty.node_id == next_node.node().get_our_node_id() &&
+                                               x.channel_id == next_channel_id.unwrap() &&
+                                               x.user_channel_id == next_user_channel_id.unwrap()
+                                       ));
+                               }
                        }
                        assert_eq!(claim_from_onchain_tx, downstream_force_closed);
+                       total_fee_earned_msat
                },
                _ => panic!("Unexpected event"),
        }
@@ -2260,7 +2289,7 @@ macro_rules! expect_payment_forwarded {
                assert_eq!(events.len(), 1);
                $crate::ln::functional_test_utils::expect_payment_forwarded(
                        events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee, None,
-                       $upstream_force_closed, $downstream_force_closed
+                       $upstream_force_closed, $downstream_force_closed, false
                );
        }
 }
@@ -2664,6 +2693,14 @@ pub struct ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
        pub expected_min_htlc_overpay: Vec<u32>,
        pub skip_last: bool,
        pub payment_preimage: PaymentPreimage,
+       // Allow forwarding nodes to have taken 1 msat more fee than expected based on the downstream
+       // fulfill amount.
+       //
+       // Necessary because our test utils calculate the expected fee for an intermediate node based on
+       // the amount was claimed in their downstream peer's fulfill, but blinded intermediate nodes
+       // calculate their fee based on the inbound amount from their upstream peer, causing a difference
+       // in rounding.
+       pub allow_1_msat_fee_overpay: bool,
 }
 
 impl<'a, 'b, 'c, 'd> ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
@@ -2674,6 +2711,7 @@ impl<'a, 'b, 'c, 'd> ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
                Self {
                        origin_node, expected_paths, expected_extra_fees: vec![0; expected_paths.len()],
                        expected_min_htlc_overpay: vec![0; expected_paths.len()], skip_last: false, payment_preimage,
+                       allow_1_msat_fee_overpay: false,
                }
        }
        pub fn skip_last(mut self, skip_last: bool) -> Self {
@@ -2688,12 +2726,16 @@ impl<'a, 'b, 'c, 'd> ClaimAlongRouteArgs<'a, 'b, 'c, 'd> {
                self.expected_min_htlc_overpay = extra_fees;
                self
        }
+       pub fn allow_1_msat_fee_overpay(mut self) -> Self {
+               self.allow_1_msat_fee_overpay = true;
+               self
+       }
 }
 
 pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArgs) -> u64 {
        let ClaimAlongRouteArgs {
                origin_node, expected_paths, expected_extra_fees, expected_min_htlc_overpay, skip_last,
-               payment_preimage: our_payment_preimage
+               payment_preimage: our_payment_preimage, allow_1_msat_fee_overpay,
        } = args;
        let claim_event = expected_paths[0].last().unwrap().node.get_and_clear_pending_events();
        assert_eq!(claim_event.len(), 1);
@@ -2813,10 +2855,10 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
                                        }
                                        let mut events = $node.node.get_and_clear_pending_events();
                                        assert_eq!(events.len(), 1);
-                                       expect_payment_forwarded(events.pop().unwrap(), *$node, $next_node, $prev_node,
-                                               Some(fee as u64), expected_extra_fee, false, false);
-                                       expected_total_fee_msat += fee as u64;
-                                       fwd_amt_msat += fee as u64;
+                                       let actual_fee = expect_payment_forwarded(events.pop().unwrap(), *$node, $next_node, $prev_node,
+                                               Some(fee as u64), expected_extra_fee, false, false, allow_1_msat_fee_overpay);
+                                       expected_total_fee_msat += actual_fee.unwrap();
+                                       fwd_amt_msat += actual_fee.unwrap();
                                        check_added_monitors!($node, 1);
                                        let new_next_msgs = if $new_msgs {
                                                let events = $node.node.get_and_clear_pending_msg_events();