Ensure payments don't ever duplicatively fail/succeed on reload
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 1b16fcb2b54fbede50b7cba4f9c6940afe59a03a..d44091bed3617d6c513e4eed445bf884487239e3 100644 (file)
@@ -142,7 +142,7 @@ pub(super) enum HTLCForwardInfo {
 }
 
 /// Tracks the inbound corresponding to an outbound HTLC
-#[derive(Clone, PartialEq)]
+#[derive(Clone, PartialEq, Eq)]
 pub(crate) struct HTLCPreviousHopData {
        short_channel_id: u64,
        htlc_id: u64,
@@ -164,7 +164,7 @@ struct ClaimableHTLC {
 }
 
 /// Tracks the inbound corresponding to an outbound HTLC
-#[derive(Clone, PartialEq)]
+#[derive(Clone, Eq)]
 pub(crate) enum HTLCSource {
        PreviousHopData(HTLCPreviousHopData),
        OutboundRoute {
@@ -175,6 +175,64 @@ pub(crate) enum HTLCSource {
                first_hop_htlc_msat: u64,
        },
 }
+
+// Clippy gets mad if we implement Hash manually but not PartialEq, and for good reason - they must
+// match for use in a HashSet.
+// Instead, we opt here to have a PartialEq that matches Hash, but panics with debug_assertions if
+// different fields do not match - something which should never exist.
+impl PartialEq for HTLCSource {
+       fn eq(&self, o: &Self) -> bool {
+               match self {
+                       HTLCSource::PreviousHopData(prev_hop_data) => {
+                               match o {
+                                       HTLCSource::PreviousHopData(o_prev_hop_data) => {
+                                               if prev_hop_data.short_channel_id == o_prev_hop_data.short_channel_id &&
+                                                  prev_hop_data.htlc_id == o_prev_hop_data.htlc_id {
+                                                       debug_assert!(prev_hop_data.incoming_packet_shared_secret == o_prev_hop_data.incoming_packet_shared_secret);
+                                                       debug_assert_eq!(prev_hop_data.outpoint, o_prev_hop_data.outpoint);
+                                                       true
+                                               } else {
+                                                       false
+                                               }
+                                       }
+                                       _ => false
+                               }
+                       }
+                       HTLCSource::OutboundRoute { ref path, ref session_priv, ref first_hop_htlc_msat } => {
+                               match o {
+                                       HTLCSource::OutboundRoute { path: o_path, session_priv: o_session_priv, first_hop_htlc_msat: o_first_hop_htlc_msat } => {
+                                               if session_priv == o_session_priv {
+                                                       debug_assert!(path == o_path);
+                                                       debug_assert_eq!(session_priv, o_session_priv);
+                                                       debug_assert_eq!(first_hop_htlc_msat, o_first_hop_htlc_msat);
+                                                       true
+                                               } else {
+                                                       false
+                                               }
+                                       }
+                                       _ => false
+                               }
+                       }
+               }
+       }
+}
+
+impl std::hash::Hash for HTLCSource {
+       fn hash<H>(&self, hasher: &mut H) where H: std::hash::Hasher {
+               match self {
+                       HTLCSource::PreviousHopData(prev_hop_data) => {
+                               hasher.write(&[0u8]);
+                               hasher.write(&byte_utils::le64_to_array(prev_hop_data.short_channel_id));
+                               hasher.write(&byte_utils::le64_to_array(prev_hop_data.htlc_id));
+                       },
+                       HTLCSource::OutboundRoute { ref session_priv, .. } => {
+                               hasher.write(&[1u8]);
+                               hasher.write(&session_priv[..]);
+                       },
+               }
+       }
+}
+
 #[cfg(test)]
 impl HTLCSource {
        pub fn dummy() -> Self {
@@ -440,6 +498,14 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// Locked *after* channel_state.
        pending_inbound_payments: Mutex<HashMap<PaymentHash, PendingInboundPayment>>,
 
+       /// Outbound HTLCs which were still pending when we force-closed a channel. The authorative
+       /// state of these HTLCs now resides in the relevant ChannelMonitors, however we track them
+       /// here to prevent duplicative PaymentFailed events. Specifically, because the ChannelMonitor
+       /// event is ultimately handled by us, and we aren't supposed to generate duplicative events
+       /// unless we haven't been re-serialized, we have to de-duplicate them here.
+       /// Locked *after* channel_state.
+       outbound_onchain_pending_htlcs: Mutex<HashSet<HTLCSource>>,
+
        our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
 
@@ -893,6 +959,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(HashMap::new()),
+                       outbound_onchain_pending_htlcs: Mutex::new(HashSet::new()),
 
                        our_network_key: keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &keys_manager.get_node_secret()),
@@ -1069,6 +1136,19 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for htlc_source in shutdown_res.outbound_htlcs_failed.drain(..) {
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
                }
+
+               // We shouldn't be holding any locks at this point, so just outbound_onchain_pending_htlcs
+               // lockorder by simply asserting that we aren't holding locks.
+               #[cfg(debug_assertions)]
+               std::mem::drop(self.outbound_onchain_pending_htlcs.try_lock().unwrap());
+               #[cfg(debug_assertions)]
+               std::mem::drop(self.channel_state.try_lock().unwrap());
+
+               for htlc in shutdown_res.outbound_onchain_pending_htlcs.drain(..) {
+                       if !self.outbound_onchain_pending_htlcs.lock().unwrap().insert(htlc) {
+                               log_error!(self.logger, "Got duplicative pending-onchain-resolution HTLC.");
+                       }
+               }
                if let Some((funding_txo, monitor_update)) = shutdown_res.monitor_update {
                        // There isn't anything we can do if we get an update failure - we're already
                        // force-closing. The monitor update on the required in-memory copy should broadcast
@@ -2668,6 +2748,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        // don't respond with the funding_signed so the channel can never go on chain).
                                        let shutdown_res = chan.force_shutdown(true);
                                        assert!(shutdown_res.outbound_htlcs_failed.is_empty());
+                                       assert!(shutdown_res.outbound_onchain_pending_htlcs.is_empty());
                                        return Err(MsgHandleErrInternal::send_err_msg_no_close("ChannelMonitor storage failure".to_owned(), funding_msg.channel_id));
                                },
                                ChannelMonitorUpdateErr::TemporaryFailure => {
@@ -3335,12 +3416,28 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        for monitor_event in self.chain_monitor.release_pending_monitor_events() {
                                match monitor_event {
                                        MonitorEvent::HTLCEvent(htlc_update) => {
-                                               if let Some(preimage) = htlc_update.payment_preimage {
-                                                       log_trace!(self.logger, "Claiming HTLC with preimage {} from our monitor", log_bytes!(preimage.0));
-                                                       self.claim_funds_internal(self.channel_state.lock().unwrap(), htlc_update.source, preimage);
-                                               } else {
-                                                       log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
-                                                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
+                                               // We shouldn't be holding any locks at this point, so just outbound_onchain_pending_htlcs
+                                               // lockorder by simply asserting that we aren't holding locks.
+                                               #[cfg(debug_assertions)]
+                                               std::mem::drop(self.outbound_onchain_pending_htlcs.try_lock().unwrap());
+                                               #[cfg(debug_assertions)]
+                                               std::mem::drop(self.channel_state.try_lock().unwrap());
+
+                                               if {
+                                                       // Take the outbound_onchain_pending_htlcs lock in a scope so we aren't
+                                                       // holding it while we process the HTLC event.
+                                                       self.outbound_onchain_pending_htlcs.lock().unwrap().remove(&htlc_update.source)
+                                               } {
+                                                       #[cfg(debug_assertions)]
+                                                       std::mem::drop(self.outbound_onchain_pending_htlcs.try_lock().unwrap());
+
+                                                       if let Some(preimage) = htlc_update.payment_preimage {
+                                                               log_trace!(self.logger, "Claiming HTLC with preimage {} from our monitor", log_bytes!(preimage.0));
+                                                               self.claim_funds_internal(self.channel_state.lock().unwrap(), htlc_update.source, preimage);
+                                                       } else {
+                                                               log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
+                                                               self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
+                                                       }
                                                }
                                        },
                                        MonitorEvent::CommitmentTxBroadcasted(funding_outpoint) => {
@@ -4420,6 +4517,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                        pending_payment.write(writer)?;
                }
 
+               let outbound_onchain_pending_htlcs = self.outbound_onchain_pending_htlcs.lock().unwrap();
+               (outbound_onchain_pending_htlcs.len() as u64).write(writer)?;
+               for htlc in outbound_onchain_pending_htlcs.iter() {
+                       htlc.write(writer)?;
+               }
+
                Ok(())
        }
 }
@@ -4555,6 +4658,8 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
 
                let mut failed_htlcs = Vec::new();
 
+               let mut outbound_onchain_pending_htlcs: HashSet<HTLCSource> = HashSet::new();
+
                let channel_count: u64 = Readable::read(reader)?;
                let mut funding_txo_set = HashSet::with_capacity(cmp::min(channel_count as usize, 128));
                let mut by_id = HashMap::with_capacity(cmp::min(channel_count as usize, 128));
@@ -4577,6 +4682,11 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                        // But if the channel is behind of the monitor, close the channel:
                                        let mut shutdown_res = channel.force_shutdown(true);
                                        failed_htlcs.append(&mut shutdown_res.outbound_htlcs_failed);
+                                       for htlc in shutdown_res.outbound_onchain_pending_htlcs.drain(..) {
+                                               if !outbound_onchain_pending_htlcs.insert(htlc) {
+                                                       return Err(DecodeError::InvalidValue);
+                                               }
+                                       }
                                        monitor.broadcast_latest_holder_commitment_txn(&args.tx_broadcaster, &args.logger);
                                } else {
                                        if let Some(short_channel_id) = channel.get_short_channel_id() {
@@ -4659,6 +4769,14 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        }
                }
 
+               let outbound_onchain_pending_htlcs_count: u64 = Readable::read(reader)?;
+               outbound_onchain_pending_htlcs.reserve(cmp::min(outbound_onchain_pending_htlcs_count as usize, MAX_ALLOC_SIZE/mem::size_of::<HTLCSource>()));
+               for _ in 0..outbound_onchain_pending_htlcs_count {
+                       if !outbound_onchain_pending_htlcs.insert(Readable::read(reader)?) {
+                               return Err(DecodeError::InvalidValue);
+                       }
+               }
+
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());
 
@@ -4678,6 +4796,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(pending_inbound_payments),
+                       outbound_onchain_pending_htlcs: Mutex::new(outbound_onchain_pending_htlcs),
 
                        our_network_key: args.keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &args.keys_manager.get_node_secret()),