From b11dcf9ba8c75b50c5915402ad8c29b972b0f8a0 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Thu, 24 Mar 2022 09:12:26 -0600 Subject: [PATCH] Randomize candidate paths during route selection. --- lightning/src/routing/router.rs | 40 +++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 1b52af88..8a2aecb5 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -495,6 +495,10 @@ impl<'a> PaymentPath<'a> { self.hops.last().unwrap().0.fee_msat } + fn get_path_penalty_msat(&self) -> u64 { + self.hops.first().map(|h| h.0.path_penalty_msat).unwrap_or(u64::max_value()) + } + fn get_total_fee_paid_msat(&self) -> u64 { if self.hops.len() < 1 { return 0; @@ -645,7 +649,7 @@ where L::Target: Logger { pub(crate) fn get_route( our_node_pubkey: &PublicKey, payment_params: &PaymentParameters, network_graph: &ReadOnlyNetworkGraph, first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, final_cltv_expiry_delta: u32, - logger: L, scorer: &S, _random_seed_bytes: &[u8; 32] + logger: L, scorer: &S, random_seed_bytes: &[u8; 32] ) -> Result where L::Target: Logger { let payee_node_id = NodeId::from_pubkey(&payment_params.payee_pubkey); @@ -1449,17 +1453,24 @@ where L::Target: Logger { // Draw multiple sufficient routes by randomly combining the selected paths. let mut drawn_routes = Vec::new(); - for i in 0..payment_paths.len() { + let mut prng = ChaCha20::new(random_seed_bytes, &[0u8; 12]); + let mut random_index_bytes = [0u8; ::core::mem::size_of::()]; + + let num_permutations = payment_paths.len(); + for _ in 0..num_permutations { let mut cur_route = Vec::::new(); let mut aggregate_route_value_msat = 0; // Step (6). - // TODO: real random shuffle - // Currently just starts with i_th and goes up to i-1_th in a looped way. - let cur_payment_paths = [&payment_paths[i..], &payment_paths[..i]].concat(); + // Do a Fisher-Yates shuffle to create a random permutation of the payment paths + for cur_index in (1..payment_paths.len()).rev() { + prng.process_in_place(&mut random_index_bytes); + let random_index = usize::from_be_bytes(random_index_bytes).wrapping_rem(cur_index+1); + payment_paths.swap(cur_index, random_index); + } // Step (7). - for payment_path in cur_payment_paths { + for payment_path in &payment_paths { cur_route.push(payment_path.clone()); aggregate_route_value_msat += payment_path.get_value_msat(); if aggregate_route_value_msat > final_value_msat { @@ -1469,12 +1480,17 @@ where L::Target: Logger { // also makes routing more reliable. let mut overpaid_value_msat = aggregate_route_value_msat - final_value_msat; - // First, drop some expensive low-value paths entirely if possible. - // Sort by value so that we drop many really-low values first, since - // fewer paths is better: the payment is less likely to fail. - // TODO: this could also be optimized by also sorting by feerate_per_sat_routed, - // so that the sender pays less fees overall. And also htlc_minimum_msat. - cur_route.sort_by_key(|path| path.get_value_msat()); + // First, we drop some expensive low-value paths entirely if possible, since fewer + // paths is better: the payment is less likely to fail. In order to do so, we sort + // by value and fall back to total fees paid, i.e., in case of equal values we + // prefer lower cost paths. + cur_route.sort_unstable_by(|a, b| { + a.get_value_msat().cmp(&b.get_value_msat()) + // Reverse ordering for fees, so we drop higher-fee paths first + .then_with(|| b.get_total_fee_paid_msat().saturating_add(b.get_path_penalty_msat()) + .cmp(&a.get_total_fee_paid_msat().saturating_add(a.get_path_penalty_msat()))) + }); + // We should make sure that at least 1 path left. let mut paths_left = cur_route.len(); cur_route.retain(|path| { -- 2.30.2