From: Elias Rohrer Date: Sat, 18 Jun 2022 14:24:37 +0000 (+0200) Subject: Allow nodes to be avoided during pathfinding X-Git-Tag: v0.0.109~5^2 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=57d8257a0b8e7ace18d9260810285314902d11aa;p=rust-lightning Allow nodes to be avoided during pathfinding Users may want to - for whatever reasons - prevent payments to be routed over certain nodes. This change therefore allows to add `NodeId`s to a list of banned nodes, which then will be avoided during path finding. --- diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 634282d6..26e00319 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -1878,7 +1878,7 @@ mod tests { use routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE}; - use routing::scoring::{ChannelUsage, Score}; + use routing::scoring::{ChannelUsage, Score, ProbabilisticScorer, ProbabilisticScoringParameters}; use chain::transaction::OutPoint; use chain::keysinterface::KeysInterface; use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures}; @@ -5713,6 +5713,33 @@ mod tests { } } } + + #[test] + fn avoids_banned_nodes() { + let (secp_ctx, network_graph, _, _, logger) = build_line_graph(); + let (_, our_id, _, nodes) = get_nodes(&secp_ctx); + + let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); + let random_seed_bytes = keys_manager.get_secure_random_bytes(); + + let scorer_params = ProbabilisticScoringParameters::default(); + let mut scorer = ProbabilisticScorer::new(scorer_params, Arc::clone(&network_graph), Arc::clone(&logger)); + + // First check we can get a route. + let payment_params = PaymentParameters::from_node_id(nodes[10]); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes); + assert!(route.is_ok()); + + // Then check that we can't get a route if we ban an intermediate node. + scorer.add_banned(&NodeId::from_pubkey(&nodes[3])); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes); + assert!(route.is_err()); + + // Finally make sure we can route again, when we remove the ban. + scorer.remove_banned(&NodeId::from_pubkey(&nodes[3])); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, 42, Arc::clone(&logger), &scorer, &random_seed_bytes); + assert!(route.is_ok()); + } } #[cfg(all(test, not(feature = "no-std")))] diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index f7fe5863..72c27fa2 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -306,7 +306,7 @@ where L::Target: Logger { /// /// Used to configure base, liquidity, and amount penalties, the sum of which comprises the channel /// penalty (i.e., the amount in msats willing to be paid to avoid routing through the channel). -#[derive(Clone, Copy)] +#[derive(Clone)] pub struct ProbabilisticScoringParameters { /// A fixed penalty in msats to apply to each channel. /// @@ -361,6 +361,11 @@ pub struct ProbabilisticScoringParameters { /// /// Default value: 256 msat pub amount_penalty_multiplier_msat: u64, + + /// A list of nodes that won't be considered during path finding. + /// + /// (C-not exported) + pub banned_nodes: HashSet, } /// Accounting for channel liquidity balance uncertainty. @@ -451,6 +456,22 @@ impl>, L: Deref, T: Time> ProbabilisticScorerU } None } + + /// Marks the node with the given `node_id` as banned, i.e., + /// it will be avoided during path finding. + pub fn add_banned(&mut self, node_id: &NodeId) { + self.params.banned_nodes.insert(*node_id); + } + + /// Removes the node with the given `node_id` from the list of nodes to avoid. + pub fn remove_banned(&mut self, node_id: &NodeId) { + self.params.banned_nodes.remove(node_id); + } + + /// Clears the list of nodes that are avoided during path finding. + pub fn clear_banned(&mut self) { + self.params.banned_nodes = HashSet::new(); + } } impl ProbabilisticScoringParameters { @@ -461,6 +482,15 @@ impl ProbabilisticScoringParameters { liquidity_penalty_multiplier_msat: 0, liquidity_offset_half_life: Duration::from_secs(3600), amount_penalty_multiplier_msat: 0, + banned_nodes: HashSet::new(), + } + } + + /// Marks all nodes in the given list as banned, i.e., + /// they will be avoided during path finding. + pub fn add_banned_from_list(&mut self, node_ids: Vec) { + for id in node_ids { + self.banned_nodes.insert(id); } } } @@ -472,6 +502,7 @@ impl Default for ProbabilisticScoringParameters { liquidity_penalty_multiplier_msat: 40_000, liquidity_offset_half_life: Duration::from_secs(3600), amount_penalty_multiplier_msat: 256, + banned_nodes: HashSet::new(), } } } @@ -543,7 +574,7 @@ const AMOUNT_PENALTY_DIVISOR: u64 = 1 << 20; impl, T: Time, U: Deref> DirectedChannelLiquidity { /// Returns a penalty for routing the given HTLC `amount_msat` through the channel in this /// direction. - fn penalty_msat(&self, amount_msat: u64, params: ProbabilisticScoringParameters) -> u64 { + fn penalty_msat(&self, amount_msat: u64, params: &ProbabilisticScoringParameters) -> u64 { let max_liquidity_msat = self.max_liquidity_msat(); let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat); if amount_msat <= min_liquidity_msat { @@ -580,7 +611,7 @@ impl, T: Time, U: Deref> DirectedChannelLiqui #[inline(always)] fn combined_penalty_msat( &self, amount_msat: u64, negative_log10_times_2048: u64, - params: ProbabilisticScoringParameters + params: &ProbabilisticScoringParameters ) -> u64 { let liquidity_penalty_msat = { // Upper bound the liquidity penalty to ensure some channel is selected. @@ -672,6 +703,10 @@ impl>, L: Deref, T: Time> Score for Probabilis fn channel_penalty_msat( &self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage ) -> u64 { + if self.params.banned_nodes.contains(source) || self.params.banned_nodes.contains(target) { + return u64::max_value(); + } + if let EffectiveCapacity::ExactLiquidity { liquidity_msat } = usage.effective_capacity { if usage.amount_msat > liquidity_msat { return u64::max_value(); @@ -688,7 +723,7 @@ impl>, L: Deref, T: Time> Score for Probabilis .get(&short_channel_id) .unwrap_or(&ChannelLiquidity::new()) .as_directed(source, target, capacity_msat, liquidity_offset_half_life) - .penalty_msat(amount_msat, self.params) + .penalty_msat(amount_msat, &self.params) } fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) { @@ -1072,7 +1107,7 @@ impl>, L: Deref, T: Time> Writeable for Probab #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { write_tlv_fields!(w, { - (0, self.channel_liquidities, required) + (0, self.channel_liquidities, required), }); Ok(()) } @@ -1087,7 +1122,7 @@ ReadableArgs<(ProbabilisticScoringParameters, G, L)> for ProbabilisticScorerUsin let (params, network_graph, logger) = args; let mut channel_liquidities = HashMap::new(); read_tlv_fields!(r, { - (0, channel_liquidities, required) + (0, channel_liquidities, required), }); Ok(Self { params, @@ -1862,7 +1897,7 @@ mod tests { liquidity_offset_half_life: Duration::from_secs(10), ..ProbabilisticScoringParameters::zero_penalty() }; - let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger); + let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger); let source = source_node_id(); let target = target_node_id(); let usage = ChannelUsage { @@ -1898,7 +1933,7 @@ mod tests { liquidity_offset_half_life: Duration::from_secs(10), ..ProbabilisticScoringParameters::zero_penalty() }; - let mut scorer = ProbabilisticScorer::new(params, &network_graph, &logger); + let mut scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger); let source = source_node_id(); let target = target_node_id(); let usage = ChannelUsage { @@ -2086,7 +2121,7 @@ mod tests { let logger = TestLogger::new(); let network_graph = network_graph(&logger); let params = ProbabilisticScoringParameters::default(); - let scorer = ProbabilisticScorer::new(params, &network_graph, &logger); + let scorer = ProbabilisticScorer::new(params.clone(), &network_graph, &logger); let source = source_node_id(); let target = target_node_id();