]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Macro-ize checking that the total value of an MPP's parts is sane
authorValentine Wallace <vwallace@protonmail.com>
Mon, 22 Nov 2021 21:53:18 +0000 (16:53 -0500)
committerValentine Wallace <vwallace@protonmail.com>
Thu, 16 Dec 2021 23:30:52 +0000 (15:30 -0800)
This DRY-ed code will be used in upcoming commits when we stop storing inbound
payment data

lightning/src/ln/channelmanager.rs

index 1b0fa19e163a50811d6395e9ca138e2e47bffb8f..b9ee78499bcc6a5fecf749086af81b69b4277212 100644 (file)
@@ -2913,6 +2913,59 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                        }
                                                                }
 
+                                                               macro_rules! check_total_value {
+                                                                       ($payment_data_total_msat: expr, $payment_secret: expr, $payment_preimage: expr) => {{
+                                                                               let mut total_value = 0;
+                                                                               let mut payment_received_generated = false;
+                                                                               let htlcs = channel_state.claimable_htlcs.entry(payment_hash)
+                                                                                       .or_insert(Vec::new());
+                                                                               if htlcs.len() == 1 {
+                                                                                       if let OnionPayload::Spontaneous(_) = htlcs[0].onion_payload {
+                                                                                               log_trace!(self.logger, "Failing new HTLC with payment_hash {} as we already had an existing keysend HTLC with the same payment hash", log_bytes!(payment_hash.0));
+                                                                                               fail_htlc!(claimable_htlc);
+                                                                                               continue
+                                                                                       }
+                                                                               }
+                                                                               htlcs.push(claimable_htlc);
+                                                                               for htlc in htlcs.iter() {
+                                                                                       total_value += htlc.value;
+                                                                                       match &htlc.onion_payload {
+                                                                                               OnionPayload::Invoice(htlc_payment_data) => {
+                                                                                                       if htlc_payment_data.total_msat != $payment_data_total_msat {
+                                                                                                               log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the HTLCs had inconsistent total values (eg {} and {})",
+                                                                                                                       log_bytes!(payment_hash.0), $payment_data_total_msat, htlc_payment_data.total_msat);
+                                                                                                               total_value = msgs::MAX_VALUE_MSAT;
+                                                                                                       }
+                                                                                                       if total_value >= msgs::MAX_VALUE_MSAT { break; }
+                                                                                               },
+                                                                                               _ => unreachable!(),
+                                                                                       }
+                                                                               }
+                                                                               if total_value >= msgs::MAX_VALUE_MSAT || total_value > $payment_data_total_msat {
+                                                                                       log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the total value {} ran over expected value {} (or HTLCs were inconsistent)",
+                                                                                               log_bytes!(payment_hash.0), total_value, $payment_data_total_msat);
+                                                                                       for htlc in htlcs.iter() {
+                                                                                               fail_htlc!(htlc);
+                                                                                       }
+                                                                               } else if total_value == $payment_data_total_msat {
+                                                                                       new_events.push(events::Event::PaymentReceived {
+                                                                                               payment_hash,
+                                                                                               purpose: events::PaymentPurpose::InvoicePayment {
+                                                                                                       payment_preimage: $payment_preimage,
+                                                                                                       payment_secret: $payment_secret,
+                                                                                               },
+                                                                                               amt: total_value,
+                                                                                       });
+                                                                                       payment_received_generated = true;
+                                                                               } else {
+                                                                                       // Nothing to do - we haven't reached the total
+                                                                                       // payment value yet, wait until we receive more
+                                                                                       // MPP parts.
+                                                                               }
+                                                                               payment_received_generated
+                                                                       }}
+                                                               }
+
                                                                // Check that the payment hash and secret are known. Note that we
                                                                // MUST take care to handle the "unknown payment hash" and
                                                                // "incorrect payment secret" cases here identically or we'd expose
@@ -2962,54 +3015,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                                log_bytes!(payment_hash.0), payment_data.total_msat, inbound_payment.get().min_value_msat.unwrap());
                                                                                        fail_htlc!(claimable_htlc);
                                                                                } else {
-                                                                                       let mut total_value = 0;
-                                                                                       let htlcs = channel_state.claimable_htlcs.entry(payment_hash)
-                                                                                               .or_insert(Vec::new());
-                                                                                       if htlcs.len() == 1 {
-                                                                                               if let OnionPayload::Spontaneous(_) = htlcs[0].onion_payload {
-                                                                                                       log_trace!(self.logger, "Failing new HTLC with payment_hash {} as we already had an existing keysend HTLC with the same payment hash", log_bytes!(payment_hash.0));
-                                                                                                       fail_htlc!(claimable_htlc);
-                                                                                                       continue
-                                                                                               }
-                                                                                       }
-                                                                                       htlcs.push(claimable_htlc);
-                                                                                       for htlc in htlcs.iter() {
-                                                                                               total_value += htlc.value;
-                                                                                               match &htlc.onion_payload {
-                                                                                                       OnionPayload::Invoice(htlc_payment_data) => {
-                                                                                                               if htlc_payment_data.total_msat != payment_data.total_msat {
-                                                                                                                       log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the HTLCs had inconsistent total values (eg {} and {})",
-                                                                                                                                                                log_bytes!(payment_hash.0), payment_data.total_msat, htlc_payment_data.total_msat);
-                                                                                                                       total_value = msgs::MAX_VALUE_MSAT;
-                                                                                                               }
-                                                                                                               if total_value >= msgs::MAX_VALUE_MSAT { break; }
-                                                                                                       },
-                                                                                                       _ => unreachable!(),
-                                                                                               }
-                                                                                       }
-                                                                                       if total_value >= msgs::MAX_VALUE_MSAT || total_value > payment_data.total_msat {
-                                                                                               log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the total value {} ran over expected value {} (or HTLCs were inconsistent)",
-                                                                                                       log_bytes!(payment_hash.0), total_value, payment_data.total_msat);
-                                                                                               for htlc in htlcs.iter() {
-                                                                                                       fail_htlc!(htlc);
-                                                                                               }
-                                                                                       } else if total_value == payment_data.total_msat {
-                                                                                               new_events.push(events::Event::PaymentReceived {
-                                                                                                       payment_hash,
-                                                                                                       purpose: events::PaymentPurpose::InvoicePayment {
-                                                                                                               payment_preimage: inbound_payment.get().payment_preimage,
-                                                                                                               payment_secret: payment_data.payment_secret,
-                                                                                                       },
-                                                                                                       amt: total_value,
-                                                                                               });
-                                                                                               // Only ever generate at most one PaymentReceived
-                                                                                               // per registered payment_hash, even if it isn't
-                                                                                               // claimed.
+                                                                                       let payment_received_generated = check_total_value!(payment_data.total_msat, payment_data.payment_secret, inbound_payment.get().payment_preimage);
+                                                                                       if payment_received_generated {
                                                                                                inbound_payment.remove_entry();
-                                                                                       } else {
-                                                                                               // Nothing to do - we haven't reached the total
-                                                                                               // payment value yet, wait until we receive more
-                                                                                               // MPP parts.
                                                                                        }
                                                                                }
                                                                        },