From b9ca5788f5b1caca660004fe94a64017ae8a7b5e Mon Sep 17 00:00:00 2001 From: Wilmer Paulino Date: Tue, 23 Jan 2024 17:14:14 -0800 Subject: [PATCH] Refactor outgoing channel lookup out from decode_update_add_htlc_onion In the future, we plan to complete remove `decode_update_add_htlc_onion` and replace it with a batched variant. This refactor, while improving readability in its current form, does not feature any functional changes and allows us to reuse most of the logic in the batched variant. --- lightning/src/ln/channelmanager.rs | 143 +++++++++++++++-------------- 1 file changed, 76 insertions(+), 67 deletions(-) diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 60fd6e87e..8c4c68eea 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -3096,7 +3096,7 @@ where fn can_forward_htlc_to_outgoing_channel( &self, chan: &mut Channel, msg: &msgs::UpdateAddHTLC, next_packet: &NextPacketDetails - ) -> Result, (&'static str, u16, Option)> { + ) -> Result<(), (&'static str, u16, Option)> { if !chan.context.should_announce() && !self.default_configuration.accept_forwards_to_priv_channels { // Note that the behavior here should be identical to the above block - we // should NOT reveal the existence or non-existence of a private channel if @@ -3109,7 +3109,6 @@ where // we don't have the channel here. return Err(("Refusing to forward over real channel SCID as our counterparty requested.", 0x4000 | 10, None)); } - let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok(); // Note that we could technically not return an error yet here and just hope // that the connection is reestablished or monitor updated by the time we get @@ -3120,6 +3119,7 @@ where // If the channel_update we're going to return is disabled (i.e. the // peer has been disabled for some time), return `channel_disabled`, // otherwise return `temporary_channel_failure`. + let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok(); if chan_update_opt.as_ref().map(|u| u.contents.flags & 2 == 2).unwrap_or(false) { return Err(("Forwarding channel has been disconnected for some time.", 0x1000 | 20, chan_update_opt)); } else { @@ -3127,13 +3127,79 @@ where } } if next_packet.outgoing_amt_msat < chan.context.get_counterparty_htlc_minimum_msat() { // amount_below_minimum + let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok(); return Err(("HTLC amount was below the htlc_minimum_msat", 0x1000 | 11, chan_update_opt)); } if let Err((err, code)) = chan.htlc_satisfies_config(msg, next_packet.outgoing_amt_msat, next_packet.outgoing_cltv_value) { + let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok(); return Err((err, code, chan_update_opt)); } - Ok(chan_update_opt) + Ok(()) + } + + /// Executes a callback `C` that returns some value `X` on the channel found with the given + /// `scid`. `None` is returned when the channel is not found. + fn do_funded_channel_callback) -> X>( + &self, scid: u64, callback: C, + ) -> Option { + let (counterparty_node_id, channel_id) = match self.short_to_chan_info.read().unwrap().get(&scid).cloned() { + None => return None, + Some((cp_id, id)) => (cp_id, id), + }; + let per_peer_state = self.per_peer_state.read().unwrap(); + let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id); + if peer_state_mutex_opt.is_none() { + return None; + } + let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let peer_state = &mut *peer_state_lock; + match peer_state.channel_by_id.get_mut(&channel_id).and_then( + |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None } + ) { + None => None, + Some(chan) => Some(callback(chan)), + } + } + + fn can_forward_htlc( + &self, msg: &msgs::UpdateAddHTLC, next_packet_details: &NextPacketDetails + ) -> Result<(), (&'static str, u16, Option)> { + match self.do_funded_channel_callback(next_packet_details.outgoing_scid, |chan: &mut Channel| { + self.can_forward_htlc_to_outgoing_channel(chan, msg, next_packet_details) + }) { + Some(Ok(())) => {}, + Some(Err(e)) => return Err(e), + None => { + // If we couldn't find the channel info for the scid, it may be a phantom or + // intercept forward. + if (self.default_configuration.accept_intercept_htlcs && + fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) || + fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash) + {} else { + return Err(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None)); + } + } + } + + let cur_height = self.best_block.read().unwrap().height + 1; + if let Err((err_msg, err_code)) = check_incoming_htlc_cltv( + cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry + ) { + let chan_update_opt = self.do_funded_channel_callback(next_packet_details.outgoing_scid, |chan: &mut Channel| { + self.get_channel_update_for_onion(next_packet_details.outgoing_scid, chan).ok() + }).flatten(); + if err_code & 0x1000 != 0 && chan_update_opt.is_none() { + // We really should set `incorrect_cltv_expiry` here but as we're not + // forwarding over a real channel we can't generate a channel_update + // for it. Instead we just return a generic temporary_node_failure. + return Err((err_msg, 0x2000 | 2, None)); + } + let chan_update_opt = if err_code & 0x1000 != 0 { chan_update_opt } else { None }; + return Err((err_msg, err_code, chan_update_opt)); + } + + Ok(()) } fn htlc_failure_from_update_add_err( @@ -3208,71 +3274,14 @@ where // Perform outbound checks here instead of in [`Self::construct_pending_htlc_info`] because we // can't hold the outbound peer state lock at the same time as the inbound peer state lock. - if let Some((err, code, chan_update)) = loop { - let id_option = self.short_to_chan_info.read().unwrap().get(&next_packet_details.outgoing_scid).cloned(); - let forwarding_chan_info_opt = match id_option { - None => { // unknown_next_peer - // Note that this is likely a timing oracle for detecting whether an scid is a - // phantom or an intercept. - if (self.default_configuration.accept_intercept_htlcs && - fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) || - fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash) - { - None - } else { - break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None)); - } - }, - Some((cp_id, id)) => Some((cp_id.clone(), id.clone())), - }; - let chan_update_opt = if let Some((counterparty_node_id, forwarding_id)) = forwarding_chan_info_opt { - let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id); - if peer_state_mutex_opt.is_none() { - break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None)); - } - let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); - let peer_state = &mut *peer_state_lock; - let chan = match peer_state.channel_by_id.get_mut(&forwarding_id).map( - |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None } - ).flatten() { - None => { - // Channel was removed. The short_to_chan_info and channel_by_id maps - // have no consistency guarantees. - break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None)); - }, - Some(chan) => chan - }; - match self.can_forward_htlc_to_outgoing_channel(chan, msg, &next_packet_details) { - Ok(chan_update_opt) => chan_update_opt, - Err(e) => break Some(e), - } - } else { - None - }; - - let cur_height = self.best_block.read().unwrap().height + 1; - - if let Err((err_msg, code)) = check_incoming_htlc_cltv( - cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry - ) { - if code & 0x1000 != 0 && chan_update_opt.is_none() { - // We really should set `incorrect_cltv_expiry` here but as we're not - // forwarding over a real channel we can't generate a channel_update - // for it. Instead we just return a generic temporary_node_failure. - break Some((err_msg, 0x2000 | 2, None)) - } - let chan_update_opt = if code & 0x1000 != 0 { chan_update_opt } else { None }; - break Some((err_msg, code, chan_update_opt)); - } + self.can_forward_htlc(&msg, &next_packet_details).map_err(|e| { + let (err_msg, err_code, chan_update_opt) = e; + self.htlc_failure_from_update_add_err( + msg, counterparty_node_id, err_msg, err_code, chan_update_opt, + next_hop.is_intro_node_blinded_forward(), &shared_secret + ) + })?; - break None; - } - { - return Err(self.htlc_failure_from_update_add_err( - msg, counterparty_node_id, err, code, chan_update, next_hop.is_intro_node_blinded_forward(), &shared_secret - )); - } Ok((next_hop, shared_secret, Some(next_packet_details.next_packet_pubkey))) } -- 2.39.5