Refactor outgoing channel lookup out from decode_update_add_htlc_onion
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 60fd6e87eeabb3eb0bb02c20fe478a66dedb682a..8c4c68eea50f9ca167f9bd6c6872c030cf6fef1e 100644 (file)
@@ -3096,7 +3096,7 @@ where
 
        fn can_forward_htlc_to_outgoing_channel(
                &self, chan: &mut Channel<SP>, msg: &msgs::UpdateAddHTLC, next_packet: &NextPacketDetails
-       ) -> Result<Option<msgs::ChannelUpdate>, (&'static str, u16, Option<msgs::ChannelUpdate>)> {
+       ) -> Result<(), (&'static str, u16, Option<msgs::ChannelUpdate>)> {
                if !chan.context.should_announce() && !self.default_configuration.accept_forwards_to_priv_channels {
                        // Note that the behavior here should be identical to the above block - we
                        // should NOT reveal the existence or non-existence of a private channel if
@@ -3109,7 +3109,6 @@ where
                        // we don't have the channel here.
                        return Err(("Refusing to forward over real channel SCID as our counterparty requested.", 0x4000 | 10, None));
                }
-               let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
 
                // Note that we could technically not return an error yet here and just hope
                // that the connection is reestablished or monitor updated by the time we get
@@ -3120,6 +3119,7 @@ where
                        // If the channel_update we're going to return is disabled (i.e. the
                        // peer has been disabled for some time), return `channel_disabled`,
                        // otherwise return `temporary_channel_failure`.
+                       let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
                        if chan_update_opt.as_ref().map(|u| u.contents.flags & 2 == 2).unwrap_or(false) {
                                return Err(("Forwarding channel has been disconnected for some time.", 0x1000 | 20, chan_update_opt));
                        } else {
@@ -3127,13 +3127,79 @@ where
                        }
                }
                if next_packet.outgoing_amt_msat < chan.context.get_counterparty_htlc_minimum_msat() { // amount_below_minimum
+                       let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
                        return Err(("HTLC amount was below the htlc_minimum_msat", 0x1000 | 11, chan_update_opt));
                }
                if let Err((err, code)) = chan.htlc_satisfies_config(msg, next_packet.outgoing_amt_msat, next_packet.outgoing_cltv_value) {
+                       let chan_update_opt = self.get_channel_update_for_onion(next_packet.outgoing_scid, chan).ok();
                        return Err((err, code, chan_update_opt));
                }
 
-               Ok(chan_update_opt)
+               Ok(())
+       }
+
+       /// Executes a callback `C` that returns some value `X` on the channel found with the given
+       /// `scid`. `None` is returned when the channel is not found.
+       fn do_funded_channel_callback<X, C: Fn(&mut Channel<SP>) -> X>(
+               &self, scid: u64, callback: C,
+       ) -> Option<X> {
+               let (counterparty_node_id, channel_id) = match self.short_to_chan_info.read().unwrap().get(&scid).cloned() {
+                       None => return None,
+                       Some((cp_id, id)) => (cp_id, id),
+               };
+               let per_peer_state = self.per_peer_state.read().unwrap();
+               let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id);
+               if peer_state_mutex_opt.is_none() {
+                       return None;
+               }
+               let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap();
+               let peer_state = &mut *peer_state_lock;
+               match peer_state.channel_by_id.get_mut(&channel_id).and_then(
+                       |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None }
+               ) {
+                       None => None,
+                       Some(chan) => Some(callback(chan)),
+               }
+       }
+
+       fn can_forward_htlc(
+               &self, msg: &msgs::UpdateAddHTLC, next_packet_details: &NextPacketDetails
+       ) -> Result<(), (&'static str, u16, Option<msgs::ChannelUpdate>)> {
+               match self.do_funded_channel_callback(next_packet_details.outgoing_scid, |chan: &mut Channel<SP>| {
+                       self.can_forward_htlc_to_outgoing_channel(chan, msg, next_packet_details)
+               }) {
+                       Some(Ok(())) => {},
+                       Some(Err(e)) => return Err(e),
+                       None => {
+                               // If we couldn't find the channel info for the scid, it may be a phantom or
+                               // intercept forward.
+                               if (self.default_configuration.accept_intercept_htlcs &&
+                                       fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) ||
+                                       fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)
+                               {} else {
+                                       return Err(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
+                               }
+                       }
+               }
+
+               let cur_height = self.best_block.read().unwrap().height + 1;
+               if let Err((err_msg, err_code)) = check_incoming_htlc_cltv(
+                       cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry
+               ) {
+                       let chan_update_opt = self.do_funded_channel_callback(next_packet_details.outgoing_scid, |chan: &mut Channel<SP>| {
+                               self.get_channel_update_for_onion(next_packet_details.outgoing_scid, chan).ok()
+                       }).flatten();
+                       if err_code & 0x1000 != 0 && chan_update_opt.is_none() {
+                               // We really should set `incorrect_cltv_expiry` here but as we're not
+                               // forwarding over a real channel we can't generate a channel_update
+                               // for it. Instead we just return a generic temporary_node_failure.
+                               return Err((err_msg, 0x2000 | 2, None));
+                       }
+                       let chan_update_opt = if err_code & 0x1000 != 0 { chan_update_opt } else { None };
+                       return Err((err_msg, err_code, chan_update_opt));
+               }
+
+               Ok(())
        }
 
        fn htlc_failure_from_update_add_err(
@@ -3208,71 +3274,14 @@ where
 
                // Perform outbound checks here instead of in [`Self::construct_pending_htlc_info`] because we
                // can't hold the outbound peer state lock at the same time as the inbound peer state lock.
-               if let Some((err, code, chan_update)) = loop {
-                       let id_option = self.short_to_chan_info.read().unwrap().get(&next_packet_details.outgoing_scid).cloned();
-                       let forwarding_chan_info_opt = match id_option {
-                               None => { // unknown_next_peer
-                                       // Note that this is likely a timing oracle for detecting whether an scid is a
-                                       // phantom or an intercept.
-                                       if (self.default_configuration.accept_intercept_htlcs &&
-                                               fake_scid::is_valid_intercept(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)) ||
-                                               fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, next_packet_details.outgoing_scid, &self.chain_hash)
-                                       {
-                                               None
-                                       } else {
-                                               break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
-                                       }
-                               },
-                               Some((cp_id, id)) => Some((cp_id.clone(), id.clone())),
-                       };
-                       let chan_update_opt = if let Some((counterparty_node_id, forwarding_id)) = forwarding_chan_info_opt {
-                               let per_peer_state = self.per_peer_state.read().unwrap();
-                               let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id);
-                               if peer_state_mutex_opt.is_none() {
-                                       break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
-                               }
-                               let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap();
-                               let peer_state = &mut *peer_state_lock;
-                               let chan = match peer_state.channel_by_id.get_mut(&forwarding_id).map(
-                                       |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None }
-                               ).flatten() {
-                                       None => {
-                                               // Channel was removed. The short_to_chan_info and channel_by_id maps
-                                               // have no consistency guarantees.
-                                               break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
-                                       },
-                                       Some(chan) => chan
-                               };
-                               match self.can_forward_htlc_to_outgoing_channel(chan, msg, &next_packet_details) {
-                                       Ok(chan_update_opt) => chan_update_opt,
-                                       Err(e) => break Some(e),
-                               }
-                       } else {
-                               None
-                       };
-
-                       let cur_height = self.best_block.read().unwrap().height + 1;
-
-                       if let Err((err_msg, code)) = check_incoming_htlc_cltv(
-                               cur_height, next_packet_details.outgoing_cltv_value, msg.cltv_expiry
-                       ) {
-                               if code & 0x1000 != 0 && chan_update_opt.is_none() {
-                                       // We really should set `incorrect_cltv_expiry` here but as we're not
-                                       // forwarding over a real channel we can't generate a channel_update
-                                       // for it. Instead we just return a generic temporary_node_failure.
-                                       break Some((err_msg, 0x2000 | 2, None))
-                               }
-                               let chan_update_opt = if code & 0x1000 != 0 { chan_update_opt } else { None };
-                               break Some((err_msg, code, chan_update_opt));
-                       }
+               self.can_forward_htlc(&msg, &next_packet_details).map_err(|e| {
+                       let (err_msg, err_code, chan_update_opt) = e;
+                       self.htlc_failure_from_update_add_err(
+                               msg, counterparty_node_id, err_msg, err_code, chan_update_opt,
+                               next_hop.is_intro_node_blinded_forward(), &shared_secret
+                       )
+               })?;
 
-                       break None;
-               }
-               {
-                       return Err(self.htlc_failure_from_update_add_err(
-                               msg, counterparty_node_id, err, code, chan_update, next_hop.is_intro_node_blinded_forward(), &shared_secret
-                       ));
-               }
                Ok((next_hop, shared_secret, Some(next_packet_details.next_packet_pubkey)))
        }