Merge pull request #2795 from TheBlueMatt/2023-11-robuster-chan-to-peer
[rust-lightning] / lightning / src / ln / channel.rs
index 28842b1bf79d2b0f21d2367291338c898c76226f..588995656ec4790cedd6bcb5ac7edc6bc190ae23 100644 (file)
@@ -259,6 +259,11 @@ enum HTLCUpdateAwaitingACK {
                htlc_id: u64,
                err_packet: msgs::OnionErrorPacket,
        },
+       FailMalformedHTLC {
+               htlc_id: u64,
+               failure_code: u16,
+               sha256_of_onion: [u8; 32],
+       },
 }
 
 macro_rules! define_state_flags {
@@ -2430,8 +2435,13 @@ impl<SP: Deref> ChannelContext<SP> where SP::Target: SignerProvider  {
                                        .ok();
 
                                if funding_signed.is_none() {
-                                       log_trace!(logger, "Counterparty commitment signature not available for funding_signed message; setting signer_pending_funding");
-                                       self.signer_pending_funding = true;
+                                       #[cfg(not(async_signing))] {
+                                               panic!("Failed to get signature for funding_signed");
+                                       }
+                                       #[cfg(async_signing)] {
+                                               log_trace!(logger, "Counterparty commitment signature not available for funding_signed message; setting signer_pending_funding");
+                                               self.signer_pending_funding = true;
+                                       }
                                } else if self.signer_pending_funding {
                                        log_trace!(logger, "Counterparty commitment signature available for funding_signed message; clearing signer_pending_funding");
                                        self.signer_pending_funding = false;
@@ -2519,6 +2529,64 @@ struct CommitmentTxInfoCached {
        feerate: u32,
 }
 
+/// Contents of a wire message that fails an HTLC backwards. Useful for [`Channel::fail_htlc`] to
+/// fail with either [`msgs::UpdateFailMalformedHTLC`] or [`msgs::UpdateFailHTLC`] as needed.
+trait FailHTLCContents {
+       type Message: FailHTLCMessageName;
+       fn to_message(self, htlc_id: u64, channel_id: ChannelId) -> Self::Message;
+       fn to_inbound_htlc_state(self) -> InboundHTLCState;
+       fn to_htlc_update_awaiting_ack(self, htlc_id: u64) -> HTLCUpdateAwaitingACK;
+}
+impl FailHTLCContents for msgs::OnionErrorPacket {
+       type Message = msgs::UpdateFailHTLC;
+       fn to_message(self, htlc_id: u64, channel_id: ChannelId) -> Self::Message {
+               msgs::UpdateFailHTLC { htlc_id, channel_id, reason: self }
+       }
+       fn to_inbound_htlc_state(self) -> InboundHTLCState {
+               InboundHTLCState::LocalRemoved(InboundHTLCRemovalReason::FailRelay(self))
+       }
+       fn to_htlc_update_awaiting_ack(self, htlc_id: u64) -> HTLCUpdateAwaitingACK {
+               HTLCUpdateAwaitingACK::FailHTLC { htlc_id, err_packet: self }
+       }
+}
+impl FailHTLCContents for (u16, [u8; 32]) {
+       type Message = msgs::UpdateFailMalformedHTLC; // (failure_code, sha256_of_onion)
+       fn to_message(self, htlc_id: u64, channel_id: ChannelId) -> Self::Message {
+               msgs::UpdateFailMalformedHTLC {
+                       htlc_id,
+                       channel_id,
+                       failure_code: self.0,
+                       sha256_of_onion: self.1
+               }
+       }
+       fn to_inbound_htlc_state(self) -> InboundHTLCState {
+               InboundHTLCState::LocalRemoved(
+                       InboundHTLCRemovalReason::FailMalformed((self.1, self.0))
+               )
+       }
+       fn to_htlc_update_awaiting_ack(self, htlc_id: u64) -> HTLCUpdateAwaitingACK {
+               HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                       htlc_id,
+                       failure_code: self.0,
+                       sha256_of_onion: self.1
+               }
+       }
+}
+
+trait FailHTLCMessageName {
+       fn name() -> &'static str;
+}
+impl FailHTLCMessageName for msgs::UpdateFailHTLC {
+       fn name() -> &'static str {
+               "update_fail_htlc"
+       }
+}
+impl FailHTLCMessageName for msgs::UpdateFailMalformedHTLC {
+       fn name() -> &'static str {
+               "update_fail_malformed_htlc"
+       }
+}
+
 impl<SP: Deref> Channel<SP> where
        SP::Target: SignerProvider,
        <SP::Target as SignerProvider>::EcdsaSigner: WriteableEcdsaChannelSigner
@@ -2721,7 +2789,9 @@ impl<SP: Deref> Channel<SP> where
                                                        return UpdateFulfillFetch::DuplicateClaim {};
                                                }
                                        },
-                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } => {
+                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } |
+                                               &HTLCUpdateAwaitingACK::FailMalformedHTLC { htlc_id, .. } =>
+                                       {
                                                if htlc_id_arg == htlc_id {
                                                        log_warn!(logger, "Have preimage and want to fulfill HTLC with pending failure against channel {}", &self.context.channel_id());
                                                        // TODO: We may actually be able to switch to a fulfill here, though its
@@ -2818,6 +2888,17 @@ impl<SP: Deref> Channel<SP> where
                        .map(|msg_opt| assert!(msg_opt.is_none(), "We forced holding cell?"))
        }
 
+       /// Used for failing back with [`msgs::UpdateFailMalformedHTLC`]. For now, this is used when we
+       /// want to fail blinded HTLCs where we are not the intro node.
+       ///
+       /// See [`Self::queue_fail_htlc`] for more info.
+       pub fn queue_fail_malformed_htlc<L: Deref>(
+               &mut self, htlc_id_arg: u64, failure_code: u16, sha256_of_onion: [u8; 32], logger: &L
+       ) -> Result<(), ChannelError> where L::Target: Logger {
+               self.fail_htlc(htlc_id_arg, (failure_code, sha256_of_onion), true, logger)
+                       .map(|msg_opt| assert!(msg_opt.is_none(), "We forced holding cell?"))
+       }
+
        /// We can only have one resolution per HTLC. In some cases around reconnect, we may fulfill
        /// an HTLC more than once or fulfill once and then attempt to fail after reconnect. We cannot,
        /// however, fail more than once as we wait for an upstream failure to be irrevocably committed
@@ -2826,8 +2907,10 @@ impl<SP: Deref> Channel<SP> where
        /// If we do fail twice, we `debug_assert!(false)` and return `Ok(None)`. Thus, this will always
        /// return `Ok(_)` if preconditions are met. In any case, `Err`s will only be
        /// [`ChannelError::Ignore`].
-       fn fail_htlc<L: Deref>(&mut self, htlc_id_arg: u64, err_packet: msgs::OnionErrorPacket, mut force_holding_cell: bool, logger: &L)
-       -> Result<Option<msgs::UpdateFailHTLC>, ChannelError> where L::Target: Logger {
+       fn fail_htlc<L: Deref, E: FailHTLCContents + Clone>(
+               &mut self, htlc_id_arg: u64, err_packet: E, mut force_holding_cell: bool,
+               logger: &L
+       ) -> Result<Option<E::Message>, ChannelError> where L::Target: Logger {
                if !matches!(self.context.channel_state, ChannelState::ChannelReady(_)) {
                        panic!("Was asked to fail an HTLC when channel was not in an operational state");
                }
@@ -2880,7 +2963,9 @@ impl<SP: Deref> Channel<SP> where
                                                        return Ok(None);
                                                }
                                        },
-                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } => {
+                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } |
+                                               &HTLCUpdateAwaitingACK::FailMalformedHTLC { htlc_id, .. } =>
+                                       {
                                                if htlc_id_arg == htlc_id {
                                                        debug_assert!(false, "Tried to fail an HTLC that was already failed");
                                                        return Err(ChannelError::Ignore("Unable to find a pending HTLC which matched the given HTLC ID".to_owned()));
@@ -2890,24 +2975,18 @@ impl<SP: Deref> Channel<SP> where
                                }
                        }
                        log_trace!(logger, "Placing failure for HTLC ID {} in holding cell in channel {}.", htlc_id_arg, &self.context.channel_id());
-                       self.context.holding_cell_htlc_updates.push(HTLCUpdateAwaitingACK::FailHTLC {
-                               htlc_id: htlc_id_arg,
-                               err_packet,
-                       });
+                       self.context.holding_cell_htlc_updates.push(err_packet.to_htlc_update_awaiting_ack(htlc_id_arg));
                        return Ok(None);
                }
 
-               log_trace!(logger, "Failing HTLC ID {} back with a update_fail_htlc message in channel {}.", htlc_id_arg, &self.context.channel_id());
+               log_trace!(logger, "Failing HTLC ID {} back with {} message in channel {}.", htlc_id_arg,
+                       E::Message::name(), &self.context.channel_id());
                {
                        let htlc = &mut self.context.pending_inbound_htlcs[pending_idx];
-                       htlc.state = InboundHTLCState::LocalRemoved(InboundHTLCRemovalReason::FailRelay(err_packet.clone()));
+                       htlc.state = err_packet.clone().to_inbound_htlc_state();
                }
 
-               Ok(Some(msgs::UpdateFailHTLC {
-                       channel_id: self.context.channel_id(),
-                       htlc_id: htlc_id_arg,
-                       reason: err_packet
-               }))
+               Ok(Some(err_packet.to_message(htlc_id_arg, self.context.channel_id())))
        }
 
        // Message handlers:
@@ -3581,6 +3660,20 @@ impl<SP: Deref> Channel<SP> where
                                                        }
                                                }
                                        },
+                                       &HTLCUpdateAwaitingACK::FailMalformedHTLC { htlc_id, failure_code, sha256_of_onion } => {
+                                               match self.fail_htlc(htlc_id, (failure_code, sha256_of_onion), false, logger) {
+                                                       Ok(update_fail_malformed_opt) => {
+                                                               debug_assert!(update_fail_malformed_opt.is_some()); // See above comment
+                                                               update_fail_count += 1;
+                                                       },
+                                                       Err(e) => {
+                                                               if let ChannelError::Ignore(_) = e {}
+                                                               else {
+                                                                       panic!("Got a non-IgnoreError action trying to fail holding cell HTLC");
+                                                               }
+                                                       }
+                                               }
+                                       },
                                }
                        }
                        if update_add_count == 0 && update_fulfill_count == 0 && update_fail_count == 0 && self.context.holding_cell_update_fee.is_none() {
@@ -4190,7 +4283,7 @@ impl<SP: Deref> Channel<SP> where
 
        /// Indicates that the signer may have some signatures for us, so we should retry if we're
        /// blocked.
-       #[allow(unused)]
+       #[cfg(async_signing)]
        pub fn signer_maybe_unblocked<L: Deref>(&mut self, logger: &L) -> SignerResumeUpdates where L::Target: Logger {
                let commitment_update = if self.context.signer_pending_commitment_update {
                        self.get_last_commitment_update_for_send(logger).ok()
@@ -4294,11 +4387,16 @@ impl<SP: Deref> Channel<SP> where
                        }
                        update
                } else {
-                       if !self.context.signer_pending_commitment_update {
-                               log_trace!(logger, "Commitment update awaiting signer: setting signer_pending_commitment_update");
-                               self.context.signer_pending_commitment_update = true;
+                       #[cfg(not(async_signing))] {
+                               panic!("Failed to get signature for new commitment state");
+                       }
+                       #[cfg(async_signing)] {
+                               if !self.context.signer_pending_commitment_update {
+                                       log_trace!(logger, "Commitment update awaiting signer: setting signer_pending_commitment_update");
+                                       self.context.signer_pending_commitment_update = true;
+                               }
+                               return Err(());
                        }
-                       return Err(());
                };
                Ok(msgs::CommitmentUpdate {
                        update_add_htlcs, update_fulfill_htlcs, update_fail_htlcs, update_fail_malformed_htlcs, update_fee,
@@ -6382,9 +6480,14 @@ impl<SP: Deref> OutboundV1Channel<SP> where SP::Target: SignerProvider {
 
                let funding_created = self.get_funding_created_msg(logger);
                if funding_created.is_none() {
-                       if !self.context.signer_pending_funding {
-                               log_trace!(logger, "funding_created awaiting signer; setting signer_pending_funding");
-                               self.context.signer_pending_funding = true;
+                       #[cfg(not(async_signing))] {
+                               panic!("Failed to get signature for new funding creation");
+                       }
+                       #[cfg(async_signing)] {
+                               if !self.context.signer_pending_funding {
+                                       log_trace!(logger, "funding_created awaiting signer; setting signer_pending_funding");
+                                       self.context.signer_pending_funding = true;
+                               }
                        }
                }
 
@@ -6730,7 +6833,7 @@ impl<SP: Deref> OutboundV1Channel<SP> where SP::Target: SignerProvider {
 
        /// Indicates that the signer may have some signatures for us, so we should retry if we're
        /// blocked.
-       #[allow(unused)]
+       #[cfg(async_signing)]
        pub fn signer_maybe_unblocked<L: Deref>(&mut self, logger: &L) -> Option<msgs::FundingCreated> where L::Target: Logger {
                if self.context.signer_pending_funding && self.context.is_outbound() {
                        log_trace!(logger, "Signer unblocked a funding_created");
@@ -7455,6 +7558,8 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
 
                let mut holding_cell_skimmed_fees: Vec<Option<u64>> = Vec::new();
                let mut holding_cell_blinding_points: Vec<Option<PublicKey>> = Vec::new();
+               // Vec of (htlc_id, failure_code, sha256_of_onion)
+               let mut malformed_htlcs: Vec<(u64, u16, [u8; 32])> = Vec::new();
                (self.context.holding_cell_htlc_updates.len() as u64).write(writer)?;
                for update in self.context.holding_cell_htlc_updates.iter() {
                        match update {
@@ -7482,6 +7587,18 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                                        htlc_id.write(writer)?;
                                        err_packet.write(writer)?;
                                }
+                               &HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                                       htlc_id, failure_code, sha256_of_onion
+                               } => {
+                                       // We don't want to break downgrading by adding a new variant, so write a dummy
+                                       // `::FailHTLC` variant and write the real malformed error as an optional TLV.
+                                       malformed_htlcs.push((htlc_id, failure_code, sha256_of_onion));
+
+                                       let dummy_err_packet = msgs::OnionErrorPacket { data: Vec::new() };
+                                       2u8.write(writer)?;
+                                       htlc_id.write(writer)?;
+                                       dummy_err_packet.write(writer)?;
+                               }
                        }
                }
 
@@ -7642,6 +7759,7 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                        (38, self.context.is_batch_funding, option),
                        (39, pending_outbound_blinding_points, optional_vec),
                        (41, holding_cell_blinding_points, optional_vec),
+                       (43, malformed_htlcs, optional_vec), // Added in 0.0.119
                });
 
                Ok(())
@@ -7932,6 +8050,8 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                let mut pending_outbound_blinding_points_opt: Option<Vec<Option<PublicKey>>> = None;
                let mut holding_cell_blinding_points_opt: Option<Vec<Option<PublicKey>>> = None;
 
+               let mut malformed_htlcs: Option<Vec<(u64, u16, [u8; 32])>> = None;
+
                read_tlv_fields!(reader, {
                        (0, announcement_sigs, option),
                        (1, minimum_depth, option),
@@ -7960,6 +8080,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        (38, is_batch_funding, option),
                        (39, pending_outbound_blinding_points_opt, optional_vec),
                        (41, holding_cell_blinding_points_opt, optional_vec),
+                       (43, malformed_htlcs, optional_vec), // Added in 0.0.119
                });
 
                let (channel_keys_id, holder_signer) = if let Some(channel_keys_id) = channel_keys_id {
@@ -8054,6 +8175,22 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
                }
 
+               if let Some(malformed_htlcs) = malformed_htlcs {
+                       for (malformed_htlc_id, failure_code, sha256_of_onion) in malformed_htlcs {
+                               let htlc_idx = holding_cell_htlc_updates.iter().position(|htlc| {
+                                       if let HTLCUpdateAwaitingACK::FailHTLC { htlc_id, err_packet } = htlc {
+                                               let matches = *htlc_id == malformed_htlc_id;
+                                               if matches { debug_assert!(err_packet.data.is_empty()) }
+                                               matches
+                                       } else { false }
+                               }).ok_or(DecodeError::InvalidValue)?;
+                               let malformed_htlc = HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                                       htlc_id: malformed_htlc_id, failure_code, sha256_of_onion
+                               };
+                               let _ = core::mem::replace(&mut holding_cell_htlc_updates[htlc_idx], malformed_htlc);
+                       }
+               }
+
                Ok(Channel {
                        context: ChannelContext {
                                user_id,
@@ -8188,6 +8325,7 @@ mod tests {
        use bitcoin::blockdata::transaction::{Transaction, TxOut};
        use bitcoin::blockdata::opcodes;
        use bitcoin::network::constants::Network;
+       use crate::ln::onion_utils::INVALID_ONION_BLINDING;
        use crate::ln::{PaymentHash, PaymentPreimage};
        use crate::ln::channel_keys::{RevocationKey, RevocationBasepoint};
        use crate::ln::channelmanager::{self, HTLCSource, PaymentId};
@@ -8724,8 +8862,9 @@ mod tests {
        }
 
        #[test]
-       fn blinding_point_skimmed_fee_ser() {
-               // Ensure that channel blinding points and skimmed fees are (de)serialized properly.
+       fn blinding_point_skimmed_fee_malformed_ser() {
+               // Ensure that channel blinding points, skimmed fees, and malformed HTLCs are (de)serialized
+               // properly.
                let feeest = LowerBoundedFeeEstimator::new(&TestFeeEstimator{fee_est: 15000});
                let secp_ctx = Secp256k1::new();
                let seed = [42; 32];
@@ -8790,13 +8929,19 @@ mod tests {
                        payment_preimage: PaymentPreimage([42; 32]),
                        htlc_id: 0,
                };
-               let mut holding_cell_htlc_updates = Vec::with_capacity(10);
-               for i in 0..10 {
-                       if i % 3 == 0 {
+               let dummy_holding_cell_failed_htlc = |htlc_id| HTLCUpdateAwaitingACK::FailHTLC {
+                       htlc_id, err_packet: msgs::OnionErrorPacket { data: vec![42] }
+               };
+               let dummy_holding_cell_malformed_htlc = |htlc_id| HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                       htlc_id, failure_code: INVALID_ONION_BLINDING, sha256_of_onion: [0; 32],
+               };
+               let mut holding_cell_htlc_updates = Vec::with_capacity(12);
+               for i in 0..12 {
+                       if i % 5 == 0 {
                                holding_cell_htlc_updates.push(dummy_holding_cell_add_htlc.clone());
-                       } else if i % 3 == 1 {
+                       } else if i % 5 == 1 {
                                holding_cell_htlc_updates.push(dummy_holding_cell_claim_htlc.clone());
-                       } else {
+                       } else if i % 5 == 2 {
                                let mut dummy_add = dummy_holding_cell_add_htlc.clone();
                                if let HTLCUpdateAwaitingACK::AddHTLC {
                                        ref mut blinding_point, ref mut skimmed_fee_msat, ..
@@ -8805,6 +8950,10 @@ mod tests {
                                        *skimmed_fee_msat = Some(42);
                                } else { panic!() }
                                holding_cell_htlc_updates.push(dummy_add);
+                       } else if i % 5 == 3 {
+                               holding_cell_htlc_updates.push(dummy_holding_cell_malformed_htlc(i as u64));
+                       } else {
+                               holding_cell_htlc_updates.push(dummy_holding_cell_failed_htlc(i as u64));
                        }
                }
                chan.context.holding_cell_htlc_updates = holding_cell_htlc_updates.clone();