Refactor outgoing channel lookup out from decode_update_add_htlc_onion
authorWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 24 Jan 2024 01:14:14 +0000 (17:14 -0800)
committerWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 27 Mar 2024 21:28:03 +0000 (14:28 -0700)
In the future, we plan to complete remove `decode_update_add_htlc_onion`
and replace it with a batched variant. This refactor, while improving
readability in its current form, does not feature any functional changes
and allows us to reuse most of the logic in the batched variant.

lightning/src/ln/channelmanager.rs

index 60fd6e87eeabb3eb0bb02c20fe478a66dedb682a..8c4c68eea50f9ca167f9bd6c6872c030cf6fef1e 100644 (file)
@@ -3096,7 +3096,7 @@ where
 
        fn can_forward_htlc_to_outgoing_channel(
                &self, chan: &mut Channel<SP>, msg: &msgs::UpdateAddHTLC, next_packet: &NextPacketDetails
-       ) -> Result<Option<msgs::ChannelUpdate>, (&'static str, u16, Option<msgs::ChannelUpdate>)> {
+       ) -> Result<(), (&'static str, u16, Option<msgs::ChannelUpdate>)> {
                if !chan.context.should_announce() && !self.default_configuration.accept_forwards_to_priv_channels {
                        // Note that the behavior here should be identical to the above block - we
                        // should NOT reveal the existence or non-existence of a private channel if
@@ -3109,7 +3109,6 @@ where
                        // we don't have the channel here.
                        return Err(("Refusing to forward over real channel SCID as our counterparty requested.", 0x4000 | 10, None));
                }
-               let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
 
                // Note that we could technically not return an error yet here and just hope
                // that the connection is reestablished or monitor updated by the time we get
@@ -3120,6 +3119,7 @@ where
                        // If the channel_update we're going to return is disabled (i.e. the
                        // peer has been disabled for some time), return `channel_disabled`,
                        // otherwise return `temporary_channel_failure`.
+                       let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
                        if chan_update_opt.as_ref().map(|u| u.contents.flags & 2 == 2).unwrap_or(false) {
                                return Err(("Forwarding channel has been disconnected for some time.", 0x1000 | 20, chan_update_opt));
                        } else {
@@ -3127,13 +3127,79 @@ where
                        }
                }
                if next_packet.outgoing_amt_msat < chan.context.get_counterparty_htlc_minimum_msat() { // amount_below_minimum
+                       let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
                        return Err(("HTLC amount was below the htlc_minimum_msat", 0x1000 | 11, chan_update_opt));
                }
                if let Err((err, code)) = chan.htlc_satisfies_config(msg, next_packet.outgoing_amt_msat, next_packet.outgoing_cltv_value) {
+                       let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
                        return Err((err, code, chan_update_opt));
                }
 
-               Ok(chan_update_opt)
+               Ok(())
+       }
+
+       /// Executes a callback `C` that returns some value `X` on the channel found with the given
+       /// `scid`. `None` is returned when the channel is not found.
+       fn do_funded_channel_callback<X, C: Fn(&mut Channel<SP>) -> X>(
+               &self, scid: u64, callback: C,
+       ) -> Option<X> {
+               let (counterparty_node_id, channel_id) = match self.short_to_chan_info.read().unwrap().get(&scid).cloned() {
+                       None => return None,
+                       Some((cp_id, id)) => (cp_id, id),
+               };
+               let per_peer_state = self.per_peer_state.read().unwrap();
+               let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id);
+               if peer_state_mutex_opt.is_none() {
+                       return None;
+               }
+               let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap();
+               let peer_state = &mut *peer_state_lock;
+               match peer_state.channel_by_id.get_mut(&channel_id).and_then(
+                       |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None }
+               ) {
+                       None => None,
+                       Some(chan) => Some(callback(chan)),
+               }
+       }
+
+       fn can_forward_htlc(
+               &self, msg: &msgs::UpdateAddHTLC, next_packet_details: &NextPacketDetails
+       ) -> Result<(), (&'static str, u16, Option<msgs::ChannelUpdate>)> {
+               match self.do_funded_channel_callback(next_packet_details.outgoing_scid, |chan: &mut Channel<SP>| {
+                       self.can_forward_htlc_to_outgoing_channel(chan, msg, next_packet_details)
+               }) {
+                       Some(Ok(())) => {},
+                       Some(Err(e)) => return Err(e),
+                       None => {
+                               // If we couldn't find the channel info for the scid, it may be a phantom or
+                               // intercept forward.
+                               if (self.default_configuration.accept_intercept_htlcs &&
+                                       fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) ||
+                                       fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)
+                               {} else {
+                                       return Err(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
+                               }
+                       }
+               }
+
+               let cur_height = self.best_block.read().unwrap().height + 1;
+               if let Err((err_msg, err_code)) = check_incoming_htlc_cltv(
+                       cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry
+               ) {
+                       let chan_update_opt = self.do_funded_channel_callback(next_packet_details.outgoing_scid, |chan: &mut Channel<SP>| {
+                               self.get_channel_update_for_onion(next_packet_details.outgoing_scid, chan).ok()
+                       }).flatten();
+                       if err_code & 0x1000 != 0 && chan_update_opt.is_none() {
+                               // We really should set `incorrect_cltv_expiry` here but as we're not
+                               // forwarding over a real channel we can't generate a channel_update
+                               // for it. Instead we just return a generic temporary_node_failure.
+                               return Err((err_msg, 0x2000 | 2, None));
+                       }
+                       let chan_update_opt = if err_code & 0x1000 != 0 { chan_update_opt } else { None };
+                       return Err((err_msg, err_code, chan_update_opt));
+               }
+
+               Ok(())
        }
 
        fn htlc_failure_from_update_add_err(
@@ -3208,71 +3274,14 @@ where
 
                // Perform outbound checks here instead of in [`Self::construct_pending_htlc_info`] because we
                // can't hold the outbound peer state lock at the same time as the inbound peer state lock.
-               if let Some((err, code, chan_update)) = loop {
-                       let id_option = self.short_to_chan_info.read().unwrap().get(&next_packet_details.outgoing_scid).cloned();
-                       let forwarding_chan_info_opt = match id_option {
-                               None => { // unknown_next_peer
-                                       // Note that this is likely a timing oracle for detecting whether an scid is a
-                                       // phantom or an intercept.
-                                       if (self.default_configuration.accept_intercept_htlcs &&
-                                               fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) ||
-                                               fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)
-                                       {
-                                               None
-                                       } else {
-                                               break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
-                                       }
-                               },
-                               Some((cp_id, id)) => Some((cp_id.clone(), id.clone())),
-                       };
-                       let chan_update_opt = if let Some((counterparty_node_id, forwarding_id)) = forwarding_chan_info_opt {
-                               let per_peer_state = self.per_peer_state.read().unwrap();
-                               let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id);
-                               if peer_state_mutex_opt.is_none() {
-                                       break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
-                               }
-                               let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap();
-                               let peer_state = &mut *peer_state_lock;
-                               let chan = match peer_state.channel_by_id.get_mut(&forwarding_id).map(
-                                       |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None }
-                               ).flatten() {
-                                       None => {
-                                               // Channel was removed. The short_to_chan_info and channel_by_id maps
-                                               // have no consistency guarantees.
-                                               break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
-                                       },
-                                       Some(chan) => chan
-                               };
-                               match self.can_forward_htlc_to_outgoing_channel(chan, msg, &next_packet_details) {
-                                       Ok(chan_update_opt) => chan_update_opt,
-                                       Err(e) => break Some(e),
-                               }
-                       } else {
-                               None
-                       };
-
-                       let cur_height = self.best_block.read().unwrap().height + 1;
-
-                       if let Err((err_msg, code)) = check_incoming_htlc_cltv(
-                               cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry
-                       ) {
-                               if code & 0x1000 != 0 && chan_update_opt.is_none() {
-                                       // We really should set `incorrect_cltv_expiry` here but as we're not
-                                       // forwarding over a real channel we can't generate a channel_update
-                                       // for it. Instead we just return a generic temporary_node_failure.
-                                       break Some((err_msg, 0x2000 | 2, None))
-                               }
-                               let chan_update_opt = if code & 0x1000 != 0 { chan_update_opt } else { None };
-                               break Some((err_msg, code, chan_update_opt));
-                       }
+               self.can_forward_htlc(&msg, &next_packet_details).map_err(|e| {
+                       let (err_msg, err_code, chan_update_opt) = e;
+                       self.htlc_failure_from_update_add_err(
+                               msg, counterparty_node_id, err_msg, err_code, chan_update_opt,
+                               next_hop.is_intro_node_blinded_forward(), &shared_secret
+                       )
+               })?;
 
-                       break None;
-               }
-               {
-                       return Err(self.htlc_failure_from_update_add_err(
-                               msg, counterparty_node_id, err, code, chan_update, next_hop.is_intro_node_blinded_forward(), &shared_secret
-                       ));
-               }
                Ok((next_hop, shared_secret, Some(next_packet_details.next_packet_pubkey)))
        }