Refactor HTLCFailureMsg generation out from decode_update_add_htlc_onion
authorWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 24 Jan 2024 00:55:40 +0000 (16:55 -0800)
committerWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 27 Mar 2024 21:28:02 +0000 (14:28 -0700)
In the future, we plan to complete remove `decode_update_add_htlc_onion`
and replace it with a batched variant. This refactor, while improving
readability in its current form, does not feature any functional changes
and allows us to reuse most of the logic in the batched variant.

lightning/src/ln/channelmanager.rs

index 685caba75c2024469178b04e367db625fd4fbfce..85a9a40412b43c41088051f191581b0fc7d92703 100644 (file)
@@ -3094,6 +3094,61 @@ where
                }
        }
 
+       fn htlc_failure_from_update_add_err(
+               &self, msg: &msgs::UpdateAddHTLC, counterparty_node_id: &PublicKey, err_msg: &'static str,
+               mut err_code: u16, chan_update: Option<msgs::ChannelUpdate>, is_intro_node_blinded_forward: bool,
+               shared_secret: &[u8; 32]
+       ) -> HTLCFailureMsg {
+               let mut res = VecWriter(Vec::with_capacity(chan_update.serialized_length() + 2 + 8 + 2));
+               if let Some(chan_update) = chan_update {
+                       if err_code == 0x1000 | 11 || err_code == 0x1000 | 12 {
+                               msg.amount_msat.write(&mut res).expect("Writes cannot fail");
+                       }
+                       else if err_code == 0x1000 | 13 {
+                               msg.cltv_expiry.write(&mut res).expect("Writes cannot fail");
+                       }
+                       else if err_code == 0x1000 | 20 {
+                               // TODO: underspecified, follow https://github.com/lightning/bolts/issues/791
+                               0u16.write(&mut res).expect("Writes cannot fail");
+                       }
+                       (chan_update.serialized_length() as u16 + 2).write(&mut res).expect("Writes cannot fail");
+                       msgs::ChannelUpdate::TYPE.write(&mut res).expect("Writes cannot fail");
+                       chan_update.write(&mut res).expect("Writes cannot fail");
+               } else if err_code & 0x1000 == 0x1000 {
+                       // If we're trying to return an error that requires a `channel_update` but
+                       // we're forwarding to a phantom or intercept "channel" (i.e. cannot
+                       // generate an update), just use the generic "temporary_node_failure"
+                       // instead.
+                       err_code = 0x2000 | 2;
+               }
+
+               log_info!(
+                       WithContext::from(&self.logger, Some(*counterparty_node_id), Some(msg.channel_id)),
+                       "Failed to accept/forward incoming HTLC: {}", err_msg
+               );
+               // If `msg.blinding_point` is set, we must always fail with malformed.
+               if msg.blinding_point.is_some() {
+                       return HTLCFailureMsg::Malformed(msgs::UpdateFailMalformedHTLC {
+                               channel_id: msg.channel_id,
+                               htlc_id: msg.htlc_id,
+                               sha256_of_onion: [0; 32],
+                               failure_code: INVALID_ONION_BLINDING,
+                       });
+               }
+
+               let (err_code, err_data) = if is_intro_node_blinded_forward {
+                       (INVALID_ONION_BLINDING, &[0; 32][..])
+               } else {
+                       (err_code, &res.0[..])
+               };
+               HTLCFailureMsg::Relay(msgs::UpdateFailHTLC {
+                       channel_id: msg.channel_id,
+                       htlc_id: msg.htlc_id,
+                       reason: HTLCFailReason::reason(err_code, err_data.to_vec())
+                               .get_encrypted_failure_packet(shared_secret, &None),
+               })
+       }
+
        fn decode_update_add_htlc_onion(
                &self, msg: &msgs::UpdateAddHTLC, counterparty_node_id: &PublicKey,
        ) -> Result<
@@ -3103,36 +3158,6 @@ where
                        msg, &self.node_signer, &self.logger, &self.secp_ctx
                )?;
 
-               macro_rules! return_err {
-                       ($msg: expr, $err_code: expr, $data: expr) => {
-                               {
-                                       log_info!(
-                                               WithContext::from(&self.logger, Some(*counterparty_node_id), Some(msg.channel_id)),
-                                               "Failed to accept/forward incoming HTLC: {}", $msg
-                                       );
-                                       // If `msg.blinding_point` is set, we must always fail with malformed.
-                                       if msg.blinding_point.is_some() {
-                                               return Err(HTLCFailureMsg::Malformed(msgs::UpdateFailMalformedHTLC {
-                                                       channel_id: msg.channel_id,
-                                                       htlc_id: msg.htlc_id,
-                                                       sha256_of_onion: [0; 32],
-                                                       failure_code: INVALID_ONION_BLINDING,
-                                               }));
-                                       }
-
-                                       let (err_code, err_data) = if next_hop.is_intro_node_blinded_forward() {
-                                               (INVALID_ONION_BLINDING, &[0; 32][..])
-                                       } else { ($err_code, $data) };
-                                       return Err(HTLCFailureMsg::Relay(msgs::UpdateFailHTLC {
-                                               channel_id: msg.channel_id,
-                                               htlc_id: msg.htlc_id,
-                                               reason: HTLCFailReason::reason(err_code, err_data.to_vec())
-                                                       .get_encrypted_failure_packet(&shared_secret, &None),
-                                       }));
-                               }
-                       }
-               }
-
                let NextPacketDetails {
                        next_packet_pubkey, outgoing_amt_msat, outgoing_scid, outgoing_cltv_value
                } = match next_packet_details_opt {
@@ -3143,7 +3168,7 @@ where
 
                // Perform outbound checks here instead of in [`Self::construct_pending_htlc_info`] because we
                // can't hold the outbound peer state lock at the same time as the inbound peer state lock.
-               if let Some((err, mut code, chan_update)) = loop {
+               if let Some((err, code, chan_update)) = loop {
                        let id_option = self.short_to_chan_info.read().unwrap().get(&outgoing_scid).cloned();
                        let forwarding_chan_info_opt = match id_option {
                                None => { // unknown_next_peer
@@ -3236,29 +3261,9 @@ where
                        break None;
                }
                {
-                       let mut res = VecWriter(Vec::with_capacity(chan_update.serialized_length() + 2 + 8 + 2));
-                       if let Some(chan_update) = chan_update {
-                               if code == 0x1000 | 11 || code == 0x1000 | 12 {
-                                       msg.amount_msat.write(&mut res).expect("Writes cannot fail");
-                               }
-                               else if code == 0x1000 | 13 {
-                                       msg.cltv_expiry.write(&mut res).expect("Writes cannot fail");
-                               }
-                               else if code == 0x1000 | 20 {
-                                       // TODO: underspecified, follow https://github.com/lightning/bolts/issues/791
-                                       0u16.write(&mut res).expect("Writes cannot fail");
-                               }
-                               (chan_update.serialized_length() as u16 + 2).write(&mut res).expect("Writes cannot fail");
-                               msgs::ChannelUpdate::TYPE.write(&mut res).expect("Writes cannot fail");
-                               chan_update.write(&mut res).expect("Writes cannot fail");
-                       } else if code & 0x1000 == 0x1000 {
-                               // If we're trying to return an error that requires a `channel_update` but
-                               // we're forwarding to a phantom or intercept "channel" (i.e. cannot
-                               // generate an update), just use the generic "temporary_node_failure"
-                               // instead.
-                               code = 0x2000 | 2;
-                       }
-                       return_err!(err, code, &res.0[..]);
+                       return Err(self.htlc_failure_from_update_add_err(
+                               msg, counterparty_node_id, err, code, chan_update, next_hop.is_intro_node_blinded_forward(), &shared_secret
+                       ));
                }
                Ok((next_hop, shared_secret, Some(next_packet_pubkey)))
        }