Expose `skimmed_fee_msat` in `PaymentForwarded`
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index 0dbf0f01cb73ebe3ec4bb373e9038d0654eec963..e53fd6b0c1e61908ab7c8f758c34555852fe6ac6 100644 (file)
@@ -2203,14 +2203,19 @@ macro_rules! expect_payment_path_successful {
 
 pub fn expect_payment_forwarded<CM: AChannelManager, H: NodeHolder<CM=CM>>(
        event: Event, node: &H, prev_node: &H, next_node: &H, expected_fee: Option<u64>,
-       upstream_force_closed: bool, downstream_force_closed: bool
+       expected_extra_fees_msat: Option<u64>, upstream_force_closed: bool,
+       downstream_force_closed: bool
 ) {
        match event {
                Event::PaymentForwarded {
-                       fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id,
-                       outbound_amount_forwarded_msat: _
+                       total_fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id,
+                       outbound_amount_forwarded_msat: _, skimmed_fee_msat
                } => {
-                       assert_eq!(fee_earned_msat, expected_fee);
+                       assert_eq!(total_fee_earned_msat, expected_fee);
+
+                       // Check that the (knowingly) withheld amount is always less or equal to the expected
+                       // overpaid amount.
+                       assert!(skimmed_fee_msat == expected_extra_fees_msat);
                        if !upstream_force_closed {
                                // Is the event prev_channel_id in one of the channels between the two nodes?
                                assert!(node.node().list_channels().iter().any(|x| x.counterparty.node_id == prev_node.node().get_our_node_id() && x.channel_id == prev_channel_id.unwrap()));
@@ -2226,13 +2231,15 @@ pub fn expect_payment_forwarded<CM: AChannelManager, H: NodeHolder<CM=CM>>(
        }
 }
 
+#[macro_export]
 macro_rules! expect_payment_forwarded {
        ($node: expr, $prev_node: expr, $next_node: expr, $expected_fee: expr, $upstream_force_closed: expr, $downstream_force_closed: expr) => {
                let mut events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                $crate::ln::functional_test_utils::expect_payment_forwarded(
-                       events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee,
-                       $upstream_force_closed, $downstream_force_closed);
+                       events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee, None,
+                       $upstream_force_closed, $downstream_force_closed
+               );
        }
 }
 
@@ -2696,11 +2703,17 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
                                                        channel.context().config().forwarding_fee_base_msat
                                                }
                                        };
+
+                                       let mut expected_extra_fee = None;
                                        if $idx == 1 {
                                                fee += expected_extra_fees[i];
                                                fee += expected_min_htlc_overpay[i];
+                                               expected_extra_fee = if expected_extra_fees[i] > 0 { Some(expected_extra_fees[i] as u64) } else { None };
                                        }
-                                       expect_payment_forwarded!(*$node, $next_node, $prev_node, Some(fee as u64), false, false);
+                                       let mut events = $node.node.get_and_clear_pending_events();
+                                       assert_eq!(events.len(), 1);
+                                       expect_payment_forwarded(events.pop().unwrap(), *$node, $next_node, $prev_node,
+                                               Some(fee as u64), expected_extra_fee, false, false);
                                        expected_total_fee_msat += fee as u64;
                                        check_added_monitors!($node, 1);
                                        let new_next_msgs = if $new_msgs {