Track payments after they resolve until all HTLCs are finalized
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 019d61c57c09a2d3e8f0606a3353a49198cc5425..187a95b5300c6299250ea00bc8f84ef08acd26a7 100644 (file)
@@ -416,19 +416,51 @@ pub(crate) enum PendingOutboundPayment {
                /// Our best known block height at the time this payment was initiated.
                starting_block_height: u32,
        },
+       /// When a pending payment is fulfilled, we continue tracking it until all pending HTLCs have
+       /// been resolved. This ensures we don't look up pending payments in ChannelMonitors on restart
+       /// and add a pending payment that was already fulfilled.
+       Fulfilled {
+               session_privs: HashSet<[u8; 32]>,
+       },
 }
 
 impl PendingOutboundPayment {
-       fn remove(&mut self, session_priv: &[u8; 32], part_amt_msat: u64) -> bool {
+       fn is_retryable(&self) -> bool {
+               match self {
+                       PendingOutboundPayment::Retryable { .. } => true,
+                       _ => false,
+               }
+       }
+       fn is_fulfilled(&self) -> bool {
+               match self {
+                       PendingOutboundPayment::Fulfilled { .. } => true,
+                       _ => false,
+               }
+       }
+
+       fn mark_fulfilled(&mut self) {
+               let mut session_privs = HashSet::new();
+               core::mem::swap(&mut session_privs, match self {
+                       PendingOutboundPayment::Legacy { session_privs } |
+                       PendingOutboundPayment::Retryable { session_privs, .. } |
+                       PendingOutboundPayment::Fulfilled { session_privs }
+                               => session_privs
+               });
+               *self = PendingOutboundPayment::Fulfilled { session_privs };
+       }
+
+       /// panics if part_amt_msat is None and !self.is_fulfilled
+       fn remove(&mut self, session_priv: &[u8; 32], part_amt_msat: Option<u64>) -> bool {
                let remove_res = match self {
                        PendingOutboundPayment::Legacy { session_privs } |
-                       PendingOutboundPayment::Retryable { session_privs, .. } => {
+                       PendingOutboundPayment::Retryable { session_privs, .. } |
+                       PendingOutboundPayment::Fulfilled { session_privs } => {
                                session_privs.remove(session_priv)
                        }
                };
                if remove_res {
                        if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, .. } = self {
-                               *pending_amt_msat -= part_amt_msat;
+                               *pending_amt_msat -= part_amt_msat.expect("We must only not provide an amount if the payment was already fulfilled");
                        }
                }
                remove_res
@@ -440,6 +472,7 @@ impl PendingOutboundPayment {
                        PendingOutboundPayment::Retryable { session_privs, .. } => {
                                session_privs.insert(session_priv)
                        }
+                       PendingOutboundPayment::Fulfilled { .. } => false
                };
                if insert_res {
                        if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, .. } = self {
@@ -452,7 +485,8 @@ impl PendingOutboundPayment {
        fn remaining_parts(&self) -> usize {
                match self {
                        PendingOutboundPayment::Legacy { session_privs } |
-                       PendingOutboundPayment::Retryable { session_privs, .. } => {
+                       PendingOutboundPayment::Retryable { session_privs, .. } |
+                       PendingOutboundPayment::Fulfilled { session_privs } => {
                                session_privs.len()
                        }
                }
@@ -1983,6 +2017,17 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                let err: Result<(), _> = loop {
                        let mut channel_lock = self.channel_state.lock().unwrap();
+
+                       let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
+                       let payment_entry = pending_outbounds.entry(payment_id);
+                       if let hash_map::Entry::Occupied(payment) = &payment_entry {
+                               if !payment.get().is_retryable() {
+                                       return Err(APIError::RouteError {
+                                               err: "Payment already completed"
+                                       });
+                               }
+                       }
+
                        let id = match channel_lock.short_to_id.get(&path.first().unwrap().short_channel_id) {
                                None => return Err(APIError::ChannelUnavailable{err: "No channel available with first hop!".to_owned()}),
                                Some(id) => id.clone(),
@@ -2006,8 +2051,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                }, onion_packet, &self.logger),
                                        channel_state, chan);
 
-                                       let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
-                                       let payment = pending_outbounds.entry(payment_id).or_insert_with(|| PendingOutboundPayment::Retryable {
+                                       let payment = payment_entry.or_insert_with(|| PendingOutboundPayment::Retryable {
                                                session_privs: HashSet::new(),
                                                pending_amt_msat: 0,
                                                payment_hash: *payment_hash,
@@ -2203,7 +2247,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                return Err(PaymentSendFailure::ParameterError(APIError::APIMisuseError {
                                                        err: "Unable to retry payments that were initially sent on LDK versions prior to 0.0.102".to_string()
                                                }))
-                                       }
+                                       },
+                                       PendingOutboundPayment::Fulfilled { .. } => {
+                                               return Err(PaymentSendFailure::ParameterError(APIError::RouteError {
+                                                       err: "Payment already completed"
+                                               }));
+                                       },
                                }
                        } else {
                                return Err(PaymentSendFailure::ParameterError(APIError::APIMisuseError {
@@ -3031,7 +3080,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        session_priv_bytes.copy_from_slice(&session_priv[..]);
                                        let mut outbounds = self.pending_outbound_payments.lock().unwrap();
                                        if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(payment_id) {
-                                               if payment.get_mut().remove(&session_priv_bytes, path.last().unwrap().fee_msat) {
+                                               if payment.get_mut().remove(&session_priv_bytes, Some(path.last().unwrap().fee_msat)) &&
+                                                       !payment.get().is_fulfilled()
+                                               {
                                                        self.pending_events.lock().unwrap().push(
                                                                events::Event::PaymentPathFailed {
                                                                        payment_hash,
@@ -3077,10 +3128,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
                                let mut all_paths_failed = false;
                                if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(payment_id) {
-                                       if !sessions.get_mut().remove(&session_priv_bytes, path.last().unwrap().fee_msat) {
+                                       if !sessions.get_mut().remove(&session_priv_bytes, Some(path.last().unwrap().fee_msat)) {
                                                log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                                return;
                                        }
+                                       if sessions.get().is_fulfilled() {
+                                               log_trace!(self.logger, "Received failure of HTLC with payment_hash {} after payment completion", log_bytes!(payment_hash.0));
+                                               return;
+                                       }
                                        if sessions.get().remaining_parts() == 0 {
                                                all_paths_failed = true;
                                        }
@@ -3329,6 +3384,23 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                } else { unreachable!(); }
        }
 
+       fn finalize_claims(&self, mut sources: Vec<HTLCSource>) {
+               for source in sources.drain(..) {
+                       if let HTLCSource::OutboundRoute { session_priv, payment_id, .. } = source {
+                               let mut session_priv_bytes = [0; 32];
+                               session_priv_bytes.copy_from_slice(&session_priv[..]);
+                               let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                               if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(payment_id) {
+                                       assert!(sessions.get().is_fulfilled());
+                                       sessions.get_mut().remove(&session_priv_bytes, None);
+                                       if sessions.get().remaining_parts() == 0 {
+                                               sessions.remove();
+                                       }
+                               }
+                       }
+               }
+       }
+
        fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
                match source {
                        HTLCSource::OutboundRoute { session_priv, payment_id, path, .. } => {
@@ -3336,8 +3408,22 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                let mut session_priv_bytes = [0; 32];
                                session_priv_bytes.copy_from_slice(&session_priv[..]);
                                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
-                               let found_payment = if let Some(mut sessions) = outbounds.remove(&payment_id) {
-                                       sessions.remove(&session_priv_bytes, path.last().unwrap().fee_msat)
+                               let found_payment = if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(payment_id) {
+                                       let found_payment = !sessions.get().is_fulfilled();
+                                       sessions.get_mut().mark_fulfilled();
+                                       if from_onchain {
+                                               // We currently immediately remove HTLCs which were fulfilled on-chain.
+                                               // This could potentially lead to removing a pending payment too early,
+                                               // with a reorg of one block causing us to re-add the fulfilled payment on
+                                               // restart.
+                                               // TODO: We should have a second monitor event that informs us of payments
+                                               // irrevocably fulfilled.
+                                               sessions.get_mut().remove(&session_priv_bytes, Some(path.last().unwrap().fee_msat));
+                                               if sessions.get().remaining_parts() == 0 {
+                                                       sessions.remove();
+                                               }
+                                       }
+                                       found_payment
                                } else { false };
                                if found_payment {
                                        let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner());
@@ -3412,7 +3498,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
                let chan_restoration_res;
-               let mut pending_failures = {
+               let (mut pending_failures, finalized_claims) = {
                        let mut channel_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_lock;
                        let mut channel = match channel_state.by_id.entry(funding_txo.to_channel_id()) {
@@ -3434,14 +3520,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        msg: self.get_channel_update_for_unicast(channel.get()).unwrap(),
                                })
                        } else { None };
-                       // TODO: Handle updates.finalized_claimed_htlcs!
                        chan_restoration_res = handle_chan_restoration_locked!(self, channel_lock, channel_state, channel, updates.raa, updates.commitment_update, updates.order, None, updates.accepted_htlcs, updates.funding_broadcastable, updates.funding_locked);
                        if let Some(upd) = channel_update {
                                channel_state.pending_msg_events.push(upd);
                        }
-                       updates.failed_htlcs
+                       (updates.failed_htlcs, updates.finalized_claimed_htlcs)
                };
                post_handle_chan_restoration!(self, chan_restoration_res);
+               self.finalize_claims(finalized_claims);
                for failure in pending_failures.drain(..) {
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2);
                }
@@ -3962,6 +4048,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                });
                                        }
                                        break Ok((raa_updates.accepted_htlcs, raa_updates.failed_htlcs,
+                                                       raa_updates.finalized_claimed_htlcs,
                                                        chan.get().get_short_channel_id()
                                                                .expect("RAA should only work on a short-id-available channel"),
                                                        chan.get().get_funding_txo().unwrap()))
@@ -3971,11 +4058,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                };
                self.fail_holding_cell_htlcs(htlcs_to_fail, msg.channel_id);
                match res {
-                       Ok((pending_forwards, mut pending_failures, short_channel_id, channel_outpoint)) => {
+                       Ok((pending_forwards, mut pending_failures, finalized_claim_htlcs,
+                               short_channel_id, channel_outpoint)) =>
+                       {
                                for failure in pending_failures.drain(..) {
                                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2);
                                }
                                self.forward_htlcs(&mut [(short_channel_id, channel_outpoint, pending_forwards)]);
+                               self.finalize_claims(finalized_claim_htlcs);
                                Ok(())
                        },
                        Err(e) => Err(e)
@@ -5338,10 +5428,13 @@ impl_writeable_tlv_based!(PendingInboundPayment, {
        (8, min_value_msat, required),
 });
 
-impl_writeable_tlv_based_enum!(PendingOutboundPayment,
+impl_writeable_tlv_based_enum_upgradable!(PendingOutboundPayment,
        (0, Legacy) => {
                (0, session_privs, required),
        },
+       (1, Fulfilled) => {
+               (0, session_privs, required),
+       },
        (2, Retryable) => {
                (0, session_privs, required),
                (2, payment_hash, required),
@@ -5350,7 +5443,7 @@ impl_writeable_tlv_based_enum!(PendingOutboundPayment,
                (8, pending_amt_msat, required),
                (10, starting_block_height, required),
        },
-;);
+);
 
 impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable for ChannelManager<Signer, M, T, K, F, L>
        where M::Target: chain::Watch<Signer>,
@@ -5443,7 +5536,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                // For backwards compat, write the session privs and their total length.
                let mut num_pending_outbounds_compat: u64 = 0;
                for (_, outbound) in pending_outbound_payments.iter() {
-                       num_pending_outbounds_compat += outbound.remaining_parts() as u64;
+                       if !outbound.is_fulfilled() {
+                               num_pending_outbounds_compat += outbound.remaining_parts() as u64;
+                       }
                }
                num_pending_outbounds_compat.write(writer)?;
                for (_, outbound) in pending_outbound_payments.iter() {
@@ -5454,6 +5549,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                                                session_priv.write(writer)?;
                                        }
                                }
+                               PendingOutboundPayment::Fulfilled { .. } => {},
                        }
                }
 
@@ -5464,7 +5560,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                                PendingOutboundPayment::Legacy { session_privs } |
                                PendingOutboundPayment::Retryable { session_privs, .. } => {
                                        pending_outbound_payments_no_retry.insert(*id, session_privs.clone());
-                               }
+                               },
+                               _ => {},
                        }
                }
                write_tlv_fields!(writer, {