Merge pull request #1799 from TheBlueMatt/2022-10-heap-nerdsnipe
[rust-lightning] / lightning / src / routing / router.rs
index c15b612d939b16d4b75e79c81b3077cd60cadb09..8543956ac657de2da7a1ac163478aa1fd5ba22b2 100644 (file)
@@ -582,7 +582,6 @@ impl_writeable_tlv_based!(RouteHintHop, {
 #[derive(Eq, PartialEq)]
 struct RouteGraphNode {
        node_id: NodeId,
-       lowest_fee_to_peer_through_node: u64,
        lowest_fee_to_node: u64,
        total_cltv_delta: u32,
        // The maximum value a yet-to-be-constructed payment path might flow through this node.
@@ -603,9 +602,9 @@ struct RouteGraphNode {
 
 impl cmp::Ord for RouteGraphNode {
        fn cmp(&self, other: &RouteGraphNode) -> cmp::Ordering {
-               let other_score = cmp::max(other.lowest_fee_to_peer_through_node, other.path_htlc_minimum_msat)
+               let other_score = cmp::max(other.lowest_fee_to_node, other.path_htlc_minimum_msat)
                        .saturating_add(other.path_penalty_msat);
-               let self_score = cmp::max(self.lowest_fee_to_peer_through_node, self.path_htlc_minimum_msat)
+               let self_score = cmp::max(self.lowest_fee_to_node, self.path_htlc_minimum_msat)
                        .saturating_add(self.path_penalty_msat);
                other_score.cmp(&self_score).then_with(|| other.node_id.cmp(&self.node_id))
        }
@@ -729,8 +728,6 @@ struct PathBuildingHop<'a> {
        candidate: CandidateRouteHop<'a>,
        fee_msat: u64,
 
-       /// Minimal fees required to route to the source node of the current hop via any of its inbound channels.
-       src_lowest_inbound_fees: RoutingFees,
        /// All the fees paid *after* this channel on the way to the destination
        next_hops_fee_msat: u64,
        /// Fee paid for the use of the current channel (see candidate.fees()).
@@ -888,18 +885,20 @@ impl<'a> PaymentPath<'a> {
        }
 }
 
+#[inline(always)]
+/// Calculate the fees required to route the given amount over a channel with the given fees.
 fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option<u64> {
-       let proportional_fee_millions =
-               amount_msat.checked_mul(channel_fees.proportional_millionths as u64);
-       if let Some(new_fee) = proportional_fee_millions.and_then(|part| {
-                       (channel_fees.base_msat as u64).checked_add(part / 1_000_000) }) {
+       amount_msat.checked_mul(channel_fees.proportional_millionths as u64)
+               .and_then(|part| (channel_fees.base_msat as u64).checked_add(part / 1_000_000))
+}
 
-               Some(new_fee)
-       } else {
-               // This function may be (indirectly) called without any verification,
-               // with channel_fees provided by a caller. We should handle it gracefully.
-               None
-       }
+#[inline(always)]
+/// Calculate the fees required to route the given amount over a channel with the given fees,
+/// saturating to [`u64::max_value`].
+fn compute_fees_saturating(amount_msat: u64, channel_fees: RoutingFees) -> u64 {
+       amount_msat.checked_mul(channel_fees.proportional_millionths as u64)
+               .map(|prop| prop / 1_000_000).unwrap_or(u64::max_value())
+               .saturating_add(channel_fees.base_msat as u64)
 }
 
 /// The default `features` we assume for a node in a route, when no `features` are known about that
@@ -1007,9 +1006,8 @@ where L::Target: Logger {
        // 8. If our maximum channel saturation limit caused us to pick two identical paths, combine
        //    them so that we're not sending two HTLCs along the same path.
 
-       // As for the actual search algorithm,
-       // we do a payee-to-payer pseudo-Dijkstra's sorting by each node's distance from the payee
-       // plus the minimum per-HTLC fee to get from it to another node (aka "shitty pseudo-A*").
+       // As for the actual search algorithm, we do a payee-to-payer Dijkstra's sorting by each node's
+       // distance from the payee
        //
        // We are not a faithful Dijkstra's implementation because we can change values which impact
        // earlier nodes while processing later nodes. Specifically, if we reach a channel with a lower
@@ -1044,10 +1042,6 @@ where L::Target: Logger {
        // runtime for little gain. Specifically, the current algorithm rather efficiently explores the
        // graph for candidate paths, calculating the maximum value which can realistically be sent at
        // the same time, remaining generic across different payment values.
-       //
-       // TODO: There are a few tweaks we could do, including possibly pre-calculating more stuff
-       // to use as the A* heuristic beyond just the cost to get one node further than the current
-       // one.
 
        let network_channels = network_graph.channels();
        let network_nodes = network_graph.nodes();
@@ -1097,7 +1091,7 @@ where L::Target: Logger {
                }
        }
 
-       // The main heap containing all candidate next-hops sorted by their score (max(A* fee,
+       // The main heap containing all candidate next-hops sorted by their score (max(fee,
        // htlc_minimum)). Ideally this would be a heap which allowed cheap score reduction instead of
        // adding duplicate entries when we find a better path to a given node.
        let mut targets: BinaryHeap<RouteGraphNode> = BinaryHeap::new();
@@ -1262,10 +1256,10 @@ where L::Target: Logger {
                                                // might violate htlc_minimum_msat on the hops which are next along the
                                                // payment path (upstream to the payee). To avoid that, we recompute
                                                // path fees knowing the final path contribution after constructing it.
-                                               let path_htlc_minimum_msat = compute_fees($next_hops_path_htlc_minimum_msat, $candidate.fees())
-                                                       .and_then(|fee_msat| fee_msat.checked_add($next_hops_path_htlc_minimum_msat))
-                                                       .map(|fee_msat| cmp::max(fee_msat, $candidate.htlc_minimum_msat()))
-                                                       .unwrap_or_else(|| u64::max_value());
+                                               let path_htlc_minimum_msat = cmp::max(
+                                                       compute_fees_saturating($next_hops_path_htlc_minimum_msat, $candidate.fees())
+                                                               .saturating_add($next_hops_path_htlc_minimum_msat),
+                                                       $candidate.htlc_minimum_msat());
                                                let hm_entry = dist.entry($src_node_id);
                                                let old_entry = hm_entry.or_insert_with(|| {
                                                        // If there was previously no known way to access the source node
@@ -1273,20 +1267,10 @@ where L::Target: Logger {
                                                        // semi-dummy record just to compute the fees to reach the source node.
                                                        // This will affect our decision on selecting short_channel_id
                                                        // as a way to reach the $dest_node_id.
-                                                       let mut fee_base_msat = 0;
-                                                       let mut fee_proportional_millionths = 0;
-                                                       if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
-                                                               fee_base_msat = fees.base_msat;
-                                                               fee_proportional_millionths = fees.proportional_millionths;
-                                                       }
                                                        PathBuildingHop {
                                                                node_id: $dest_node_id.clone(),
                                                                candidate: $candidate.clone(),
                                                                fee_msat: 0,
-                                                               src_lowest_inbound_fees: RoutingFees {
-                                                                       base_msat: fee_base_msat,
-                                                                       proportional_millionths: fee_proportional_millionths,
-                                                               },
                                                                next_hops_fee_msat: u64::max_value(),
                                                                hop_use_fee_msat: u64::max_value(),
                                                                total_fee_msat: u64::max_value(),
@@ -1309,38 +1293,15 @@ where L::Target: Logger {
 
                                                if should_process {
                                                        let mut hop_use_fee_msat = 0;
-                                                       let mut total_fee_msat = $next_hops_fee_msat;
+                                                       let mut total_fee_msat: u64 = $next_hops_fee_msat;
 
                                                        // Ignore hop_use_fee_msat for channel-from-us as we assume all channels-from-us
                                                        // will have the same effective-fee
                                                        if $src_node_id != our_node_id {
-                                                               match compute_fees(amount_to_transfer_over_msat, $candidate.fees()) {
-                                                                       // max_value means we'll always fail
-                                                                       // the old_entry.total_fee_msat > total_fee_msat check
-                                                                       None => total_fee_msat = u64::max_value(),
-                                                                       Some(fee_msat) => {
-                                                                               hop_use_fee_msat = fee_msat;
-                                                                               total_fee_msat += hop_use_fee_msat;
-                                                                               // When calculating the lowest inbound fees to a node, we
-                                                                               // calculate fees here not based on the actual value we think
-                                                                               // will flow over this channel, but on the minimum value that
-                                                                               // we'll accept flowing over it. The minimum accepted value
-                                                                               // is a constant through each path collection run, ensuring
-                                                                               // consistent basis. Otherwise we may later find a
-                                                                               // different path to the source node that is more expensive,
-                                                                               // but which we consider to be cheaper because we are capacity
-                                                                               // constrained and the relative fee becomes lower.
-                                                                               match compute_fees(minimal_value_contribution_msat, old_entry.src_lowest_inbound_fees)
-                                                                                               .map(|a| a.checked_add(total_fee_msat)) {
-                                                                                       Some(Some(v)) => {
-                                                                                               total_fee_msat = v;
-                                                                                       },
-                                                                                       _ => {
-                                                                                               total_fee_msat = u64::max_value();
-                                                                                       }
-                                                                               };
-                                                                       }
-                                                               }
+                                                               // Note that `u64::max_value` means we'll always fail the
+                                                               // `old_entry.total_fee_msat > total_fee_msat` check below
+                                                               hop_use_fee_msat = compute_fees_saturating(amount_to_transfer_over_msat, $candidate.fees());
+                                                               total_fee_msat = total_fee_msat.saturating_add(hop_use_fee_msat);
                                                        }
 
                                                        let channel_usage = ChannelUsage {
@@ -1355,8 +1316,7 @@ where L::Target: Logger {
                                                                .saturating_add(channel_penalty_msat);
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
-                                                               lowest_fee_to_peer_through_node: total_fee_msat,
-                                                               lowest_fee_to_node: $next_hops_fee_msat as u64 + hop_use_fee_msat,
+                                                               lowest_fee_to_node: total_fee_msat,
                                                                total_cltv_delta: hop_total_cltv_delta,
                                                                value_contribution_msat,
                                                                path_htlc_minimum_msat,
@@ -5544,9 +5504,9 @@ mod tests {
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let payment_params = PaymentParameters::from_node_id(dst);
                                let amt = seed as u64 % 200_000_000;
                                let params = ProbabilisticScoringParameters::default();
@@ -5582,9 +5542,9 @@ mod tests {
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let payment_params = PaymentParameters::from_node_id(dst).with_features(channelmanager::provided_invoice_features(&config));
                                let amt = seed as u64 % 200_000_000;
                                let params = ProbabilisticScoringParameters::default();
@@ -5639,8 +5599,8 @@ pub(crate) mod bench_utils {
        use std::fs::File;
        /// Tries to open a network graph file, or panics with a URL to fetch it.
        pub(crate) fn get_route_file() -> Result<std::fs::File, &'static str> {
-               let res = File::open("net_graph-2021-05-31.bin") // By default we're run in RL/lightning
-                       .or_else(|_| File::open("lightning/net_graph-2021-05-31.bin")) // We may be run manually in RL/
+               let res = File::open("net_graph-2023-01-18.bin") // By default we're run in RL/lightning
+                       .or_else(|_| File::open("lightning/net_graph-2023-01-18.bin")) // We may be run manually in RL/
                        .or_else(|_| { // Fall back to guessing based on the binary location
                                // path is likely something like .../rust-lightning/target/debug/deps/lightning-...
                                let mut path = std::env::current_exe().unwrap();
@@ -5649,11 +5609,11 @@ pub(crate) mod bench_utils {
                                path.pop(); // debug
                                path.pop(); // target
                                path.push("lightning");
-                               path.push("net_graph-2021-05-31.bin");
+                               path.push("net_graph-2023-01-18.bin");
                                eprintln!("{}", path.to_str().unwrap());
                                File::open(path)
                        })
-               .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin and place it at lightning/net_graph-2021-05-31.bin");
+               .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin and place it at lightning/net_graph-2023-01-18.bin");
                #[cfg(require_route_graph_test)]
                return Ok(res.unwrap());
                #[cfg(not(require_route_graph_test))]
@@ -5782,9 +5742,9 @@ mod benches {
                'load_endpoints: for _ in 0..150 {
                        loop {
                                seed *= 0xdeadbeef;
-                               let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed *= 0xdeadbeef;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let params = PaymentParameters::from_node_id(dst).with_features(features.clone());
                                let first_hop = first_hop(src);
                                let amt = seed as u64 % 1_000_000;