Add MPP ID to pending_outbound_htlcs
authorValentine Wallace <vwallace@protonmail.com>
Thu, 19 Aug 2021 23:56:53 +0000 (19:56 -0400)
committerValentine Wallace <vwallace@protonmail.com>
Fri, 17 Sep 2021 19:36:21 +0000 (15:36 -0400)
We'll use this to correlate MPP shards in upcoming commits

lightning/src/ln/channelmanager.rs

index 583ebfb9b73b82352c55a8f775c7d7c3920a30eb..9b0f78b22bc013b6f7c66097279a390091542ad4 100644 (file)
@@ -491,8 +491,11 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// which may generate a claim event, we may receive similar duplicate claim/fail MonitorEvents
        /// after reloading from disk while replaying blocks against ChannelMonitors.
        ///
+       /// Each payment has each of its MPP part's session_priv bytes in the HashSet of the map (even
+       /// payments over a single path).
+       ///
        /// Locked *after* channel_state.
-       pending_outbound_payments: Mutex<HashSet<[u8; 32]>>,
+       pending_outbound_payments: Mutex<HashMap<MppId, HashSet<[u8; 32]>>>,
 
        our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
@@ -1156,7 +1159,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(HashMap::new()),
-                       pending_outbound_payments: Mutex::new(HashSet::new()),
+                       pending_outbound_payments: Mutex::new(HashMap::new()),
 
                        our_network_key: keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &keys_manager.get_node_secret()),
@@ -1853,7 +1856,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash);
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-               assert!(self.pending_outbound_payments.lock().unwrap().insert(session_priv_bytes));
+               let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
+               let sessions = pending_outbounds.entry(mpp_id).or_insert(HashSet::new());
+               assert!(sessions.insert(session_priv_bytes));
 
                let err: Result<(), _> = loop {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -2832,23 +2837,27 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        self.fail_htlc_backwards_internal(channel_state,
                                                htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data});
                                },
-                               HTLCSource::OutboundRoute { session_priv, .. } => {
-                                       if {
-                                               let mut session_priv_bytes = [0; 32];
-                                               session_priv_bytes.copy_from_slice(&session_priv[..]);
-                                               self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
-                                       } {
-                                               self.pending_events.lock().unwrap().push(
-                                                       events::Event::PaymentFailed {
-                                                               payment_hash,
-                                                               rejected_by_dest: false,
-                                                               network_update: None,
-#[cfg(test)]
-                                                               error_code: None,
-#[cfg(test)]
-                                                               error_data: None,
+                               HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
+                                       let mut session_priv_bytes = [0; 32];
+                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
+                                       let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                                       if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(mpp_id) {
+                                               if sessions.get_mut().remove(&session_priv_bytes) {
+                                                       self.pending_events.lock().unwrap().push(
+                                                               events::Event::PaymentFailed {
+                                                                       payment_hash,
+                                                                       rejected_by_dest: false,
+                                                                       network_update: None,
+                                                                       #[cfg(test)]
+                                                                       error_code: None,
+                                                                       #[cfg(test)]
+                                                                       error_data: None,
+                                                               }
+                                                       );
+                                                       if sessions.get().len() == 0 {
+                                                               sessions.remove();
                                                        }
-                                               )
+                                               }
                                        } else {
                                                log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        }
@@ -2873,12 +2882,19 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // from block_connected which may run during initialization prior to the chain_monitor
                // being fully configured. See the docs for `ChannelManagerReadArgs` for more.
                match source {
-                       HTLCSource::OutboundRoute { ref path, session_priv, .. } => {
-                               if {
-                                       let mut session_priv_bytes = [0; 32];
-                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
-                                       !self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
-                               } {
+                       HTLCSource::OutboundRoute { ref path, session_priv, mpp_id, .. } => {
+                               let mut session_priv_bytes = [0; 32];
+                               session_priv_bytes.copy_from_slice(&session_priv[..]);
+                               let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                               if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(mpp_id) {
+                                       if !sessions.get_mut().remove(&session_priv_bytes) {
+                                               log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
+                                               return;
+                                       }
+                                       if sessions.get().len() == 0 {
+                                               sessions.remove();
+                                       }
+                               } else {
                                        log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        return;
                                }
@@ -3119,17 +3135,22 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
                match source {
-                       HTLCSource::OutboundRoute { session_priv, .. } => {
+                       HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
                                mem::drop(channel_state_lock);
-                               if {
-                                       let mut session_priv_bytes = [0; 32];
-                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
-                                       self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
-                               } {
-                                       let mut pending_events = self.pending_events.lock().unwrap();
-                                       pending_events.push(events::Event::PaymentSent {
-                                               payment_preimage
-                                       });
+                               let mut session_priv_bytes = [0; 32];
+                               session_priv_bytes.copy_from_slice(&session_priv[..]);
+                               let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                               if let Some(sessions) = outbounds.get_mut(&mpp_id) {
+                                       if sessions.remove(&session_priv_bytes) {
+                                               self.pending_events.lock().unwrap().push(
+                                                       events::Event::PaymentSent { payment_preimage }
+                                               );
+                                               if sessions.len() == 0 {
+                                                       outbounds.remove(&mpp_id);
+                                               }
+                                       } else {
+                                               log_trace!(self.logger, "Received duplicative fulfill for HTLC with payment_preimage {}", log_bytes!(payment_preimage.0));
+                                       }
                                } else {
                                        log_trace!(self.logger, "Received duplicative fulfill for HTLC with payment_preimage {}", log_bytes!(payment_preimage.0));
                                }
@@ -5105,12 +5126,21 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                }
 
                let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
-               (pending_outbound_payments.len() as u64).write(writer)?;
-               for session_priv in pending_outbound_payments.iter() {
-                       session_priv.write(writer)?;
+               // For backwards compat, write the session privs and their total length.
+               let mut num_pending_outbounds_compat: u64 = 0;
+               for (_, outbounds) in pending_outbound_payments.iter() {
+                       num_pending_outbounds_compat += outbounds.len() as u64;
+               }
+               num_pending_outbounds_compat.write(writer)?;
+               for (_, outbounds) in pending_outbound_payments.iter() {
+                       for outbound in outbounds.iter() {
+                               outbound.write(writer)?;
+                       }
                }
 
-               write_tlv_fields!(writer, {});
+               write_tlv_fields!(writer, {
+                       (1, pending_outbound_payments, required),
+               });
 
                Ok(())
        }
@@ -5363,15 +5393,23 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        }
                }
 
-               let pending_outbound_payments_count: u64 = Readable::read(reader)?;
-               let mut pending_outbound_payments: HashSet<[u8; 32]> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
-               for _ in 0..pending_outbound_payments_count {
-                       if !pending_outbound_payments.insert(Readable::read(reader)?) {
-                               return Err(DecodeError::InvalidValue);
-                       }
+               let pending_outbound_payments_count_compat: u64 = Readable::read(reader)?;
+               let mut pending_outbound_payments_compat: HashMap<MppId, HashSet<[u8; 32]>> =
+                       HashMap::with_capacity(cmp::min(pending_outbound_payments_count_compat as usize, MAX_ALLOC_SIZE/32));
+               for _ in 0..pending_outbound_payments_count_compat {
+                       let session_priv = Readable::read(reader)?;
+                       if pending_outbound_payments_compat.insert(MppId(session_priv), [session_priv].iter().cloned().collect()).is_some() {
+                               return Err(DecodeError::InvalidValue)
+                       };
                }
 
-               read_tlv_fields!(reader, {});
+               let mut pending_outbound_payments = None;
+               read_tlv_fields!(reader, {
+                       (1, pending_outbound_payments, option),
+               });
+               if pending_outbound_payments.is_none() {
+                       pending_outbound_payments = Some(pending_outbound_payments_compat);
+               }
 
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());
@@ -5392,7 +5430,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(pending_inbound_payments),
-                       pending_outbound_payments: Mutex::new(pending_outbound_payments),
+                       pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
 
                        our_network_key: args.keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &args.keys_manager.get_node_secret()),