Support failing blinded non-intro HTLCs after RAA processing.
[rust-lightning] / lightning / src / ln / channel.rs
index 582f67878e60d8e59d4a36b567e12a8fcded0798..bac45076891aefc9ab2df35187e2847002e20c27 100644 (file)
@@ -259,6 +259,11 @@ enum HTLCUpdateAwaitingACK {
                htlc_id: u64,
                err_packet: msgs::OnionErrorPacket,
        },
+       FailMalformedHTLC {
+               htlc_id: u64,
+               failure_code: u16,
+               sha256_of_onion: [u8; 32],
+       },
 }
 
 macro_rules! define_state_flags {
@@ -2518,6 +2523,64 @@ struct CommitmentTxInfoCached {
        feerate: u32,
 }
 
+/// Contents of a wire message that fails an HTLC backwards. Useful for [`Channel::fail_htlc`] to
+/// fail with either [`msgs::UpdateFailMalformedHTLC`] or [`msgs::UpdateFailHTLC`] as needed.
+trait FailHTLCContents {
+       type Message: FailHTLCMessageName;
+       fn to_message(self, htlc_id: u64, channel_id: ChannelId) -> Self::Message;
+       fn to_inbound_htlc_state(self) -> InboundHTLCState;
+       fn to_htlc_update_awaiting_ack(self, htlc_id: u64) -> HTLCUpdateAwaitingACK;
+}
+impl FailHTLCContents for msgs::OnionErrorPacket {
+       type Message = msgs::UpdateFailHTLC;
+       fn to_message(self, htlc_id: u64, channel_id: ChannelId) -> Self::Message {
+               msgs::UpdateFailHTLC { htlc_id, channel_id, reason: self }
+       }
+       fn to_inbound_htlc_state(self) -> InboundHTLCState {
+               InboundHTLCState::LocalRemoved(InboundHTLCRemovalReason::FailRelay(self))
+       }
+       fn to_htlc_update_awaiting_ack(self, htlc_id: u64) -> HTLCUpdateAwaitingACK {
+               HTLCUpdateAwaitingACK::FailHTLC { htlc_id, err_packet: self }
+       }
+}
+impl FailHTLCContents for (u16, [u8; 32]) {
+       type Message = msgs::UpdateFailMalformedHTLC; // (failure_code, sha256_of_onion)
+       fn to_message(self, htlc_id: u64, channel_id: ChannelId) -> Self::Message {
+               msgs::UpdateFailMalformedHTLC {
+                       htlc_id,
+                       channel_id,
+                       failure_code: self.0,
+                       sha256_of_onion: self.1
+               }
+       }
+       fn to_inbound_htlc_state(self) -> InboundHTLCState {
+               InboundHTLCState::LocalRemoved(
+                       InboundHTLCRemovalReason::FailMalformed((self.1, self.0))
+               )
+       }
+       fn to_htlc_update_awaiting_ack(self, htlc_id: u64) -> HTLCUpdateAwaitingACK {
+               HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                       htlc_id,
+                       failure_code: self.0,
+                       sha256_of_onion: self.1
+               }
+       }
+}
+
+trait FailHTLCMessageName {
+       fn name() -> &'static str;
+}
+impl FailHTLCMessageName for msgs::UpdateFailHTLC {
+       fn name() -> &'static str {
+               "update_fail_htlc"
+       }
+}
+impl FailHTLCMessageName for msgs::UpdateFailMalformedHTLC {
+       fn name() -> &'static str {
+               "update_fail_malformed_htlc"
+       }
+}
+
 impl<SP: Deref> Channel<SP> where
        SP::Target: SignerProvider,
        <SP::Target as SignerProvider>::EcdsaSigner: WriteableEcdsaChannelSigner
@@ -2719,7 +2782,9 @@ impl<SP: Deref> Channel<SP> where
                                                        return UpdateFulfillFetch::DuplicateClaim {};
                                                }
                                        },
-                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } => {
+                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } |
+                                               &HTLCUpdateAwaitingACK::FailMalformedHTLC { htlc_id, .. } =>
+                                       {
                                                if htlc_id_arg == htlc_id {
                                                        log_warn!(logger, "Have preimage and want to fulfill HTLC with pending failure against channel {}", &self.context.channel_id());
                                                        // TODO: We may actually be able to switch to a fulfill here, though its
@@ -2816,6 +2881,17 @@ impl<SP: Deref> Channel<SP> where
                        .map(|msg_opt| assert!(msg_opt.is_none(), "We forced holding cell?"))
        }
 
+       /// Used for failing back with [`msgs::UpdateFailMalformedHTLC`]. For now, this is used when we
+       /// want to fail blinded HTLCs where we are not the intro node.
+       ///
+       /// See [`Self::queue_fail_htlc`] for more info.
+       pub fn queue_fail_malformed_htlc<L: Deref>(
+               &mut self, htlc_id_arg: u64, failure_code: u16, sha256_of_onion: [u8; 32], logger: &L
+       ) -> Result<(), ChannelError> where L::Target: Logger {
+               self.fail_htlc(htlc_id_arg, (failure_code, sha256_of_onion), true, logger)
+                       .map(|msg_opt| assert!(msg_opt.is_none(), "We forced holding cell?"))
+       }
+
        /// We can only have one resolution per HTLC. In some cases around reconnect, we may fulfill
        /// an HTLC more than once or fulfill once and then attempt to fail after reconnect. We cannot,
        /// however, fail more than once as we wait for an upstream failure to be irrevocably committed
@@ -2824,8 +2900,10 @@ impl<SP: Deref> Channel<SP> where
        /// If we do fail twice, we `debug_assert!(false)` and return `Ok(None)`. Thus, this will always
        /// return `Ok(_)` if preconditions are met. In any case, `Err`s will only be
        /// [`ChannelError::Ignore`].
-       fn fail_htlc<L: Deref>(&mut self, htlc_id_arg: u64, err_packet: msgs::OnionErrorPacket, mut force_holding_cell: bool, logger: &L)
-       -> Result<Option<msgs::UpdateFailHTLC>, ChannelError> where L::Target: Logger {
+       fn fail_htlc<L: Deref, E: FailHTLCContents + Clone>(
+               &mut self, htlc_id_arg: u64, err_packet: E, mut force_holding_cell: bool,
+               logger: &L
+       ) -> Result<Option<E::Message>, ChannelError> where L::Target: Logger {
                if !matches!(self.context.channel_state, ChannelState::ChannelReady(_)) {
                        panic!("Was asked to fail an HTLC when channel was not in an operational state");
                }
@@ -2878,7 +2956,9 @@ impl<SP: Deref> Channel<SP> where
                                                        return Ok(None);
                                                }
                                        },
-                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } => {
+                                       &HTLCUpdateAwaitingACK::FailHTLC { htlc_id, .. } |
+                                               &HTLCUpdateAwaitingACK::FailMalformedHTLC { htlc_id, .. } =>
+                                       {
                                                if htlc_id_arg == htlc_id {
                                                        debug_assert!(false, "Tried to fail an HTLC that was already failed");
                                                        return Err(ChannelError::Ignore("Unable to find a pending HTLC which matched the given HTLC ID".to_owned()));
@@ -2888,24 +2968,18 @@ impl<SP: Deref> Channel<SP> where
                                }
                        }
                        log_trace!(logger, "Placing failure for HTLC ID {} in holding cell in channel {}.", htlc_id_arg, &self.context.channel_id());
-                       self.context.holding_cell_htlc_updates.push(HTLCUpdateAwaitingACK::FailHTLC {
-                               htlc_id: htlc_id_arg,
-                               err_packet,
-                       });
+                       self.context.holding_cell_htlc_updates.push(err_packet.to_htlc_update_awaiting_ack(htlc_id_arg));
                        return Ok(None);
                }
 
-               log_trace!(logger, "Failing HTLC ID {} back with a update_fail_htlc message in channel {}.", htlc_id_arg, &self.context.channel_id());
+               log_trace!(logger, "Failing HTLC ID {} back with {} message in channel {}.", htlc_id_arg,
+                       E::Message::name(), &self.context.channel_id());
                {
                        let htlc = &mut self.context.pending_inbound_htlcs[pending_idx];
-                       htlc.state = InboundHTLCState::LocalRemoved(InboundHTLCRemovalReason::FailRelay(err_packet.clone()));
+                       htlc.state = err_packet.clone().to_inbound_htlc_state();
                }
 
-               Ok(Some(msgs::UpdateFailHTLC {
-                       channel_id: self.context.channel_id(),
-                       htlc_id: htlc_id_arg,
-                       reason: err_packet
-               }))
+               Ok(Some(err_packet.to_message(htlc_id_arg, self.context.channel_id())))
        }
 
        // Message handlers:
@@ -3563,6 +3637,20 @@ impl<SP: Deref> Channel<SP> where
                                                        }
                                                }
                                        },
+                                       &HTLCUpdateAwaitingACK::FailMalformedHTLC { htlc_id, failure_code, sha256_of_onion } => {
+                                               match self.fail_htlc(htlc_id, (failure_code, sha256_of_onion), false, logger) {
+                                                       Ok(update_fail_malformed_opt) => {
+                                                               debug_assert!(update_fail_malformed_opt.is_some()); // See above comment
+                                                               update_fail_count += 1;
+                                                       },
+                                                       Err(e) => {
+                                                               if let ChannelError::Ignore(_) = e {}
+                                                               else {
+                                                                       panic!("Got a non-IgnoreError action trying to fail holding cell HTLC");
+                                                               }
+                                                       }
+                                               }
+                                       },
                                }
                        }
                        if update_add_count == 0 && update_fulfill_count == 0 && update_fail_count == 0 && self.context.holding_cell_update_fee.is_none() {
@@ -7433,6 +7521,8 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
 
                let mut holding_cell_skimmed_fees: Vec<Option<u64>> = Vec::new();
                let mut holding_cell_blinding_points: Vec<Option<PublicKey>> = Vec::new();
+               // Vec of (htlc_id, failure_code, sha256_of_onion)
+               let mut malformed_htlcs: Vec<(u64, u16, [u8; 32])> = Vec::new();
                (self.context.holding_cell_htlc_updates.len() as u64).write(writer)?;
                for update in self.context.holding_cell_htlc_updates.iter() {
                        match update {
@@ -7460,6 +7550,18 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                                        htlc_id.write(writer)?;
                                        err_packet.write(writer)?;
                                }
+                               &HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                                       htlc_id, failure_code, sha256_of_onion
+                               } => {
+                                       // We don't want to break downgrading by adding a new variant, so write a dummy
+                                       // `::FailHTLC` variant and write the real malformed error as an optional TLV.
+                                       malformed_htlcs.push((htlc_id, failure_code, sha256_of_onion));
+
+                                       let dummy_err_packet = msgs::OnionErrorPacket { data: Vec::new() };
+                                       2u8.write(writer)?;
+                                       htlc_id.write(writer)?;
+                                       dummy_err_packet.write(writer)?;
+                               }
                        }
                }
 
@@ -7620,6 +7722,7 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                        (38, self.context.is_batch_funding, option),
                        (39, pending_outbound_blinding_points, optional_vec),
                        (41, holding_cell_blinding_points, optional_vec),
+                       (43, malformed_htlcs, optional_vec), // Added in 0.0.119
                });
 
                Ok(())
@@ -7910,6 +8013,8 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                let mut pending_outbound_blinding_points_opt: Option<Vec<Option<PublicKey>>> = None;
                let mut holding_cell_blinding_points_opt: Option<Vec<Option<PublicKey>>> = None;
 
+               let mut malformed_htlcs: Option<Vec<(u64, u16, [u8; 32])>> = None;
+
                read_tlv_fields!(reader, {
                        (0, announcement_sigs, option),
                        (1, minimum_depth, option),
@@ -7938,6 +8043,7 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        (38, is_batch_funding, option),
                        (39, pending_outbound_blinding_points_opt, optional_vec),
                        (41, holding_cell_blinding_points_opt, optional_vec),
+                       (43, malformed_htlcs, optional_vec), // Added in 0.0.119
                });
 
                let (channel_keys_id, holder_signer) = if let Some(channel_keys_id) = channel_keys_id {
@@ -8032,6 +8138,22 @@ impl<'a, 'b, 'c, ES: Deref, SP: Deref> ReadableArgs<(&'a ES, &'b SP, u32, &'c Ch
                        if iter.next().is_some() { return Err(DecodeError::InvalidValue) }
                }
 
+               if let Some(malformed_htlcs) = malformed_htlcs {
+                       for (malformed_htlc_id, failure_code, sha256_of_onion) in malformed_htlcs {
+                               let htlc_idx = holding_cell_htlc_updates.iter().position(|htlc| {
+                                       if let HTLCUpdateAwaitingACK::FailHTLC { htlc_id, err_packet } = htlc {
+                                               let matches = *htlc_id == malformed_htlc_id;
+                                               if matches { debug_assert!(err_packet.data.is_empty()) }
+                                               matches
+                                       } else { false }
+                               }).ok_or(DecodeError::InvalidValue)?;
+                               let malformed_htlc = HTLCUpdateAwaitingACK::FailMalformedHTLC {
+                                       htlc_id: malformed_htlc_id, failure_code, sha256_of_onion
+                               };
+                               let _ = core::mem::replace(&mut holding_cell_htlc_updates[htlc_idx], malformed_htlc);
+                       }
+               }
+
                Ok(Channel {
                        context: ChannelContext {
                                user_id,