Pass `EffectiveCapacity` through to scorer instead of a `u64`
[rust-lightning] / lightning / src / routing / router.rs
index 677a571620ce31a329f8b046d294a8d99b1c2387..fbdefb9840a80d4c38ebc32007e8b39ab6d31ed8 100644 (file)
@@ -854,7 +854,7 @@ where L::Target: Logger {
                                let short_channel_id = $candidate.short_channel_id();
                                let available_liquidity_msat = bookkept_channels_liquidity_available_msat
                                        .entry(short_channel_id)
-                                       .or_insert_with(|| $candidate.effective_capacity().as_msat());
+                                       .or_insert_with(|| $candidate.effective_capacity());
 
                                // It is tricky to substract $next_hops_fee_msat from available liquidity here.
                                // It may be misleading because we might later choose to reduce the value transferred
@@ -863,7 +863,7 @@ where L::Target: Logger {
                                // fees caused by one expensive channel, but then this channel could have been used
                                // if the amount being transferred over this path is lower.
                                // We do this for now, but this is a subject for removal.
-                               if let Some(available_value_contribution_msat) = available_liquidity_msat.checked_sub($next_hops_fee_msat) {
+                               if let Some(available_value_contribution) = available_liquidity_msat.checked_sub($next_hops_fee_msat) {
 
                                        // Routing Fragmentation Mitigation heuristic:
                                        //
@@ -886,6 +886,7 @@ where L::Target: Logger {
                                                final_value_msat
                                        };
                                        // Verify the liquidity offered by this channel complies to the minimal contribution.
+                                       let available_value_contribution_msat = available_value_contribution.as_msat_without_bounds();
                                        let contributes_sufficient_value = available_value_contribution_msat >= minimal_value_contribution_msat;
 
                                        // Do not consider candidates that exceed the maximum total cltv expiry limit.
@@ -1220,9 +1221,8 @@ where L::Target: Logger {
                                                        short_channel_id: hop.short_channel_id,
                                                })
                                                .unwrap_or_else(|| CandidateRouteHop::PrivateHop { hint: hop });
-                                       let capacity_msat = candidate.effective_capacity().as_msat();
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
-                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, final_value_msat, capacity_msat, &source, &target))
+                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, path_value_msat, candidate.effective_capacity(), &source, &target))
                                                .unwrap_or_else(|| u64::max_value());
 
                                        aggregate_next_hops_cltv_delta = aggregate_next_hops_cltv_delta
@@ -1382,12 +1382,13 @@ where L::Target: Logger {
                                        let mut spent_on_hop_msat = value_contribution_msat;
                                        let next_hops_fee_msat = payment_hop.next_hops_fee_msat;
                                        spent_on_hop_msat += next_hops_fee_msat;
-                                       if spent_on_hop_msat == *channel_liquidity_available_msat {
+                                       if spent_on_hop_msat == channel_liquidity_available_msat.as_msat_without_bounds() {
                                                // If this path used all of this channel's available liquidity, we know
                                                // this path will not be selected again in the next loop iteration.
                                                prevented_redundant_path_selection = true;
                                        }
-                                       *channel_liquidity_available_msat -= spent_on_hop_msat;
+                                       *channel_liquidity_available_msat = channel_liquidity_available_msat
+                                               .checked_sub(spent_on_hop_msat).unwrap_or(EffectiveCapacity::ExactLiquidity { liquidity_msat: 0 });
                                }
                                if !prevented_redundant_path_selection {
                                        // If we weren't capped by hitting a liquidity limit on a channel in the path,
@@ -1396,7 +1397,7 @@ where L::Target: Logger {
                                        let victim_scid = payment_path.hops[(payment_path.hops.len() - 1) / 2].0.candidate.short_channel_id();
                                        log_trace!(logger, "Disabling channel {} for future path building iterations to avoid duplicates.", victim_scid);
                                        let victim_liquidity = bookkept_channels_liquidity_available_msat.get_mut(&victim_scid).unwrap();
-                                       *victim_liquidity = 0;
+                                       *victim_liquidity = EffectiveCapacity::ExactLiquidity { liquidity_msat: 0 };
                                }
 
                                // Track the total amount all our collected paths allow to send so that we:
@@ -1664,7 +1665,7 @@ fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters,
 
 #[cfg(test)]
 mod tests {
-       use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
+       use routing::network_graph::{EffectiveCapacity, NetworkGraph, NetGraphMsgHandler, NodeId};
        use routing::router::{get_route, add_random_cltv_offset, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA};
        use routing::scoring::Score;
        use chain::transaction::OutPoint;
@@ -5001,7 +5002,7 @@ mod tests {
                fn write<W: Writer>(&self, _w: &mut W) -> Result<(), ::io::Error> { unimplemented!() }
        }
        impl Score for BadChannelScorer {
-               fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _capacity_msat: u64, _source: &NodeId, _target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _capacity: EffectiveCapacity, _source: &NodeId, _target: &NodeId) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
 
@@ -5019,7 +5020,7 @@ mod tests {
        }
 
        impl Score for BadNodeScorer {
-               fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _capacity_msat: u64, _source: &NodeId, target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _capacity: EffectiveCapacity, _source: &NodeId, target: &NodeId) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }