channelmanager: Add retry data to pending_outbound_payments
authorValentine Wallace <vwallace@protonmail.com>
Fri, 24 Sep 2021 20:02:11 +0000 (16:02 -0400)
committerValentine Wallace <vwallace@protonmail.com>
Tue, 28 Sep 2021 23:39:37 +0000 (19:39 -0400)
lightning/src/ln/channelmanager.rs

index ebfe0d2bc80361c3b9344128f79e29ef5e4603fe..e04217a9f29ca58860d0650e39babe35388d604f 100644 (file)
@@ -400,6 +400,63 @@ struct PendingInboundPayment {
        min_value_msat: Option<u64>,
 }
 
+/// Stores the session_priv for each part of a payment that is still pending. For versions 0.0.102
+/// and later, also stores information for retrying the payment.
+enum PendingOutboundPayment {
+       Legacy {
+               session_privs: HashSet<[u8; 32]>,
+       },
+       Retryable {
+               session_privs: HashSet<[u8; 32]>,
+               payment_hash: PaymentHash,
+               payment_secret: Option<PaymentSecret>,
+               pending_amt_msat: u64,
+               /// The total payment amount across all paths, used to verify that a retry is not overpaying.
+               total_msat: u64,
+       },
+}
+
+impl PendingOutboundPayment {
+       fn remove(&mut self, session_priv: &[u8; 32], part_amt_msat: u64) -> bool {
+               let remove_res = match self {
+                       PendingOutboundPayment::Legacy { session_privs } |
+                       PendingOutboundPayment::Retryable { session_privs, .. } => {
+                               session_privs.remove(session_priv)
+                       }
+               };
+               if remove_res {
+                       if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, .. } = self {
+                               *pending_amt_msat -= part_amt_msat;
+                       }
+               }
+               remove_res
+       }
+
+       fn insert(&mut self, session_priv: [u8; 32], part_amt_msat: u64) -> bool {
+               let insert_res = match self {
+                       PendingOutboundPayment::Legacy { session_privs } |
+                       PendingOutboundPayment::Retryable { session_privs, .. } => {
+                               session_privs.insert(session_priv)
+                       }
+               };
+               if insert_res {
+                       if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, .. } = self {
+                               *pending_amt_msat += part_amt_msat;
+                       }
+               }
+               insert_res
+       }
+
+       fn remaining_parts(&self) -> usize {
+               match self {
+                       PendingOutboundPayment::Legacy { session_privs } |
+                       PendingOutboundPayment::Retryable { session_privs, .. } => {
+                               session_privs.len()
+                       }
+               }
+       }
+}
+
 /// SimpleArcChannelManager is useful when you need a ChannelManager with a static lifetime, e.g.
 /// when you're using lightning-net-tokio (since tokio::spawn requires parameters with static
 /// lifetimes). Other times you can afford a reference, which is more efficient, in which case
@@ -486,7 +543,7 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// Locked *after* channel_state.
        pending_inbound_payments: Mutex<HashMap<PaymentHash, PendingInboundPayment>>,
 
-       /// The session_priv bytes of outbound payments which are pending resolution.
+       /// The session_priv bytes and retry metadata of outbound payments which are pending resolution.
        /// The authoritative state of these HTLCs resides either within Channels or ChannelMonitors
        /// (if the channel has been force-closed), however we track them here to prevent duplicative
        /// PaymentSent/PaymentPathFailed events. Specifically, in the case of a duplicative
@@ -495,11 +552,10 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// which may generate a claim event, we may receive similar duplicate claim/fail MonitorEvents
        /// after reloading from disk while replaying blocks against ChannelMonitors.
        ///
-       /// Each payment has each of its MPP part's session_priv bytes in the HashSet of the map (even
-       /// payments over a single path).
+       /// See `PendingOutboundPayment` documentation for more info.
        ///
        /// Locked *after* channel_state.
-       pending_outbound_payments: Mutex<HashMap<PaymentId, HashSet<[u8; 32]>>>,
+       pending_outbound_payments: Mutex<HashMap<PaymentId, PendingOutboundPayment>>,
 
        our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
@@ -1894,8 +1950,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
-               let sessions = pending_outbounds.entry(payment_id).or_insert(HashSet::new());
-               assert!(sessions.insert(session_priv_bytes));
+               let payment = pending_outbounds.entry(payment_id).or_insert_with(|| PendingOutboundPayment::Retryable {
+                       session_privs: HashSet::new(),
+                       pending_amt_msat: 0,
+                       payment_hash: *payment_hash,
+                       payment_secret: *payment_secret,
+                       total_msat: total_value,
+               });
+               assert!(payment.insert(session_priv_bytes, path.last().unwrap().fee_msat));
 
                let err: Result<(), _> = loop {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -2883,14 +2945,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        let mut session_priv_bytes = [0; 32];
                                        session_priv_bytes.copy_from_slice(&session_priv[..]);
                                        let mut outbounds = self.pending_outbound_payments.lock().unwrap();
-                                       if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(payment_id) {
-                                               if sessions.get_mut().remove(&session_priv_bytes) {
+                                       if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(payment_id) {
+                                               if payment.get_mut().remove(&session_priv_bytes, path.last().unwrap().fee_msat) {
                                                        self.pending_events.lock().unwrap().push(
                                                                events::Event::PaymentPathFailed {
                                                                        payment_hash,
                                                                        rejected_by_dest: false,
                                                                        network_update: None,
-                                                                       all_paths_failed: sessions.get().len() == 0,
+                                                                       all_paths_failed: payment.get().remaining_parts() == 0,
                                                                        path: path.clone(),
                                                                        #[cfg(test)]
                                                                        error_code: None,
@@ -2898,8 +2960,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                        error_data: None,
                                                                }
                                                        );
-                                                       if sessions.get().len() == 0 {
-                                                               sessions.remove();
+                                                       if payment.get().remaining_parts() == 0 {
+                                                               payment.remove();
                                                        }
                                                }
                                        } else {
@@ -2932,11 +2994,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
                                let mut all_paths_failed = false;
                                if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(payment_id) {
-                                       if !sessions.get_mut().remove(&session_priv_bytes) {
+                                       if !sessions.get_mut().remove(&session_priv_bytes, path.last().unwrap().fee_msat) {
                                                log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                                return;
                                        }
-                                       if sessions.get().len() == 0 {
+                                       if sessions.get().remaining_parts() == 0 {
                                                all_paths_failed = true;
                                                sessions.remove();
                                        }
@@ -3185,13 +3247,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
                match source {
-                       HTLCSource::OutboundRoute { session_priv, payment_id, .. } => {
+                       HTLCSource::OutboundRoute { session_priv, payment_id, path, .. } => {
                                mem::drop(channel_state_lock);
                                let mut session_priv_bytes = [0; 32];
                                session_priv_bytes.copy_from_slice(&session_priv[..]);
                                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
                                let found_payment = if let Some(mut sessions) = outbounds.remove(&payment_id) {
-                                       sessions.remove(&session_priv_bytes)
+                                       sessions.remove(&session_priv_bytes, path.last().unwrap().fee_msat)
                                } else { false };
                                if found_payment {
                                        self.pending_events.lock().unwrap().push(
@@ -5161,6 +5223,19 @@ impl_writeable_tlv_based!(PendingInboundPayment, {
        (8, min_value_msat, required),
 });
 
+impl_writeable_tlv_based_enum!(PendingOutboundPayment,
+       (0, Legacy) => {
+               (0, session_privs, required),
+       },
+       (2, Retryable) => {
+               (0, session_privs, required),
+               (2, payment_hash, required),
+               (4, payment_secret, option),
+               (6, total_msat, required),
+               (8, pending_amt_msat, required),
+       },
+;);
+
 impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable for ChannelManager<Signer, M, T, K, F, L>
        where M::Target: chain::Watch<Signer>,
         T::Target: BroadcasterInterface,
@@ -5251,18 +5326,34 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
                // For backwards compat, write the session privs and their total length.
                let mut num_pending_outbounds_compat: u64 = 0;
-               for (_, outbounds) in pending_outbound_payments.iter() {
-                       num_pending_outbounds_compat += outbounds.len() as u64;
+               for (_, outbound) in pending_outbound_payments.iter() {
+                       num_pending_outbounds_compat += outbound.remaining_parts() as u64;
                }
                num_pending_outbounds_compat.write(writer)?;
-               for (_, outbounds) in pending_outbound_payments.iter() {
-                       for outbound in outbounds.iter() {
-                               outbound.write(writer)?;
+               for (_, outbound) in pending_outbound_payments.iter() {
+                       match outbound {
+                               PendingOutboundPayment::Legacy { session_privs } |
+                               PendingOutboundPayment::Retryable { session_privs, .. } => {
+                                       for session_priv in session_privs.iter() {
+                                               session_priv.write(writer)?;
+                                       }
+                               }
                        }
                }
 
+               // Encode without retry info for 0.0.101 compatibility.
+               let mut pending_outbound_payments_no_retry: HashMap<PaymentId, HashSet<[u8; 32]>> = HashMap::new();
+               for (id, outbound) in pending_outbound_payments.iter() {
+                       match outbound {
+                               PendingOutboundPayment::Legacy { session_privs } |
+                               PendingOutboundPayment::Retryable { session_privs, .. } => {
+                                       pending_outbound_payments_no_retry.insert(*id, session_privs.clone());
+                               }
+                       }
+               }
                write_tlv_fields!(writer, {
-                       (1, pending_outbound_payments, required),
+                       (1, pending_outbound_payments_no_retry, required),
+                       (3, pending_outbound_payments, required),
                });
 
                Ok(())
@@ -5522,21 +5613,33 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                }
 
                let pending_outbound_payments_count_compat: u64 = Readable::read(reader)?;
-               let mut pending_outbound_payments_compat: HashMap<PaymentId, HashSet<[u8; 32]>> =
+               let mut pending_outbound_payments_compat: HashMap<PaymentId, PendingOutboundPayment> =
                        HashMap::with_capacity(cmp::min(pending_outbound_payments_count_compat as usize, MAX_ALLOC_SIZE/32));
                for _ in 0..pending_outbound_payments_count_compat {
                        let session_priv = Readable::read(reader)?;
-                       if pending_outbound_payments_compat.insert(PaymentId(session_priv), [session_priv].iter().cloned().collect()).is_some() {
+                       let payment = PendingOutboundPayment::Legacy {
+                               session_privs: [session_priv].iter().cloned().collect()
+                       };
+                       if pending_outbound_payments_compat.insert(PaymentId(session_priv), payment).is_some() {
                                return Err(DecodeError::InvalidValue)
                        };
                }
 
+               // pending_outbound_payments_no_retry is for compatibility with 0.0.101 clients.
+               let mut pending_outbound_payments_no_retry: Option<HashMap<PaymentId, HashSet<[u8; 32]>>> = None;
                let mut pending_outbound_payments = None;
                read_tlv_fields!(reader, {
-                       (1, pending_outbound_payments, option),
+                       (1, pending_outbound_payments_no_retry, option),
+                       (3, pending_outbound_payments, option),
                });
-               if pending_outbound_payments.is_none() {
+               if pending_outbound_payments.is_none() && pending_outbound_payments_no_retry.is_none() {
                        pending_outbound_payments = Some(pending_outbound_payments_compat);
+               } else if pending_outbound_payments.is_none() {
+                       let mut outbounds = HashMap::new();
+                       for (id, session_privs) in pending_outbound_payments_no_retry.unwrap().drain() {
+                               outbounds.insert(id, PendingOutboundPayment::Legacy { session_privs });
+                       }
+                       pending_outbound_payments = Some(outbounds);
                }
 
                let mut secp_ctx = Secp256k1::new();