]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Make `claimable_payments` map value a struct, rather than a tuple
authorMatt Corallo <git@bluematt.me>
Fri, 7 Apr 2023 20:19:03 +0000 (20:19 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 19 Apr 2023 02:57:19 +0000 (02:57 +0000)
This makes the `claimable_payments` code more upgradable allowing
us to add new fields in the coming commit(s).

lightning/src/ln/channelmanager.rs

index 4ca93418582e1da2d94fbeeff63959c5c3b06e5e..d51209d7e0ec5c88d7bbac48e52b9f81cdab5640 100644 (file)
@@ -470,6 +470,11 @@ impl_writeable_tlv_based!(ClaimingPayment, {
        (4, receiver_node_id, required),
 });
 
+struct ClaimablePayment {
+       purpose: events::PaymentPurpose,
+       htlcs: Vec<ClaimableHTLC>,
+}
+
 /// Information about claimable or being-claimed payments
 struct ClaimablePayments {
        /// Map from payment hash to the payment data and any HTLCs which are to us and can be
@@ -480,7 +485,7 @@ struct ClaimablePayments {
        ///
        /// When adding to the map, [`Self::pending_claiming_payments`] must also be checked to ensure
        /// we don't get a duplicate payment.
-       claimable_htlcs: HashMap<PaymentHash, (events::PaymentPurpose, Vec<ClaimableHTLC>)>,
+       claimable_payments: HashMap<PaymentHash, ClaimablePayment>,
 
        /// Map from payment hash to the payment data for HTLCs which we have begun claiming, but which
        /// are waiting on a [`ChannelMonitorUpdate`] to complete in order to be surfaced to the user
@@ -1668,7 +1673,7 @@ where
                        pending_inbound_payments: Mutex::new(HashMap::new()),
                        pending_outbound_payments: OutboundPayments::new(),
                        forward_htlcs: Mutex::new(HashMap::new()),
-                       claimable_payments: Mutex::new(ClaimablePayments { claimable_htlcs: HashMap::new(), pending_claiming_payments: HashMap::new() }),
+                       claimable_payments: Mutex::new(ClaimablePayments { claimable_payments: HashMap::new(), pending_claiming_payments: HashMap::new() }),
                        pending_intercepted_htlcs: Mutex::new(HashMap::new()),
                        id_to_peer: Mutex::new(HashMap::new()),
                        short_to_chan_info: FairRwLock::new(HashMap::new()),
@@ -3349,8 +3354,13 @@ where
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
                                                                                        continue
                                                                                }
-                                                                               let (_, ref mut htlcs) = claimable_payments.claimable_htlcs.entry(payment_hash)
-                                                                                       .or_insert_with(|| (purpose(), Vec::new()));
+                                                                               let ref mut claimable_payment = claimable_payments.claimable_payments
+                                                                                       .entry(payment_hash)
+                                                                                       // Note that if we insert here we MUST NOT fail_htlc!()
+                                                                                       .or_insert_with(|| ClaimablePayment {
+                                                                                               purpose: purpose(), htlcs: Vec::new()
+                                                                                       });
+                                                                               let ref mut htlcs = &mut claimable_payment.htlcs;
                                                                                if htlcs.len() == 1 {
                                                                                        if let OnionPayload::Spontaneous(_) = htlcs[0].onion_payload {
                                                                                                log_trace!(self.logger, "Failing new HTLC with payment_hash {} as we already had an existing keysend HTLC with the same payment hash", log_bytes!(payment_hash.0));
@@ -3445,13 +3455,16 @@ where
                                                                                                        fail_htlc!(claimable_htlc, payment_hash);
                                                                                                        continue
                                                                                                }
-                                                                                               match claimable_payments.claimable_htlcs.entry(payment_hash) {
+                                                                                               match claimable_payments.claimable_payments.entry(payment_hash) {
                                                                                                        hash_map::Entry::Vacant(e) => {
                                                                                                                let amount_msat = claimable_htlc.value;
                                                                                                                claimable_htlc.total_value_received = Some(amount_msat);
                                                                                                                let claim_deadline = Some(claimable_htlc.cltv_expiry - HTLC_FAIL_BACK_BUFFER);
                                                                                                                let purpose = events::PaymentPurpose::SpontaneousPayment(preimage);
-                                                                                                               e.insert((purpose.clone(), vec![claimable_htlc]));
+                                                                                                               e.insert(ClaimablePayment {
+                                                                                                                       purpose: purpose.clone(),
+                                                                                                                       htlcs: vec![claimable_htlc],
+                                                                                                               });
                                                                                                                let prev_channel_id = prev_funding_outpoint.to_channel_id();
                                                                                                                new_events.push(events::Event::PaymentClaimable {
                                                                                                                        receiver_node_id: Some(receiver_node_id),
@@ -3708,24 +3721,27 @@ where
                                }
                        }
 
-                       self.claimable_payments.lock().unwrap().claimable_htlcs.retain(|payment_hash, (_, htlcs)| {
-                               if htlcs.is_empty() {
+                       self.claimable_payments.lock().unwrap().claimable_payments.retain(|payment_hash, payment| {
+                               if payment.htlcs.is_empty() {
                                        // This should be unreachable
                                        debug_assert!(false);
                                        return false;
                                }
-                               if let OnionPayload::Invoice { .. } = htlcs[0].onion_payload {
+                               if let OnionPayload::Invoice { .. } = payment.htlcs[0].onion_payload {
                                        // Check if we've received all the parts we need for an MPP (the value of the parts adds to total_msat).
                                        // In this case we're not going to handle any timeouts of the parts here.
                                        // This condition determining whether the MPP is complete here must match
                                        // exactly the condition used in `process_pending_htlc_forwards`.
-                                       if htlcs[0].total_msat <= htlcs.iter().fold(0, |total, htlc| total + htlc.sender_intended_value) {
+                                       if payment.htlcs[0].total_msat <= payment.htlcs.iter()
+                                               .fold(0, |total, htlc| total + htlc.sender_intended_value)
+                                       {
                                                return true;
-                                       } else if htlcs.into_iter().any(|htlc| {
+                                       } else if payment.htlcs.iter_mut().any(|htlc| {
                                                htlc.timer_ticks += 1;
                                                return htlc.timer_ticks >= MPP_TIMEOUT_TICKS
                                        }) {
-                                               timed_out_mpp_htlcs.extend(htlcs.drain(..).map(|htlc: ClaimableHTLC| (htlc.prev_hop, *payment_hash)));
+                                               timed_out_mpp_htlcs.extend(payment.htlcs.drain(..)
+                                                       .map(|htlc: ClaimableHTLC| (htlc.prev_hop, *payment_hash)));
                                                return false;
                                        }
                                }
@@ -3780,9 +3796,9 @@ where
        pub fn fail_htlc_backwards_with_reason(&self, payment_hash: &PaymentHash, failure_code: FailureCode) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let removed_source = self.claimable_payments.lock().unwrap().claimable_htlcs.remove(payment_hash);
-               if let Some((_, mut sources)) = removed_source {
-                       for htlc in sources.drain(..) {
+               let removed_source = self.claimable_payments.lock().unwrap().claimable_payments.remove(payment_hash);
+               if let Some(payment) = removed_source {
+                       for htlc in payment.htlcs {
                                let reason = self.get_htlc_fail_reason_from_failure_code(failure_code, &htlc);
                                let source = HTLCSource::PreviousHopData(htlc.prev_hop);
                                let receiver = HTLCDestination::FailedPayment { payment_hash: *payment_hash };
@@ -3959,9 +3975,9 @@ where
 
                let mut sources = {
                        let mut claimable_payments = self.claimable_payments.lock().unwrap();
-                       if let Some((payment_purpose, sources)) = claimable_payments.claimable_htlcs.remove(&payment_hash) {
+                       if let Some(payment) = claimable_payments.claimable_payments.remove(&payment_hash) {
                                let mut receiver_node_id = self.our_network_pubkey;
-                               for htlc in sources.iter() {
+                               for htlc in payment.htlcs.iter() {
                                        if htlc.prev_hop.phantom_shared_secret.is_some() {
                                                let phantom_pubkey = self.node_signer.get_node_id(Recipient::PhantomNode)
                                                        .expect("Failed to get node_id for phantom node recipient");
@@ -3971,15 +3987,15 @@ where
                                }
 
                                let dup_purpose = claimable_payments.pending_claiming_payments.insert(payment_hash,
-                                       ClaimingPayment { amount_msat: sources.iter().map(|source| source.value).sum(),
-                                       payment_purpose, receiver_node_id,
+                                       ClaimingPayment { amount_msat: payment.htlcs.iter().map(|source| source.value).sum(),
+                                       payment_purpose: payment.purpose, receiver_node_id,
                                });
                                if dup_purpose.is_some() {
                                        debug_assert!(false, "Shouldn't get a duplicate pending claim event ever");
                                        log_error!(self.logger, "Got a duplicate pending claimable event on payment hash {}! Please report this bug",
                                                log_bytes!(payment_hash.0));
                                }
-                               sources
+                               payment.htlcs
                        } else { return; }
                };
                debug_assert!(!sources.is_empty());
@@ -6091,8 +6107,8 @@ where
                }
 
                if let Some(height) = height_opt {
-                       self.claimable_payments.lock().unwrap().claimable_htlcs.retain(|payment_hash, (_, htlcs)| {
-                               htlcs.retain(|htlc| {
+                       self.claimable_payments.lock().unwrap().claimable_payments.retain(|payment_hash, payment| {
+                               payment.htlcs.retain(|htlc| {
                                        // If height is approaching the number of blocks we think it takes us to get
                                        // our commitment transaction confirmed before the HTLC expires, plus the
                                        // number of blocks we generally consider it to take to do a commitment update,
@@ -6107,7 +6123,7 @@ where
                                                false
                                        } else { true }
                                });
-                               !htlcs.is_empty() // Only retain this entry if htlcs has at least one entry.
+                               !payment.htlcs.is_empty() // Only retain this entry if htlcs has at least one entry.
                        });
 
                        let mut intercepted_htlcs = self.pending_intercepted_htlcs.lock().unwrap();
@@ -7028,14 +7044,14 @@ where
                let pending_outbound_payments = self.pending_outbound_payments.pending_outbound_payments.lock().unwrap();
 
                let mut htlc_purposes: Vec<&events::PaymentPurpose> = Vec::new();
-               (claimable_payments.claimable_htlcs.len() as u64).write(writer)?;
-               for (payment_hash, (purpose, previous_hops)) in claimable_payments.claimable_htlcs.iter() {
+               (claimable_payments.claimable_payments.len() as u64).write(writer)?;
+               for (payment_hash, payment) in claimable_payments.claimable_payments.iter() {
                        payment_hash.write(writer)?;
-                       (previous_hops.len() as u64).write(writer)?;
-                       for htlc in previous_hops.iter() {
+                       (payment.htlcs.len() as u64).write(writer)?;
+                       for htlc in payment.htlcs.iter() {
                                htlc.write(writer)?;
                        }
-                       htlc_purposes.push(purpose);
+                       htlc_purposes.push(&payment.purpose);
                }
 
                let mut monitor_update_blocked_actions_per_peer = None;
@@ -7688,22 +7704,25 @@ where
                let inbound_pmt_key_material = args.node_signer.get_inbound_payment_key_material();
                let expanded_inbound_key = inbound_payment::ExpandedKey::new(&inbound_pmt_key_material);
 
-               let mut claimable_htlcs = HashMap::with_capacity(claimable_htlcs_list.len());
+               let mut claimable_payments = HashMap::with_capacity(claimable_htlcs_list.len());
                if let Some(mut purposes) = claimable_htlc_purposes {
                        if purposes.len() != claimable_htlcs_list.len() {
                                return Err(DecodeError::InvalidValue);
                        }
-                       for (purpose, (payment_hash, previous_hops)) in purposes.drain(..).zip(claimable_htlcs_list.drain(..)) {
-                               claimable_htlcs.insert(payment_hash, (purpose, previous_hops));
+                       for (purpose, (payment_hash, htlcs)) in purposes.drain(..).zip(claimable_htlcs_list.drain(..)) {
+                               let existing_payment = claimable_payments.insert(payment_hash, ClaimablePayment {
+                                       purpose, htlcs,
+                               });
+                               if existing_payment.is_some() { return Err(DecodeError::InvalidValue); }
                        }
                } else {
                        // LDK versions prior to 0.0.107 did not write a `pending_htlc_purposes`, but do
                        // include a `_legacy_hop_data` in the `OnionPayload`.
-                       for (payment_hash, previous_hops) in claimable_htlcs_list.drain(..) {
-                               if previous_hops.is_empty() {
+                       for (payment_hash, htlcs) in claimable_htlcs_list.drain(..) {
+                               if htlcs.is_empty() {
                                        return Err(DecodeError::InvalidValue);
                                }
-                               let purpose = match &previous_hops[0].onion_payload {
+                               let purpose = match &htlcs[0].onion_payload {
                                        OnionPayload::Invoice { _legacy_hop_data } => {
                                                if let Some(hop_data) = _legacy_hop_data {
                                                        events::PaymentPurpose::InvoicePayment {
@@ -7724,7 +7743,9 @@ where
                                        OnionPayload::Spontaneous(payment_preimage) =>
                                                events::PaymentPurpose::SpontaneousPayment(*payment_preimage),
                                };
-                               claimable_htlcs.insert(payment_hash, (purpose, previous_hops));
+                               claimable_payments.insert(payment_hash, ClaimablePayment {
+                                       purpose, htlcs,
+                               });
                        }
                }
 
@@ -7776,17 +7797,17 @@ where
 
                for (_, monitor) in args.channel_monitors.iter() {
                        for (payment_hash, payment_preimage) in monitor.get_stored_preimages() {
-                               if let Some((payment_purpose, claimable_htlcs)) = claimable_htlcs.remove(&payment_hash) {
+                               if let Some(payment) = claimable_payments.remove(&payment_hash) {
                                        log_info!(args.logger, "Re-claiming HTLCs with payment hash {} as we've released the preimage to a ChannelMonitor!", log_bytes!(payment_hash.0));
                                        let mut claimable_amt_msat = 0;
                                        let mut receiver_node_id = Some(our_network_pubkey);
-                                       let phantom_shared_secret = claimable_htlcs[0].prev_hop.phantom_shared_secret;
+                                       let phantom_shared_secret = payment.htlcs[0].prev_hop.phantom_shared_secret;
                                        if phantom_shared_secret.is_some() {
                                                let phantom_pubkey = args.node_signer.get_node_id(Recipient::PhantomNode)
                                                        .expect("Failed to get node_id for phantom node recipient");
                                                receiver_node_id = Some(phantom_pubkey)
                                        }
-                                       for claimable_htlc in claimable_htlcs {
+                                       for claimable_htlc in payment.htlcs {
                                                claimable_amt_msat += claimable_htlc.value;
 
                                                // Add a holding-cell claim of the payment to the Channel, which should be
@@ -7820,7 +7841,7 @@ where
                                        pending_events_read.push(events::Event::PaymentClaimed {
                                                receiver_node_id,
                                                payment_hash,
-                                               purpose: payment_purpose,
+                                               purpose: payment.purpose,
                                                amount_msat: claimable_amt_msat,
                                        });
                                }
@@ -7851,7 +7872,7 @@ where
                        pending_intercepted_htlcs: Mutex::new(pending_intercepted_htlcs.unwrap()),
 
                        forward_htlcs: Mutex::new(forward_htlcs),
-                       claimable_payments: Mutex::new(ClaimablePayments { claimable_htlcs, pending_claiming_payments: pending_claiming_payments.unwrap() }),
+                       claimable_payments: Mutex::new(ClaimablePayments { claimable_payments, pending_claiming_payments: pending_claiming_payments.unwrap() }),
                        outbound_scid_aliases: Mutex::new(outbound_scid_aliases),
                        id_to_peer: Mutex::new(id_to_peer),
                        short_to_chan_info: FairRwLock::new(short_to_chan_info),