X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Fln%2Ffunctional_test_utils.rs;h=170d5b0b022b5b06a313c048e61cf828287181bb;hb=e594021052e251659e33c0f5e82c7ec2b9e99c18;hp=e7fc68924efbda8209ba103b94d8eb54b932be3b;hpb=7f177bb6dd6fde177176b7be1ec62ef533038dd3;p=rust-lightning diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index e7fc6892..170d5b0b 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -2203,14 +2203,19 @@ macro_rules! expect_payment_path_successful { pub fn expect_payment_forwarded>( event: Event, node: &H, prev_node: &H, next_node: &H, expected_fee: Option, - upstream_force_closed: bool, downstream_force_closed: bool + expected_extra_fees_msat: Option, upstream_force_closed: bool, + downstream_force_closed: bool ) { match event { Event::PaymentForwarded { - fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id, - outbound_amount_forwarded_msat: _ + total_fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id, + outbound_amount_forwarded_msat: _, skimmed_fee_msat } => { - assert_eq!(fee_earned_msat, expected_fee); + assert_eq!(total_fee_earned_msat, expected_fee); + + // Check that the (knowingly) withheld amount is always less or equal to the expected + // overpaid amount. + assert!(skimmed_fee_msat == expected_extra_fees_msat); if !upstream_force_closed { // Is the event prev_channel_id in one of the channels between the two nodes? assert!(node.node().list_channels().iter().any(|x| x.counterparty.node_id == prev_node.node().get_our_node_id() && x.channel_id == prev_channel_id.unwrap())); @@ -2226,13 +2231,15 @@ pub fn expect_payment_forwarded>( } } +#[macro_export] macro_rules! expect_payment_forwarded { ($node: expr, $prev_node: expr, $next_node: expr, $expected_fee: expr, $upstream_force_closed: expr, $downstream_force_closed: expr) => { let mut events = $node.node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); $crate::ln::functional_test_utils::expect_payment_forwarded( - events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee, - $upstream_force_closed, $downstream_force_closed); + events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee, None, + $upstream_force_closed, $downstream_force_closed + ); } } @@ -2552,24 +2559,54 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>( origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_preimage: PaymentPreimage ) -> u64 { - let extra_fees = vec![0; expected_paths.len()]; - do_claim_payment_along_route_with_extra_penultimate_hop_fees(origin_node, expected_paths, - &extra_fees[..], skip_last, our_payment_preimage) -} - -pub fn do_claim_payment_along_route_with_extra_penultimate_hop_fees<'a, 'b, 'c>( - origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], expected_extra_fees: - &[u32], skip_last: bool, our_payment_preimage: PaymentPreimage -) -> u64 { - assert_eq!(expected_paths.len(), expected_extra_fees.len()); for path in expected_paths.iter() { assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id()); } expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage); - pass_claimed_payment_along_route(origin_node, expected_paths, expected_extra_fees, skip_last, our_payment_preimage) + pass_claimed_payment_along_route( + ClaimAlongRouteArgs::new(origin_node, expected_paths, our_payment_preimage) + .skip_last(skip_last) + ) +} + +pub struct ClaimAlongRouteArgs<'a, 'b, 'c, 'd> { + pub origin_node: &'a Node<'b, 'c, 'd>, + pub expected_paths: &'a [&'a [&'a Node<'b, 'c, 'd>]], + pub expected_extra_fees: Vec, + pub expected_min_htlc_overpay: Vec, + pub skip_last: bool, + pub payment_preimage: PaymentPreimage, } -pub fn pass_claimed_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], expected_extra_fees: &[u32], skip_last: bool, our_payment_preimage: PaymentPreimage) -> u64 { +impl<'a, 'b, 'c, 'd> ClaimAlongRouteArgs<'a, 'b, 'c, 'd> { + pub fn new( + origin_node: &'a Node<'b, 'c, 'd>, expected_paths: &'a [&'a [&'a Node<'b, 'c, 'd>]], + payment_preimage: PaymentPreimage, + ) -> Self { + Self { + origin_node, expected_paths, expected_extra_fees: vec![0; expected_paths.len()], + expected_min_htlc_overpay: vec![0; expected_paths.len()], skip_last: false, payment_preimage, + } + } + pub fn skip_last(mut self, skip_last: bool) -> Self { + self.skip_last = skip_last; + self + } + pub fn with_expected_extra_fees(mut self, extra_fees: Vec) -> Self { + self.expected_extra_fees = extra_fees; + self + } + pub fn with_expected_min_htlc_overpay(mut self, extra_fees: Vec) -> Self { + self.expected_min_htlc_overpay = extra_fees; + self + } +} + +pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArgs) -> u64 { + let ClaimAlongRouteArgs { + origin_node, expected_paths, expected_extra_fees, expected_min_htlc_overpay, skip_last, + payment_preimage: our_payment_preimage + } = args; let claim_event = expected_paths[0].last().unwrap().node.get_and_clear_pending_events(); assert_eq!(claim_event.len(), 1); match claim_event[0] { @@ -2666,8 +2703,17 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, ' channel.context().config().forwarding_fee_base_msat } }; - if $idx == 1 { fee += expected_extra_fees[i]; } - expect_payment_forwarded!(*$node, $next_node, $prev_node, Some(fee as u64), false, false); + + let mut expected_extra_fee = None; + if $idx == 1 { + fee += expected_extra_fees[i]; + fee += expected_min_htlc_overpay[i]; + expected_extra_fee = if expected_extra_fees[i] > 0 { Some(expected_extra_fees[i] as u64) } else { None }; + } + let mut events = $node.node.get_and_clear_pending_events(); + assert_eq!(events.len(), 1); + expect_payment_forwarded(events.pop().unwrap(), *$node, $next_node, $prev_node, + Some(fee as u64), expected_extra_fee, false, false); expected_total_fee_msat += fee as u64; check_added_monitors!($node, 1); let new_next_msgs = if $new_msgs { @@ -2934,7 +2980,7 @@ pub fn create_node_cfgs_with_persisters<'a>(node_count: usize, chanmon_cfgs: &'a tx_broadcaster: &chanmon_cfgs[i].tx_broadcaster, fee_estimator: &chanmon_cfgs[i].fee_estimator, router: test_utils::TestRouter::new(network_graph.clone(), &chanmon_cfgs[i].logger, &chanmon_cfgs[i].scorer), - message_router: test_utils::TestMessageRouter::new(network_graph.clone()), + message_router: test_utils::TestMessageRouter::new(network_graph.clone(), &chanmon_cfgs[i].keys_manager), chain_monitor, keys_manager: &chanmon_cfgs[i].keys_manager, node_seed: seed,