Refactor outgoing HTLC checks out from decode_update_add_htlc_onion
authorWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 24 Jan 2024 01:08:17 +0000 (17:08 -0800)
committerWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 27 Mar 2024 21:28:02 +0000 (14:28 -0700)
In the future, we plan to complete remove `decode_update_add_htlc_onion`
and replace it with a batched variant. This refactor, while improving
readability in its current form, does not feature any functional changes
and allows us to reuse most of the logic in the batched variant.

lightning/src/ln/channelmanager.rs

index 85a9a40412b43c41088051f191581b0fc7d92703..60fd6e87eeabb3eb0bb02c20fe478a66dedb682a 100644 (file)
@@ -3094,6 +3094,48 @@ where
                }
        }
 
+       fn can_forward_htlc_to_outgoing_channel(
+               &self, chan: &mut Channel<SP>, msg: &msgs::UpdateAddHTLC, next_packet: &NextPacketDetails
+       ) -> Result<Option<msgs::ChannelUpdate>, (&'static str, u16, Option<msgs::ChannelUpdate>)> {
+               if !chan.context.should_announce() && !self.default_configuration.accept_forwards_to_priv_channels {
+                       // Note that the behavior here should be identical to the above block - we
+                       // should NOT reveal the existence or non-existence of a private channel if
+                       // we don't allow forwards outbound over them.
+                       return Err(("Refusing to forward to a private channel based on our config.", 0x4000 | 10, None));
+               }
+               if chan.context.get_channel_type().supports_scid_privacy() && next_packet.outgoing_scid != chan.context.outbound_scid_alias() {
+                       // `option_scid_alias` (referred to in LDK as `scid_privacy`) means
+                       // "refuse to forward unless the SCID alias was used", so we pretend
+                       // we don't have the channel here.
+                       return Err(("Refusing to forward over real channel SCID as our counterparty requested.", 0x4000 | 10, None));
+               }
+               let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
+
+               // Note that we could technically not return an error yet here and just hope
+               // that the connection is reestablished or monitor updated by the time we get
+               // around to doing the actual forward, but better to fail early if we can and
+               // hopefully an attacker trying to path-trace payments cannot make this occur
+               // on a small/per-node/per-channel scale.
+               if !chan.context.is_live() { // channel_disabled
+                       // If the channel_update we're going to return is disabled (i.e. the
+                       // peer has been disabled for some time), return `channel_disabled`,
+                       // otherwise return `temporary_channel_failure`.
+                       if chan_update_opt.as_ref().map(|u| u.contents.flags & 2 == 2).unwrap_or(false) {
+                               return Err(("Forwarding channel has been disconnected for some time.", 0x1000 | 20, chan_update_opt));
+                       } else {
+                               return Err(("Forwarding channel is not in a ready state.", 0x1000 | 7, chan_update_opt));
+                       }
+               }
+               if next_packet.outgoing_amt_msat < chan.context.get_counterparty_htlc_minimum_msat() { // amount_below_minimum
+                       return Err(("HTLC amount was below the htlc_minimum_msat", 0x1000 | 11, chan_update_opt));
+               }
+               if let Err((err, code)) = chan.htlc_satisfies_config(msg, next_packet.outgoing_amt_msat, next_packet.outgoing_cltv_value) {
+                       return Err((err, code, chan_update_opt));
+               }
+
+               Ok(chan_update_opt)
+       }
+
        fn htlc_failure_from_update_add_err(
                &self, msg: &msgs::UpdateAddHTLC, counterparty_node_id: &PublicKey, err_msg: &'static str,
                mut err_code: u16, chan_update: Option<msgs::ChannelUpdate>, is_intro_node_blinded_forward: bool,
@@ -3158,9 +3200,7 @@ where
                        msg, &self.node_signer, &self.logger, &self.secp_ctx
                )?;
 
-               let NextPacketDetails {
-                       next_packet_pubkey, outgoing_amt_msat, outgoing_scid, outgoing_cltv_value
-               } = match next_packet_details_opt {
+               let next_packet_details = match next_packet_details_opt {
                        Some(next_packet_details) => next_packet_details,
                        // it is a receive, so no need for outbound checks
                        None => return Ok((next_hop, shared_secret, None)),
@@ -3169,14 +3209,14 @@ where
                // Perform outbound checks here instead of in [`Self::construct_pending_htlc_info`] because we
                // can't hold the outbound peer state lock at the same time as the inbound peer state lock.
                if let Some((err, code, chan_update)) = loop {
-                       let id_option = self.short_to_chan_info.read().unwrap().get(&outgoing_scid).cloned();
+                       let id_option = self.short_to_chan_info.read().unwrap().get(&next_packet_details.outgoing_scid).cloned();
                        let forwarding_chan_info_opt = match id_option {
                                None => { // unknown_next_peer
                                        // Note that this is likely a timing oracle for detecting whether an scid is a
                                        // phantom or an intercept.
                                        if (self.default_configuration.accept_intercept_htlcs &&
-                                               fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, outgoing_scid, &self.chain_hash)) ||
-                                               fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, outgoing_scid, &self.chain_hash)
+                                               fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) ||
+                                               fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)
                                        {
                                                None
                                        } else {
@@ -3203,42 +3243,10 @@ where
                                        },
                                        Some(chan) => chan
                                };
-                               if !chan.context.should_announce() && !self.default_configuration.accept_forwards_to_priv_channels {
-                                       // Note that the behavior here should be identical to the above block - we
-                                       // should NOT reveal the existence or non-existence of a private channel if
-                                       // we don't allow forwards outbound over them.
-                                       break Some(("Refusing to forward to a private channel based on our config.", 0x4000 | 10, None));
-                               }
-                               if chan.context.get_channel_type().supports_scid_privacy() && outgoing_scid != chan.context.outbound_scid_alias() {
-                                       // `option_scid_alias` (referred to in LDK as `scid_privacy`) means
-                                       // "refuse to forward unless the SCID alias was used", so we pretend
-                                       // we don't have the channel here.
-                                       break Some(("Refusing to forward over real channel SCID as our counterparty requested.", 0x4000 | 10, None));
-                               }
-                               let chan_update_opt = self.get_channel_update_for_onion(outgoing_scid, chan).ok();
-
-                               // Note that we could technically not return an error yet here and just hope
-                               // that the connection is reestablished or monitor updated by the time we get
-                               // around to doing the actual forward, but better to fail early if we can and
-                               // hopefully an attacker trying to path-trace payments cannot make this occur
-                               // on a small/per-node/per-channel scale.
-                               if !chan.context.is_live() { // channel_disabled
-                                       // If the channel_update we're going to return is disabled (i.e. the
-                                       // peer has been disabled for some time), return `channel_disabled`,
-                                       // otherwise return `temporary_channel_failure`.
-                                       if chan_update_opt.as_ref().map(|u| u.contents.flags & 2 == 2).unwrap_or(false) {
-                                               break Some(("Forwarding channel has been disconnected for some time.", 0x1000 | 20, chan_update_opt));
-                                       } else {
-                                               break Some(("Forwarding channel is not in a ready state.", 0x1000 | 7, chan_update_opt));
-                                       }
-                               }
-                               if outgoing_amt_msat < chan.context.get_counterparty_htlc_minimum_msat() { // amount_below_minimum
-                                       break Some(("HTLC amount was below the htlc_minimum_msat", 0x1000 | 11, chan_update_opt));
-                               }
-                               if let Err((err, code)) = chan.htlc_satisfies_config(&msg, outgoing_amt_msat, outgoing_cltv_value) {
-                                       break Some((err, code, chan_update_opt));
+                               match self.can_forward_htlc_to_outgoing_channel(chan, msg, &next_packet_details) {
+                                       Ok(chan_update_opt) => chan_update_opt,
+                                       Err(e) => break Some(e),
                                }
-                               chan_update_opt
                        } else {
                                None
                        };
@@ -3246,7 +3254,7 @@ where
                        let cur_height = self.best_block.read().unwrap().height + 1;
 
                        if let Err((err_msg, code)) = check_incoming_htlc_cltv(
-                               cur_height, outgoing_cltv_value, msg.cltv_expiry
+                               cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry
                        ) {
                                if code & 0x1000 != 0 && chan_update_opt.is_none() {
                                        // We really should set `incorrect_cltv_expiry` here but as we're not
@@ -3265,7 +3273,7 @@ where
                                msg, counterparty_node_id, err, code, chan_update, next_hop.is_intro_node_blinded_forward(), &shared_secret
                        ));
                }
-               Ok((next_hop, shared_secret, Some(next_packet_pubkey)))
+               Ok((next_hop, shared_secret, Some(next_packet_details.next_packet_pubkey)))
        }
 
        fn construct_pending_htlc_status<'a>(