Merge pull request #2062 from alecchendev/2023-02-allow-overshoot-mpp
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 6bd9148e98734d5ce833f5e6d9e7db3d14256280..c8adc721235470dc37bce5a9a2651fc8e74dc8ed 100644 (file)
@@ -120,7 +120,10 @@ pub(super) struct PendingHTLCInfo {
        pub(super) routing: PendingHTLCRouting,
        pub(super) incoming_shared_secret: [u8; 32],
        payment_hash: PaymentHash,
+       /// Amount received
        pub(super) incoming_amt_msat: Option<u64>, // Added in 0.0.113
+       /// Sender intended amount to forward or receive (actual amount received
+       /// may overshoot this in either case)
        pub(super) outgoing_amt_msat: u64,
        pub(super) outgoing_cltv_value: u32,
 }
@@ -192,9 +195,15 @@ struct ClaimableHTLC {
        cltv_expiry: u32,
        /// The amount (in msats) of this MPP part
        value: u64,
+       /// The amount (in msats) that the sender intended to be sent in this MPP
+       /// part (used for validating total MPP amount)
+       sender_intended_value: u64,
        onion_payload: OnionPayload,
        timer_ticks: u8,
-       /// The sum total of all MPP parts
+       /// The total value received for a payment (sum of all MPP parts if the payment is a MPP).
+       /// Gets set to the amount reported when pushing [`Event::PaymentClaimable`].
+       total_value_received: Option<u64>,
+       /// The sender intended sum total of all MPP parts specified in the onion
        total_msat: u64,
 }
 
@@ -2092,9 +2101,9 @@ where
                payment_hash: PaymentHash, amt_msat: u64, cltv_expiry: u32, phantom_shared_secret: Option<[u8; 32]>) -> Result<PendingHTLCInfo, ReceiveError>
        {
                // final_incorrect_cltv_expiry
-               if hop_data.outgoing_cltv_value != cltv_expiry {
+               if hop_data.outgoing_cltv_value > cltv_expiry {
                        return Err(ReceiveError {
-                               msg: "Upstream node set CLTV to the wrong value",
+                               msg: "Upstream node set CLTV to less than the CLTV set by the sender",
                                err_code: 18,
                                err_data: cltv_expiry.to_be_bytes().to_vec()
                        })
@@ -2178,7 +2187,7 @@ where
                        payment_hash,
                        incoming_shared_secret: shared_secret,
                        incoming_amt_msat: Some(amt_msat),
-                       outgoing_amt_msat: amt_msat,
+                       outgoing_amt_msat: hop_data.amt_to_forward,
                        outgoing_cltv_value: hop_data.outgoing_cltv_value,
                })
        }
@@ -2660,7 +2669,7 @@ where
        }
 
        #[cfg(test)]
-       fn test_send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>, payment_id: PaymentId, recv_value_msat: Option<u64>, onion_session_privs: Vec<[u8; 32]>) -> Result<(), PaymentSendFailure> {
+       pub(super) fn test_send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>, payment_id: PaymentId, recv_value_msat: Option<u64>, onion_session_privs: Vec<[u8; 32]>) -> Result<(), PaymentSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments.test_send_payment_internal(route, payment_hash, payment_secret, keysend_preimage, payment_id, recv_value_msat, onion_session_privs, &self.node_signer, best_block_height,
@@ -3258,7 +3267,7 @@ where
                                                        HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo {
                                                                prev_short_channel_id, prev_htlc_id, prev_funding_outpoint, prev_user_channel_id,
                                                                forward_info: PendingHTLCInfo {
-                                                                       routing, incoming_shared_secret, payment_hash, outgoing_amt_msat, ..
+                                                                       routing, incoming_shared_secret, payment_hash, incoming_amt_msat, outgoing_amt_msat, ..
                                                                }
                                                        }) => {
                                                                let (cltv_expiry, onion_payload, payment_data, phantom_shared_secret) = match routing {
@@ -3272,7 +3281,7 @@ where
                                                                                panic!("short_channel_id == 0 should imply any pending_forward entries are of type Receive");
                                                                        }
                                                                };
-                                                               let claimable_htlc = ClaimableHTLC {
+                                                               let mut claimable_htlc = ClaimableHTLC {
                                                                        prev_hop: HTLCPreviousHopData {
                                                                                short_channel_id: prev_short_channel_id,
                                                                                outpoint: prev_funding_outpoint,
@@ -3280,8 +3289,13 @@ where
                                                                                incoming_packet_shared_secret: incoming_shared_secret,
                                                                                phantom_shared_secret,
                                                                        },
-                                                                       value: outgoing_amt_msat,
+                                                                       // We differentiate the received value from the sender intended value
+                                                                       // if possible so that we don't prematurely mark MPP payments complete
+                                                                       // if routing nodes overpay
+                                                                       value: incoming_amt_msat.unwrap_or(outgoing_amt_msat),
+                                                                       sender_intended_value: outgoing_amt_msat,
                                                                        timer_ticks: 0,
+                                                                       total_value_received: None,
                                                                        total_msat: if let Some(data) = &payment_data { data.total_msat } else { outgoing_amt_msat },
                                                                        cltv_expiry,
                                                                        onion_payload,
@@ -3326,7 +3340,7 @@ where
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
                                                                                        continue
                                                                                }
-                                                                               let (_, htlcs) = claimable_payments.claimable_htlcs.entry(payment_hash)
+                                                                               let (_, ref mut htlcs) = claimable_payments.claimable_htlcs.entry(payment_hash)
                                                                                        .or_insert_with(|| (purpose(), Vec::new()));
                                                                                if htlcs.len() == 1 {
                                                                                        if let OnionPayload::Spontaneous(_) = htlcs[0].onion_payload {
@@ -3335,9 +3349,9 @@ where
                                                                                                continue
                                                                                        }
                                                                                }
-                                                                               let mut total_value = claimable_htlc.value;
+                                                                               let mut total_value = claimable_htlc.sender_intended_value;
                                                                                for htlc in htlcs.iter() {
-                                                                                       total_value += htlc.value;
+                                                                                       total_value += htlc.sender_intended_value;
                                                                                        match &htlc.onion_payload {
                                                                                                OnionPayload::Invoice { .. } => {
                                                                                                        if htlc.total_msat != $payment_data.total_msat {
@@ -3350,18 +3364,24 @@ where
                                                                                                _ => unreachable!(),
                                                                                        }
                                                                                }
-                                                                               if total_value >= msgs::MAX_VALUE_MSAT || total_value > $payment_data.total_msat {
-                                                                                       log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the total value {} ran over expected value {} (or HTLCs were inconsistent)",
-                                                                                               log_bytes!(payment_hash.0), total_value, $payment_data.total_msat);
+                                                                               // The condition determining whether an MPP is complete must
+                                                                               // match exactly the condition used in `timer_tick_occurred`
+                                                                               if total_value >= msgs::MAX_VALUE_MSAT {
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
-                                                                               } else if total_value == $payment_data.total_msat {
+                                                                               } else if total_value - claimable_htlc.sender_intended_value >= $payment_data.total_msat {
+                                                                                       log_trace!(self.logger, "Failing HTLC with payment_hash {} as payment is already claimable",
+                                                                                               log_bytes!(payment_hash.0));
+                                                                                       fail_htlc!(claimable_htlc, payment_hash);
+                                                                               } else if total_value >= $payment_data.total_msat {
                                                                                        let prev_channel_id = prev_funding_outpoint.to_channel_id();
                                                                                        htlcs.push(claimable_htlc);
+                                                                                       let amount_msat = htlcs.iter().map(|htlc| htlc.value).sum();
+                                                                                       htlcs.iter_mut().for_each(|htlc| htlc.total_value_received = Some(amount_msat));
                                                                                        new_events.push(events::Event::PaymentClaimable {
                                                                                                receiver_node_id: Some(receiver_node_id),
                                                                                                payment_hash,
                                                                                                purpose: purpose(),
-                                                                                               amount_msat: total_value,
+                                                                                               amount_msat,
                                                                                                via_channel_id: Some(prev_channel_id),
                                                                                                via_user_channel_id: Some(prev_user_channel_id),
                                                                                        });
@@ -3415,13 +3435,15 @@ where
                                                                                                }
                                                                                                match claimable_payments.claimable_htlcs.entry(payment_hash) {
                                                                                                        hash_map::Entry::Vacant(e) => {
+                                                                                                               let amount_msat = claimable_htlc.value;
+                                                                                                               claimable_htlc.total_value_received = Some(amount_msat);
                                                                                                                let purpose = events::PaymentPurpose::SpontaneousPayment(preimage);
                                                                                                                e.insert((purpose.clone(), vec![claimable_htlc]));
                                                                                                                let prev_channel_id = prev_funding_outpoint.to_channel_id();
                                                                                                                new_events.push(events::Event::PaymentClaimable {
                                                                                                                        receiver_node_id: Some(receiver_node_id),
                                                                                                                        payment_hash,
-                                                                                                                       amount_msat: outgoing_amt_msat,
+                                                                                                                       amount_msat,
                                                                                                                        purpose,
                                                                                                                        via_channel_id: Some(prev_channel_id),
                                                                                                                        via_user_channel_id: Some(prev_user_channel_id),
@@ -3681,7 +3703,9 @@ where
                                if let OnionPayload::Invoice { .. } = htlcs[0].onion_payload {
                                        // Check if we've received all the parts we need for an MPP (the value of the parts adds to total_msat).
                                        // In this case we're not going to handle any timeouts of the parts here.
-                                       if htlcs[0].total_msat == htlcs.iter().fold(0, |total, htlc| total + htlc.value) {
+                                       // This condition determining whether the MPP is complete here must match
+                                       // exactly the condition used in `process_pending_htlc_forwards`.
+                                       if htlcs[0].total_msat <= htlcs.iter().fold(0, |total, htlc| total + htlc.sender_intended_value) {
                                                return true;
                                        } else if htlcs.into_iter().any(|htlc| {
                                                htlc.timer_ticks += 1;
@@ -3960,6 +3984,7 @@ where
                // provide the preimage, so worrying too much about the optimal handling isn't worth
                // it.
                let mut claimable_amt_msat = 0;
+               let mut prev_total_msat = None;
                let mut expected_amt_msat = None;
                let mut valid_mpp = true;
                let mut errs = Vec::new();
@@ -3987,14 +4012,22 @@ where
                                break;
                        }
 
-                       if expected_amt_msat.is_some() && expected_amt_msat != Some(htlc.total_msat) {
-                               log_error!(self.logger, "Somehow ended up with an MPP payment with different total amounts - this should not be reachable!");
+                       if prev_total_msat.is_some() && prev_total_msat != Some(htlc.total_msat) {
+                               log_error!(self.logger, "Somehow ended up with an MPP payment with different expected total amounts - this should not be reachable!");
+                               debug_assert!(false);
+                               valid_mpp = false;
+                               break;
+                       }
+                       prev_total_msat = Some(htlc.total_msat);
+
+                       if expected_amt_msat.is_some() && expected_amt_msat != htlc.total_value_received {
+                               log_error!(self.logger, "Somehow ended up with an MPP payment with different received total amounts - this should not be reachable!");
                                debug_assert!(false);
                                valid_mpp = false;
                                break;
                        }
+                       expected_amt_msat = htlc.total_value_received;
 
-                       expected_amt_msat = Some(htlc.total_msat);
                        if let OnionPayload::Spontaneous(_) = &htlc.onion_payload {
                                // We don't currently support MPP for spontaneous payments, so just check
                                // that there's one payment here and move on.
@@ -6794,7 +6827,9 @@ impl Writeable for ClaimableHTLC {
                        (0, self.prev_hop, required),
                        (1, self.total_msat, required),
                        (2, self.value, required),
+                       (3, self.sender_intended_value, required),
                        (4, payment_data, option),
+                       (5, self.total_value_received, option),
                        (6, self.cltv_expiry, required),
                        (8, keysend_preimage, option),
                });
@@ -6806,15 +6841,19 @@ impl Readable for ClaimableHTLC {
        fn read<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
                let mut prev_hop = crate::util::ser::RequiredWrapper(None);
                let mut value = 0;
+               let mut sender_intended_value = None;
                let mut payment_data: Option<msgs::FinalOnionHopData> = None;
                let mut cltv_expiry = 0;
+               let mut total_value_received = None;
                let mut total_msat = None;
                let mut keysend_preimage: Option<PaymentPreimage> = None;
                read_tlv_fields!(reader, {
                        (0, prev_hop, required),
                        (1, total_msat, option),
                        (2, value, required),
+                       (3, sender_intended_value, option),
                        (4, payment_data, option),
+                       (5, total_value_received, option),
                        (6, cltv_expiry, required),
                        (8, keysend_preimage, option)
                });
@@ -6842,6 +6881,8 @@ impl Readable for ClaimableHTLC {
                        prev_hop: prev_hop.0.unwrap(),
                        timer_ticks: 0,
                        value,
+                       sender_intended_value: sender_intended_value.unwrap_or(value),
+                       total_value_received,
                        total_msat: total_msat.unwrap(),
                        onion_payload,
                        cltv_expiry,