Use onion amount `amt_to_forward` for MPP set calculation
authorAlec Chen <alecchendev@gmail.com>
Tue, 7 Mar 2023 01:51:44 +0000 (19:51 -0600)
committerAlec Chen <alecchendev@gmail.com>
Tue, 28 Mar 2023 22:21:09 +0000 (17:21 -0500)
If routing nodes take less fees and pay the final node more than
`amt_to_forward`, the receiver may see that `total_msat` has been met
before all of the sender's intended HTLCs have arrived. The receiver
may then prematurely claim the payment and release the payment hash,
allowing routing nodes to claim the remaining HTLCs. Using the onion
value `amt_to_forward` to determine when `total_msat` has been met
allows the sender to control the set total.

lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_tests.rs

index fb1f6b4a8481d33c82286c5f52ddef7c16f6d463..c8adc721235470dc37bce5a9a2651fc8e74dc8ed 100644 (file)
@@ -120,7 +120,10 @@ pub(super) struct PendingHTLCInfo {
        pub(super) routing: PendingHTLCRouting,
        pub(super) incoming_shared_secret: [u8; 32],
        payment_hash: PaymentHash,
+       /// Amount received
        pub(super) incoming_amt_msat: Option<u64>, // Added in 0.0.113
+       /// Sender intended amount to forward or receive (actual amount received
+       /// may overshoot this in either case)
        pub(super) outgoing_amt_msat: u64,
        pub(super) outgoing_cltv_value: u32,
 }
@@ -192,6 +195,9 @@ struct ClaimableHTLC {
        cltv_expiry: u32,
        /// The amount (in msats) of this MPP part
        value: u64,
+       /// The amount (in msats) that the sender intended to be sent in this MPP
+       /// part (used for validating total MPP amount)
+       sender_intended_value: u64,
        onion_payload: OnionPayload,
        timer_ticks: u8,
        /// The total value received for a payment (sum of all MPP parts if the payment is a MPP).
@@ -2181,7 +2187,7 @@ where
                        payment_hash,
                        incoming_shared_secret: shared_secret,
                        incoming_amt_msat: Some(amt_msat),
-                       outgoing_amt_msat: amt_msat,
+                       outgoing_amt_msat: hop_data.amt_to_forward,
                        outgoing_cltv_value: hop_data.outgoing_cltv_value,
                })
        }
@@ -3261,7 +3267,7 @@ where
                                                        HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo {
                                                                prev_short_channel_id, prev_htlc_id, prev_funding_outpoint, prev_user_channel_id,
                                                                forward_info: PendingHTLCInfo {
-                                                                       routing, incoming_shared_secret, payment_hash, outgoing_amt_msat, ..
+                                                                       routing, incoming_shared_secret, payment_hash, incoming_amt_msat, outgoing_amt_msat, ..
                                                                }
                                                        }) => {
                                                                let (cltv_expiry, onion_payload, payment_data, phantom_shared_secret) = match routing {
@@ -3283,7 +3289,11 @@ where
                                                                                incoming_packet_shared_secret: incoming_shared_secret,
                                                                                phantom_shared_secret,
                                                                        },
-                                                                       value: outgoing_amt_msat,
+                                                                       // We differentiate the received value from the sender intended value
+                                                                       // if possible so that we don't prematurely mark MPP payments complete
+                                                                       // if routing nodes overpay
+                                                                       value: incoming_amt_msat.unwrap_or(outgoing_amt_msat),
+                                                                       sender_intended_value: outgoing_amt_msat,
                                                                        timer_ticks: 0,
                                                                        total_value_received: None,
                                                                        total_msat: if let Some(data) = &payment_data { data.total_msat } else { outgoing_amt_msat },
@@ -3339,9 +3349,9 @@ where
                                                                                                continue
                                                                                        }
                                                                                }
-                                                                               let mut total_value = claimable_htlc.value;
+                                                                               let mut total_value = claimable_htlc.sender_intended_value;
                                                                                for htlc in htlcs.iter() {
-                                                                                       total_value += htlc.value;
+                                                                                       total_value += htlc.sender_intended_value;
                                                                                        match &htlc.onion_payload {
                                                                                                OnionPayload::Invoice { .. } => {
                                                                                                        if htlc.total_msat != $payment_data.total_msat {
@@ -3354,9 +3364,11 @@ where
                                                                                                _ => unreachable!(),
                                                                                        }
                                                                                }
+                                                                               // The condition determining whether an MPP is complete must
+                                                                               // match exactly the condition used in `timer_tick_occurred`
                                                                                if total_value >= msgs::MAX_VALUE_MSAT {
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
-                                                                               } else if total_value - claimable_htlc.value >= $payment_data.total_msat {
+                                                                               } else if total_value - claimable_htlc.sender_intended_value >= $payment_data.total_msat {
                                                                                        log_trace!(self.logger, "Failing HTLC with payment_hash {} as payment is already claimable",
                                                                                                log_bytes!(payment_hash.0));
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
@@ -3431,7 +3443,7 @@ where
                                                                                                                new_events.push(events::Event::PaymentClaimable {
                                                                                                                        receiver_node_id: Some(receiver_node_id),
                                                                                                                        payment_hash,
-                                                                                                                       amount_msat: outgoing_amt_msat,
+                                                                                                                       amount_msat,
                                                                                                                        purpose,
                                                                                                                        via_channel_id: Some(prev_channel_id),
                                                                                                                        via_user_channel_id: Some(prev_user_channel_id),
@@ -3691,7 +3703,9 @@ where
                                if let OnionPayload::Invoice { .. } = htlcs[0].onion_payload {
                                        // Check if we've received all the parts we need for an MPP (the value of the parts adds to total_msat).
                                        // In this case we're not going to handle any timeouts of the parts here.
-                                       if htlcs[0].total_msat <= htlcs.iter().fold(0, |total, htlc| total + htlc.value) {
+                                       // This condition determining whether the MPP is complete here must match
+                                       // exactly the condition used in `process_pending_htlc_forwards`.
+                                       if htlcs[0].total_msat <= htlcs.iter().fold(0, |total, htlc| total + htlc.sender_intended_value) {
                                                return true;
                                        } else if htlcs.into_iter().any(|htlc| {
                                                htlc.timer_ticks += 1;
@@ -6813,6 +6827,7 @@ impl Writeable for ClaimableHTLC {
                        (0, self.prev_hop, required),
                        (1, self.total_msat, required),
                        (2, self.value, required),
+                       (3, self.sender_intended_value, required),
                        (4, payment_data, option),
                        (5, self.total_value_received, option),
                        (6, self.cltv_expiry, required),
@@ -6826,6 +6841,7 @@ impl Readable for ClaimableHTLC {
        fn read<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
                let mut prev_hop = crate::util::ser::RequiredWrapper(None);
                let mut value = 0;
+               let mut sender_intended_value = None;
                let mut payment_data: Option<msgs::FinalOnionHopData> = None;
                let mut cltv_expiry = 0;
                let mut total_value_received = None;
@@ -6835,6 +6851,7 @@ impl Readable for ClaimableHTLC {
                        (0, prev_hop, required),
                        (1, total_msat, option),
                        (2, value, required),
+                       (3, sender_intended_value, option),
                        (4, payment_data, option),
                        (5, total_value_received, option),
                        (6, cltv_expiry, required),
@@ -6864,6 +6881,7 @@ impl Readable for ClaimableHTLC {
                        prev_hop: prev_hop.0.unwrap(),
                        timer_ticks: 0,
                        value,
+                       sender_intended_value: sender_intended_value.unwrap_or(value),
                        total_value_received,
                        total_msat: total_msat.unwrap(),
                        onion_payload,
index 25b0f792ccec9e91a21f01cd2740183c29f84e38..b3f3c76454711c2f52d77100e4d1be36c30cd58d 100644 (file)
@@ -7926,6 +7926,101 @@ fn test_can_not_accept_unknown_inbound_channel() {
        }
 }
 
+#[test]
+fn test_onion_value_mpp_set_calculation() {
+       // Test that we use the onion value `amt_to_forward` when
+       // calculating whether we've reached the `total_msat` of an MPP
+       // by having a routing node forward more than `amt_to_forward`
+       // and checking that the receiving node doesn't generate
+       // a PaymentClaimable event too early
+       let node_count = 4;
+       let chanmon_cfgs = create_chanmon_cfgs(node_count);
+       let node_cfgs = create_node_cfgs(node_count, &chanmon_cfgs);
+       let node_chanmgrs = create_node_chanmgrs(node_count, &node_cfgs, &vec![None; node_count]);
+       let mut nodes = create_network(node_count, &node_cfgs, &node_chanmgrs);
+
+       let chan_1_id = create_announced_chan_between_nodes(&nodes, 0, 1).0.contents.short_channel_id;
+       let chan_2_id = create_announced_chan_between_nodes(&nodes, 0, 2).0.contents.short_channel_id;
+       let chan_3_id = create_announced_chan_between_nodes(&nodes, 1, 3).0.contents.short_channel_id;
+       let chan_4_id = create_announced_chan_between_nodes(&nodes, 2, 3).0.contents.short_channel_id;
+
+       let total_msat = 100_000;
+       let expected_paths: &[&[&Node]] = &[&[&nodes[1], &nodes[3]], &[&nodes[2], &nodes[3]]];
+       let (mut route, our_payment_hash, our_payment_preimage, our_payment_secret) = get_route_and_payment_hash!(&nodes[0], nodes[3], total_msat);
+       let sample_path = route.paths.pop().unwrap();
+
+       let mut path_1 = sample_path.clone();
+       path_1[0].pubkey = nodes[1].node.get_our_node_id();
+       path_1[0].short_channel_id = chan_1_id;
+       path_1[1].pubkey = nodes[3].node.get_our_node_id();
+       path_1[1].short_channel_id = chan_3_id;
+       path_1[1].fee_msat = 100_000;
+       route.paths.push(path_1);
+
+       let mut path_2 = sample_path.clone();
+       path_2[0].pubkey = nodes[2].node.get_our_node_id();
+       path_2[0].short_channel_id = chan_2_id;
+       path_2[1].pubkey = nodes[3].node.get_our_node_id();
+       path_2[1].short_channel_id = chan_4_id;
+       path_2[1].fee_msat = 1_000;
+       route.paths.push(path_2);
+
+       // Send payment
+       let payment_id = PaymentId(nodes[0].keys_manager.backing.get_secure_random_bytes());
+       let onion_session_privs = nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(our_payment_secret), payment_id, &route).unwrap();
+       nodes[0].node.test_send_payment_internal(&route, our_payment_hash, &Some(our_payment_secret), None, payment_id, Some(total_msat), onion_session_privs).unwrap();
+       check_added_monitors!(nodes[0], expected_paths.len());
+
+       let mut events = nodes[0].node.get_and_clear_pending_msg_events();
+       assert_eq!(events.len(), expected_paths.len());
+
+       // First path
+       let ev = remove_first_msg_event_to_node(&expected_paths[0][0].node.get_our_node_id(), &mut events);
+       let mut payment_event = SendEvent::from_event(ev);
+       let mut prev_node = &nodes[0];
+
+       for (idx, &node) in expected_paths[0].iter().enumerate() {
+               assert_eq!(node.node.get_our_node_id(), payment_event.node_id);
+
+               if idx == 0 { // routing node
+                       let session_priv = [3; 32];
+                       let height = nodes[0].best_block_info().1;
+                       let session_priv = SecretKey::from_slice(&session_priv).unwrap();
+                       let mut onion_keys = onion_utils::construct_onion_keys(&Secp256k1::new(), &route.paths[0], &session_priv).unwrap();
+                       let (mut onion_payloads, _, _) = onion_utils::build_onion_payloads(&route.paths[0], 100_000, &Some(our_payment_secret), height + 1, &None).unwrap();
+                       // Edit amt_to_forward to simulate the sender having set
+                       // the final amount and the routing node taking less fee
+                       onion_payloads[1].amt_to_forward = 99_000;
+                       let new_onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, [0; 32], &our_payment_hash);
+                       payment_event.msgs[0].onion_routing_packet = new_onion_packet;
+               }
+
+               node.node.handle_update_add_htlc(&prev_node.node.get_our_node_id(), &payment_event.msgs[0]);
+               check_added_monitors!(node, 0);
+               commitment_signed_dance!(node, prev_node, payment_event.commitment_msg, false);
+               expect_pending_htlcs_forwardable!(node);
+
+               if idx == 0 {
+                       let mut events_2 = node.node.get_and_clear_pending_msg_events();
+                       assert_eq!(events_2.len(), 1);
+                       check_added_monitors!(node, 1);
+                       payment_event = SendEvent::from_event(events_2.remove(0));
+                       assert_eq!(payment_event.msgs.len(), 1);
+               } else {
+                       let events_2 = node.node.get_and_clear_pending_events();
+                       assert!(events_2.is_empty());
+               }
+
+               prev_node = node;
+       }
+
+       // Second path
+       let ev = remove_first_msg_event_to_node(&expected_paths[1][0].node.get_our_node_id(), &mut events);
+       pass_along_path(&nodes[0], expected_paths[1], 101_000, our_payment_hash.clone(), Some(our_payment_secret), ev, true, None);
+
+       claim_payment_along_route(&nodes[0], expected_paths, false, our_payment_preimage);
+}
+
 fn do_test_overshoot_mpp(msat_amounts: &[u64], total_msat: u64) {
 
        let routing_node_count = msat_amounts.len();