Merge pull request #2762 from TheBlueMatt/2023-11-htlc-docs
[rust-lightning] / lightning / src / routing / router.rs
index 361d5793edcc5e27980b34b93737ce608d6f2103..b4cbf3cb234c9eaae6c0757131b596f88c5e7c56 100644 (file)
@@ -130,18 +130,27 @@ impl<'a, S: Deref> ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: Scor
 
 impl<'a, S: Deref> ScoreLookUp for ScorerAccountingForInFlightHtlcs<'a, S> where S::Target: ScoreLookUp {
        type ScoreParams = <S::Target as ScoreLookUp>::ScoreParams;
-       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
+       fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
+               let target = match candidate.target() {
+                       Some(target) => target,
+                       None => return self.scorer.channel_penalty_msat(candidate, usage, score_params),
+               };
+               let short_channel_id = match candidate.short_channel_id() {
+                       Some(short_channel_id) => short_channel_id,
+                       None => return self.scorer.channel_penalty_msat(candidate, usage, score_params),
+               };
+               let source = candidate.source();
                if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
-                       source, target, short_channel_id
+                       &source, &target, short_channel_id
                ) {
                        let usage = ChannelUsage {
                                inflight_htlc_msat: usage.inflight_htlc_msat.saturating_add(used_liquidity),
                                ..usage
                        };
 
-                       self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
+                       self.scorer.channel_penalty_msat(candidate, usage, score_params)
                } else {
-                       self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
+                       self.scorer.channel_penalty_msat(candidate, usage, score_params)
                }
        }
 }
@@ -1062,9 +1071,13 @@ pub enum CandidateRouteHop<'a> {
 }
 
 impl<'a> CandidateRouteHop<'a> {
-       fn short_channel_id(&self) -> Option<u64> {
+       /// Returns short_channel_id if known.
+       /// For `FirstHop` we assume [`ChannelDetails::get_outbound_payment_scid`] is always set, this assumption is checked in
+       /// [`find_route`] method.
+       /// For `Blinded` and `OneHopBlinded` we return `None` because next hop is not known.
+       pub fn short_channel_id(&self) -> Option<u64> {
                match self {
-                       CandidateRouteHop::FirstHop { details, .. } => Some(details.get_outbound_payment_scid().unwrap()),
+                       CandidateRouteHop::FirstHop { details, .. } => details.get_outbound_payment_scid(),
                        CandidateRouteHop::PublicHop { short_channel_id, .. } => Some(*short_channel_id),
                        CandidateRouteHop::PrivateHop { hint, .. } => Some(hint.short_channel_id),
                        CandidateRouteHop::Blinded { .. } => None,
@@ -1083,7 +1096,8 @@ impl<'a> CandidateRouteHop<'a> {
                }
        }
 
-       fn cltv_expiry_delta(&self) -> u32 {
+       /// Returns cltv_expiry_delta for this hop.
+       pub fn cltv_expiry_delta(&self) -> u32 {
                match self {
                        CandidateRouteHop::FirstHop { .. } => 0,
                        CandidateRouteHop::PublicHop { info, .. } => info.direction().cltv_expiry_delta as u32,
@@ -1093,7 +1107,8 @@ impl<'a> CandidateRouteHop<'a> {
                }
        }
 
-       fn htlc_minimum_msat(&self) -> u64 {
+       /// Returns the htlc_minimum_msat for this hop.
+       pub fn htlc_minimum_msat(&self) -> u64 {
                match self {
                        CandidateRouteHop::FirstHop { details, .. } => details.next_outbound_htlc_minimum_msat,
                        CandidateRouteHop::PublicHop { info, .. } => info.direction().htlc_minimum_msat,
@@ -1103,7 +1118,8 @@ impl<'a> CandidateRouteHop<'a> {
                }
        }
 
-       fn fees(&self) -> RoutingFees {
+       /// Returns the fees for this hop.
+       pub fn fees(&self) -> RoutingFees {
                match self {
                        CandidateRouteHop::FirstHop { .. } => RoutingFees {
                                base_msat: 0, proportional_millionths: 0,
@@ -1166,7 +1182,7 @@ impl<'a> CandidateRouteHop<'a> {
                        CandidateRouteHop::PublicHop { info, .. } => *info.source(),
                        CandidateRouteHop::PrivateHop { hint, .. } => hint.src_node_id.into(),
                        CandidateRouteHop::Blinded { hint, .. } => hint.1.introduction_node_id.into(),
-                       CandidateRouteHop::OneHopBlinded { hint, .. } => hint.1.introduction_node_id.into()
+                       CandidateRouteHop::OneHopBlinded { hint, .. } => hint.1.introduction_node_id.into(),
                }
        }
        /// Returns the target node id of this hop, if known.
@@ -1231,9 +1247,6 @@ fn iter_equal<I1: Iterator, I2: Iterator>(mut iter_a: I1, mut iter_b: I2)
 /// These fee values are useful to choose hops as we traverse the graph "payee-to-payer".
 #[derive(Clone)]
 struct PathBuildingHop<'a> {
-       // Note that this should be dropped in favor of loading it from CandidateRouteHop, but doing so
-       // is a larger refactor and will require careful performance analysis.
-       node_id: NodeId,
        candidate: CandidateRouteHop<'a>,
        fee_msat: u64,
 
@@ -1271,7 +1284,7 @@ impl<'a> core::fmt::Debug for PathBuildingHop<'a> {
        fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
                let mut debug_struct = f.debug_struct("PathBuildingHop");
                debug_struct
-                       .field("node_id", &self.node_id)
+                       .field("node_id", &self.candidate.target())
                        .field("short_channel_id", &self.candidate.short_channel_id())
                        .field("total_fee_msat", &self.total_fee_msat)
                        .field("next_hops_fee_msat", &self.next_hops_fee_msat)
@@ -1808,8 +1821,7 @@ where L::Target: Logger {
                        // - for regular channels at channel announcement (TODO)
                        // - for first and last hops early in get_route
                        let src_node_id = $candidate.source();
-                       let dest_node_id = $candidate.target().unwrap_or(maybe_dummy_payee_node_id);
-                       if src_node_id != dest_node_id {
+                       if Some(src_node_id) != $candidate.target() {
                                let scid_opt = $candidate.short_channel_id();
                                let effective_capacity = $candidate.effective_capacity();
                                let htlc_maximum_msat = max_htlc_from_capacity(effective_capacity, channel_saturation_pow_half);
@@ -1948,7 +1960,6 @@ where L::Target: Logger {
                                                        // This will affect our decision on selecting short_channel_id
                                                        // as a way to reach the $candidate.target() node.
                                                        PathBuildingHop {
-                                                               node_id: dest_node_id.clone(),
                                                                candidate: $candidate.clone(),
                                                                fee_msat: 0,
                                                                next_hops_fee_msat: u64::max_value(),
@@ -2004,9 +2015,10 @@ where L::Target: Logger {
                                                                        inflight_htlc_msat: used_liquidity_msat,
                                                                        effective_capacity,
                                                                };
-                                                               let channel_penalty_msat = scid_opt.map_or(0,
-                                                                       |scid| scorer.channel_penalty_msat(scid, &src_node_id, &dest_node_id,
-                                                                               channel_usage, score_params));
+                                                               let channel_penalty_msat =
+                                                                       scorer.channel_penalty_msat($candidate,
+                                                                               channel_usage,
+                                                                               score_params);
                                                                let path_penalty_msat = $next_hops_path_penalty_msat
                                                                        .saturating_add(channel_penalty_msat);
                                                                let new_graph_node = RouteGraphNode {
@@ -2046,7 +2058,6 @@ where L::Target: Logger {
                                                                        old_entry.next_hops_fee_msat = $next_hops_fee_msat;
                                                                        old_entry.hop_use_fee_msat = hop_use_fee_msat;
                                                                        old_entry.total_fee_msat = total_fee_msat;
-                                                                       old_entry.node_id = dest_node_id;
                                                                        old_entry.candidate = $candidate.clone();
                                                                        old_entry.fee_msat = 0; // This value will be later filled with hop_use_fee_msat of the following channel
                                                                        old_entry.path_htlc_minimum_msat = path_htlc_minimum_msat;
@@ -2317,7 +2328,7 @@ where L::Target: Logger {
                                                effective_capacity: candidate.effective_capacity(),
                                        };
                                        let channel_penalty_msat = scorer.channel_penalty_msat(
-                                               hop.short_channel_id, &source, &target, channel_usage, score_params
+                                               &candidate, channel_usage, score_params
                                        );
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
                                                .saturating_add(channel_penalty_msat);
@@ -2418,7 +2429,8 @@ where L::Target: Logger {
 
                                'path_walk: loop {
                                        let mut features_set = false;
-                                       if let Some(first_channels) = first_hop_targets.get(&ordered_hops.last().unwrap().0.node_id) {
+                                       let target = ordered_hops.last().unwrap().0.candidate.target().unwrap_or(maybe_dummy_payee_node_id);
+                                       if let Some(first_channels) = first_hop_targets.get(&target) {
                                                for details in first_channels {
                                                        if let Some(scid) = ordered_hops.last().unwrap().0.candidate.short_channel_id() {
                                                                if details.get_outbound_payment_scid().unwrap() == scid {
@@ -2430,7 +2442,7 @@ where L::Target: Logger {
                                                }
                                        }
                                        if !features_set {
-                                               if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.node_id) {
+                                               if let Some(node) = network_nodes.get(&target) {
                                                        if let Some(node_info) = node.announcement_info.as_ref() {
                                                                ordered_hops.last_mut().unwrap().1 = node_info.features.clone();
                                                        } else {
@@ -2447,11 +2459,11 @@ where L::Target: Logger {
                                        // save this path for the payment route. Also, update the liquidity
                                        // remaining on the used hops, so that we take them into account
                                        // while looking for more paths.
-                                       if ordered_hops.last().unwrap().0.node_id == maybe_dummy_payee_node_id {
+                                       if target == maybe_dummy_payee_node_id {
                                                break 'path_walk;
                                        }
 
-                                       new_entry = match dist.remove(&ordered_hops.last().unwrap().0.node_id) {
+                                       new_entry = match dist.remove(&target) {
                                                Some(payment_hop) => payment_hop,
                                                // We can't arrive at None because, if we ever add an entry to targets,
                                                // we also fill in the entry in dist (see add_entry!).
@@ -2670,8 +2682,8 @@ where L::Target: Logger {
        });
        for idx in 0..(selected_route.len() - 1) {
                if idx + 1 >= selected_route.len() { break; }
-               if iter_equal(selected_route[idx].hops.iter().map(|h| (h.0.candidate.id(), h.0.node_id)),
-                                                                       selected_route[idx + 1].hops.iter().map(|h| (h.0.candidate.id(), h.0.node_id))) {
+               if iter_equal(selected_route[idx].hops.iter().map(|h| (h.0.candidate.id(), h.0.candidate.target())),
+                                                                       selected_route[idx + 1].hops.iter().map(|h| (h.0.candidate.id(), h.0.candidate.target()))) {
                        let new_value = selected_route[idx].get_value_msat() + selected_route[idx + 1].get_value_msat();
                        selected_route[idx].update_value_and_recompute_fees(new_value);
                        selected_route.remove(idx + 1);
@@ -2684,6 +2696,7 @@ where L::Target: Logger {
                for (hop, node_features) in payment_path.hops.iter()
                        .filter(|(h, _)| h.candidate.short_channel_id().is_some())
                {
+                       let target = hop.candidate.target().expect("target is defined when short_channel_id is defined");
                        let maybe_announced_channel = if let CandidateRouteHop::PublicHop { .. } = hop.candidate {
                                // If we sourced the hop from the graph we're sure the target node is announced.
                                true
@@ -2695,14 +2708,14 @@ where L::Target: Logger {
                                // there are announced channels between the endpoints. If so, the hop might be
                                // referring to any of the announced channels, as its `short_channel_id` might be
                                // an alias, in which case we don't take any chances here.
-                               network_graph.node(&hop.node_id).map_or(false, |hop_node|
+                               network_graph.node(&target).map_or(false, |hop_node|
                                        hop_node.channels.iter().any(|scid| network_graph.channel(*scid)
                                                        .map_or(false, |c| c.as_directed_from(&hop.candidate.source()).is_some()))
                                )
                        };
 
                        hops.push(RouteHop {
-                               pubkey: PublicKey::from_slice(hop.node_id.as_slice()).map_err(|_| LightningError{err: format!("Public key {:?} is invalid", &hop.node_id), action: ErrorAction::IgnoreAndLog(Level::Trace)})?,
+                               pubkey: PublicKey::from_slice(target.as_slice()).map_err(|_| LightningError{err: format!("Public key {:?} is invalid", &target), action: ErrorAction::IgnoreAndLog(Level::Trace)})?,
                                node_features: node_features.clone(),
                                short_channel_id: hop.candidate.short_channel_id().unwrap(),
                                channel_features: hop.candidate.features(),
@@ -2872,13 +2885,13 @@ fn build_route_from_hops_internal<L: Deref>(
 
        impl ScoreLookUp for HopScorer {
                type ScoreParams = ();
-               fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
+               fn channel_penalty_msat(&self, candidate: &CandidateRouteHop,
                        _usage: ChannelUsage, _score_params: &Self::ScoreParams) -> u64
                {
                        let mut cur_id = self.our_node_id;
                        for i in 0..self.hop_ids.len() {
                                if let Some(next_id) = self.hop_ids[i] {
-                                       if cur_id == *source && next_id == *target {
+                                       if cur_id == candidate.source() && Some(next_id) == candidate.target() {
                                                return 0;
                                        }
                                        cur_id = next_id;
@@ -2919,7 +2932,7 @@ mod tests {
        use crate::routing::utxo::UtxoResult;
        use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
                BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
-               DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE, RouteParameters};
+               DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE, RouteParameters, CandidateRouteHop};
        use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, ScoreLookUp, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
        use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
        use crate::chain::transaction::OutPoint;
@@ -6224,8 +6237,8 @@ mod tests {
        }
        impl ScoreLookUp for BadChannelScorer {
                type ScoreParams = ();
-               fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
-                       if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
+               fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
+                       if candidate.short_channel_id() == Some(self.short_channel_id) { u64::max_value()  } else { 0  }
                }
        }
 
@@ -6240,8 +6253,8 @@ mod tests {
 
        impl ScoreLookUp for BadNodeScorer {
                type ScoreParams = ();
-               fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
-                       if *target == self.node_id { u64::max_value() } else { 0 }
+               fn channel_penalty_msat(&self, candidate: &CandidateRouteHop, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
+                       if candidate.target() == Some(self.node_id) { u64::max_value() } else { 0 }
                }
        }
 
@@ -6729,26 +6742,32 @@ mod tests {
                };
                scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123);
                scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456);
-               assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage, &scorer_params), 456);
+               let network_graph = network_graph.read_only();
+               let channels = network_graph.channels();
+               let channel = channels.get(&5).unwrap();
+               let info = channel.as_directed_from(&NodeId::from_pubkey(&nodes[3])).unwrap();
+               let candidate: CandidateRouteHop = CandidateRouteHop::PublicHop {
+                       info: info.0,
+                       short_channel_id: 5,
+               };
+               assert_eq!(scorer.channel_penalty_msat(&candidate, usage, &scorer_params), 456);
 
                // Then check we can get a normal route
                let payment_params = PaymentParameters::from_node_id(nodes[10], 42);
                let route_params = RouteParameters::from_payment_params_and_value(
                        payment_params, 100);
-               let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
+               let route = get_route(&our_id, &route_params, &network_graph, None,
                        Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
                assert!(route.is_ok());
 
                // Then check that we can't get a route if we ban an intermediate node.
                scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3]));
-               let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
-                       Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
+               let route = get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes);
                assert!(route.is_err());
 
                // Finally make sure we can route again, when we remove the ban.
                scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3]));
-               let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
-                       Arc::clone(&logger), &scorer, &scorer_params, &random_seed_bytes);
+               let route = get_route(&our_id, &route_params, &network_graph, None, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes);
                assert!(route.is_ok());
        }