Store total payment amount in ClaimableHTLC explicitly
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 268676173df8a17174fd24a8b11dfb7438f45fc3..23426e2194545e9b96527a96f8ba3733cf8a5124 100644 (file)
@@ -459,6 +459,7 @@ struct ClaimableHTLC {
        cltv_expiry: u32,
        value: u64,
        onion_payload: OnionPayload,
+       total_msat: u64,
 }
 
 /// A payment identifier used to uniquely identify a payment to LDK.
@@ -3172,11 +3173,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                        HTLCForwardInfo::AddHTLC { prev_short_channel_id, prev_htlc_id, forward_info: PendingHTLCInfo {
                                                                        routing, incoming_shared_secret, payment_hash, amt_to_forward, .. },
                                                                        prev_funding_outpoint } => {
-                                                               let (cltv_expiry, onion_payload) = match routing {
+                                                               let (cltv_expiry, total_msat, onion_payload) = match routing {
                                                                        PendingHTLCRouting::Receive { payment_data, incoming_cltv_expiry } =>
-                                                                               (incoming_cltv_expiry, OnionPayload::Invoice(payment_data)),
+                                                                               (incoming_cltv_expiry, payment_data.total_msat, OnionPayload::Invoice(payment_data)),
                                                                        PendingHTLCRouting::ReceiveKeysend { payment_preimage, incoming_cltv_expiry } =>
-                                                                               (incoming_cltv_expiry, OnionPayload::Spontaneous(payment_preimage)),
+                                                                               (incoming_cltv_expiry, amt_to_forward, OnionPayload::Spontaneous(payment_preimage)),
                                                                        _ => {
                                                                                panic!("short_channel_id == 0 should imply any pending_forward entries are of type Receive");
                                                                        }
@@ -3189,6 +3190,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                incoming_packet_shared_secret: incoming_shared_secret,
                                                                        },
                                                                        value: amt_to_forward,
+                                                                       total_msat,
                                                                        cltv_expiry,
                                                                        onion_payload,
                                                                };
@@ -3212,7 +3214,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                                                                macro_rules! check_total_value {
                                                                        ($payment_data_total_msat: expr, $payment_secret: expr, $payment_preimage: expr) => {{
-                                                                               let mut total_value = 0;
                                                                                let mut payment_received_generated = false;
                                                                                let htlcs = channel_state.claimable_htlcs.entry(payment_hash)
                                                                                        .or_insert(Vec::new());
@@ -3223,14 +3224,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                                continue
                                                                                        }
                                                                                }
-                                                                               htlcs.push(claimable_htlc);
+                                                                               let mut total_value = claimable_htlc.value;
                                                                                for htlc in htlcs.iter() {
                                                                                        total_value += htlc.value;
                                                                                        match &htlc.onion_payload {
-                                                                                               OnionPayload::Invoice(htlc_payment_data) => {
-                                                                                                       if htlc_payment_data.total_msat != $payment_data_total_msat {
+                                                                                               OnionPayload::Invoice(_) => {
+                                                                                                       if htlc.total_msat != claimable_htlc.total_msat {
                                                                                                                log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the HTLCs had inconsistent total values (eg {} and {})",
-                                                                                                                       log_bytes!(payment_hash.0), $payment_data_total_msat, htlc_payment_data.total_msat);
+                                                                                                                       log_bytes!(payment_hash.0), claimable_htlc.total_msat, htlc.total_msat);
                                                                                                                total_value = msgs::MAX_VALUE_MSAT;
                                                                                                        }
                                                                                                        if total_value >= msgs::MAX_VALUE_MSAT { break; }
@@ -3238,13 +3239,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                                _ => unreachable!(),
                                                                                        }
                                                                                }
-                                                                               if total_value >= msgs::MAX_VALUE_MSAT || total_value > $payment_data_total_msat {
+                                                                               if total_value >= msgs::MAX_VALUE_MSAT || total_value > claimable_htlc.total_msat {
                                                                                        log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the total value {} ran over expected value {} (or HTLCs were inconsistent)",
-                                                                                               log_bytes!(payment_hash.0), total_value, $payment_data_total_msat);
+                                                                                               log_bytes!(payment_hash.0), total_value, claimable_htlc.total_msat);
                                                                                        for htlc in htlcs.iter() {
                                                                                                fail_htlc!(htlc);
                                                                                        }
-                                                                               } else if total_value == $payment_data_total_msat {
+                                                                               } else if total_value == claimable_htlc.total_msat {
                                                                                        new_events.push(events::Event::PaymentReceived {
                                                                                                payment_hash,
                                                                                                purpose: events::PaymentPurpose::InvoicePayment {
@@ -3259,6 +3260,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                        // payment value yet, wait until we receive more
                                                                                        // MPP parts.
                                                                                }
+                                                                               htlcs.push(claimable_htlc);
                                                                                payment_received_generated
                                                                        }}
                                                                }
@@ -5927,13 +5929,14 @@ impl Writeable for ClaimableHTLC {
                        OnionPayload::Invoice(_) => None,
                        OnionPayload::Spontaneous(preimage) => Some(preimage.clone()),
                };
-               write_tlv_fields!
-               (writer,
-                {
-                  (0, self.prev_hop, required), (2, self.value, required),
-                  (4, payment_data, option), (6, self.cltv_expiry, required),
-                        (8, keysend_preimage, option),
-                });
+               write_tlv_fields!(writer, {
+                       (0, self.prev_hop, required),
+                       (1, self.total_msat, required),
+                       (2, self.value, required),
+                       (4, payment_data, option),
+                       (6, self.cltv_expiry, required),
+                       (8, keysend_preimage, option),
+               });
                Ok(())
        }
 }
@@ -5944,31 +5947,40 @@ impl Readable for ClaimableHTLC {
                let mut value = 0;
                let mut payment_data: Option<msgs::FinalOnionHopData> = None;
                let mut cltv_expiry = 0;
+               let mut total_msat = None;
                let mut keysend_preimage: Option<PaymentPreimage> = None;
-               read_tlv_fields!
-               (reader,
-                {
-                  (0, prev_hop, required), (2, value, required),
-                  (4, payment_data, option), (6, cltv_expiry, required),
-                        (8, keysend_preimage, option)
-                });
+               read_tlv_fields!(reader, {
+                       (0, prev_hop, required),
+                       (1, total_msat, option),
+                       (2, value, required),
+                       (4, payment_data, option),
+                       (6, cltv_expiry, required),
+                       (8, keysend_preimage, option)
+               });
                let onion_payload = match keysend_preimage {
                        Some(p) => {
                                if payment_data.is_some() {
                                        return Err(DecodeError::InvalidValue)
                                }
+                               if total_msat.is_none() {
+                                       total_msat = Some(value);
+                               }
                                OnionPayload::Spontaneous(p)
                        },
                        None => {
                                if payment_data.is_none() {
                                        return Err(DecodeError::InvalidValue)
                                }
+                               if total_msat.is_none() {
+                                       total_msat = Some(payment_data.as_ref().unwrap().total_msat);
+                               }
                                OnionPayload::Invoice(payment_data.unwrap())
                        },
                };
                Ok(Self {
                        prev_hop: prev_hop.0.unwrap(),
                        value,
+                       total_msat: total_msat.unwrap(),
                        onion_payload,
                        cltv_expiry,
                })