Allow overshooting `total_msat` for an MPP
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 199beeae525a6a01bcb21894120ac495030704b2..617ac3968c41df38a7ed6547cfb308950965e4b9 100644 (file)
@@ -2663,7 +2663,7 @@ where
        }
 
        #[cfg(test)]
-       fn test_send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>, payment_id: PaymentId, recv_value_msat: Option<u64>, onion_session_privs: Vec<[u8; 32]>) -> Result<(), PaymentSendFailure> {
+       pub(super) fn test_send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>, payment_id: PaymentId, recv_value_msat: Option<u64>, onion_session_privs: Vec<[u8; 32]>) -> Result<(), PaymentSendFailure> {
                let best_block_height = self.best_block.read().unwrap().height();
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                self.pending_outbound_payments.test_send_payment_internal(route, payment_hash, payment_secret, keysend_preimage, payment_id, recv_value_msat, onion_session_privs, &self.node_signer, best_block_height,
@@ -3354,11 +3354,13 @@ where
                                                                                                _ => unreachable!(),
                                                                                        }
                                                                                }
-                                                                               if total_value >= msgs::MAX_VALUE_MSAT || total_value > $payment_data.total_msat {
-                                                                                       log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the total value {} ran over expected value {} (or HTLCs were inconsistent)",
-                                                                                               log_bytes!(payment_hash.0), total_value, $payment_data.total_msat);
+                                                                               if total_value >= msgs::MAX_VALUE_MSAT {
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
-                                                                               } else if total_value == $payment_data.total_msat {
+                                                                               } else if total_value - claimable_htlc.value >= $payment_data.total_msat {
+                                                                                       log_trace!(self.logger, "Failing HTLC with payment_hash {} as payment is already claimable",
+                                                                                               log_bytes!(payment_hash.0));
+                                                                                       fail_htlc!(claimable_htlc, payment_hash);
+                                                                               } else if total_value >= $payment_data.total_msat {
                                                                                        let prev_channel_id = prev_funding_outpoint.to_channel_id();
                                                                                        htlcs.push(claimable_htlc);
                                                                                        let amount_msat = htlcs.iter().map(|htlc| htlc.value).sum();
@@ -3689,7 +3691,7 @@ where
                                if let OnionPayload::Invoice { .. } = htlcs[0].onion_payload {
                                        // Check if we've received all the parts we need for an MPP (the value of the parts adds to total_msat).
                                        // In this case we're not going to handle any timeouts of the parts here.
-                                       if htlcs[0].total_msat == htlcs.iter().fold(0, |total, htlc| total + htlc.value) {
+                                       if htlcs[0].total_msat <= htlcs.iter().fold(0, |total, htlc| total + htlc.value) {
                                                return true;
                                        } else if htlcs.into_iter().any(|htlc| {
                                                htlc.timer_ticks += 1;