X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Frouting%2Fscorer.rs;fp=lightning%2Fsrc%2Frouting%2Fscorer.rs;h=d2b167675e083748c813a65928a3b71859292b4b;hb=c34ab42961b9c602adf4235742e4b5b54f3de717;hp=01481f16c297f9218aa05082b3fbdc7f22e886ba;hpb=7a8954e1ca318eecb4a03aa5f729d9d0ee9a904e;p=rust-lightning diff --git a/lightning/src/routing/scorer.rs b/lightning/src/routing/scorer.rs index 01481f16..d2b16767 100644 --- a/lightning/src/routing/scorer.rs +++ b/lightning/src/routing/scorer.rs @@ -19,7 +19,7 @@ //! # //! # use lightning::routing::network_graph::NetworkGraph; //! # use lightning::routing::router::{RouteParameters, find_route}; -//! # use lightning::routing::scorer::Scorer; +//! # use lightning::routing::scorer::{Scorer, ScoringParameters}; //! # use lightning::util::logger::{Logger, Record}; //! # use secp256k1::key::PublicKey; //! # @@ -30,11 +30,15 @@ //! # fn find_scored_route(payer: PublicKey, params: RouteParameters, network_graph: NetworkGraph) { //! # let logger = FakeLogger {}; //! # -//! // Use the default channel penalty. +//! // Use the default channel penalties. //! let scorer = Scorer::default(); //! -//! // Or use a custom channel penalty. -//! let scorer = Scorer::new(1_000); +//! // Or use custom channel penalties. +//! let scorer = Scorer::new(ScoringParameters { +//! base_penalty_msat: 1000, +//! failure_penalty_msat: 2 * 1024 * 1000, +//! ..ScoringParameters::default() +//! }); //! //! let route = find_route(&payer, ¶ms, &network_graph, None, &logger, &scorer); //! # } @@ -48,39 +52,130 @@ use routing::network_graph::NodeId; use routing::router::RouteHop; use prelude::*; +#[cfg(not(feature = "no-std"))] +use core::time::Duration; +#[cfg(not(feature = "no-std"))] +use std::time::Instant; /// [`routing::Score`] implementation that provides reasonable default behavior. /// /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with -/// slightly higher fees are available. +/// slightly higher fees are available. May also further penalize failed channels. /// /// See [module-level documentation] for usage. /// /// [module-level documentation]: crate::routing::scorer pub struct Scorer { - base_penalty_msat: u64, + params: ScoringParameters, + #[cfg(not(feature = "no-std"))] + channel_failures: HashMap, + #[cfg(feature = "no-std")] + channel_failures: HashMap, +} + +/// Parameters for configuring [`Scorer`]. +pub struct ScoringParameters { + /// A fixed penalty in msats to apply to each channel. + pub base_penalty_msat: u64, + + /// A penalty in msats to apply to a channel upon failure. + /// + /// This may be reduced over time based on [`failure_penalty_half_life`]. + /// + /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life + pub failure_penalty_msat: u64, + + /// The time needed before any accumulated channel failure penalties are cut in half. + #[cfg(not(feature = "no-std"))] + pub failure_penalty_half_life: Duration, } impl Scorer { - /// Creates a new scorer using `base_penalty_msat` as the channel penalty. - pub fn new(base_penalty_msat: u64) -> Self { - Self { base_penalty_msat } + /// Creates a new scorer using the given scoring parameters. + pub fn new(params: ScoringParameters) -> Self { + Self { + params, + channel_failures: HashMap::new(), + } + } + + /// Creates a new scorer using `penalty_msat` as a fixed channel penalty. + #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))] + pub fn with_fixed_penalty(penalty_msat: u64) -> Self { + Self::new(ScoringParameters { + base_penalty_msat: penalty_msat, + failure_penalty_msat: 0, + #[cfg(not(feature = "no-std"))] + failure_penalty_half_life: Duration::from_secs(0), + }) + } + + #[cfg(not(feature = "no-std"))] + fn decay_from(&self, penalty_msat: u64, last_failure: &Instant) -> u64 { + decay_from(penalty_msat, last_failure, self.params.failure_penalty_half_life) } } impl Default for Scorer { - /// Creates a new scorer using 500 msat as the channel penalty. fn default() -> Self { - Scorer::new(500) + Scorer::new(ScoringParameters::default()) + } +} + +impl Default for ScoringParameters { + fn default() -> Self { + Self { + base_penalty_msat: 500, + failure_penalty_msat: 1024 * 1000, + #[cfg(not(feature = "no-std"))] + failure_penalty_half_life: Duration::from_secs(3600), + } } } impl routing::Score for Scorer { fn channel_penalty_msat( - &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId + &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId ) -> u64 { - self.base_penalty_msat + #[cfg(not(feature = "no-std"))] + let failure_penalty_msat = match self.channel_failures.get(&short_channel_id) { + Some((penalty_msat, last_failure)) => self.decay_from(*penalty_msat, last_failure), + None => 0, + }; + #[cfg(feature = "no-std")] + let failure_penalty_msat = + self.channel_failures.get(&short_channel_id).copied().unwrap_or(0); + + self.params.base_penalty_msat + failure_penalty_msat } - fn payment_path_failed(&mut self, _path: &Vec, _short_channel_id: u64) {} + fn payment_path_failed(&mut self, _path: &Vec, short_channel_id: u64) { + let failure_penalty_msat = self.params.failure_penalty_msat; + #[cfg(not(feature = "no-std"))] + { + let half_life = self.params.failure_penalty_half_life; + self.channel_failures + .entry(short_channel_id) + .and_modify(|(penalty_msat, last_failure)| { + let decayed_penalty = decay_from(*penalty_msat, last_failure, half_life); + *penalty_msat = decayed_penalty + failure_penalty_msat; + *last_failure = Instant::now(); + }) + .or_insert_with(|| (failure_penalty_msat, Instant::now())); + } + #[cfg(feature = "no-std")] + self.channel_failures + .entry(short_channel_id) + .and_modify(|penalty_msat| *penalty_msat += failure_penalty_msat) + .or_insert(failure_penalty_msat); + } +} + +#[cfg(not(feature = "no-std"))] +fn decay_from(penalty_msat: u64, last_failure: &Instant, half_life: Duration) -> u64 { + let decays = last_failure.elapsed().as_secs().checked_div(half_life.as_secs()); + match decays { + Some(decays) => penalty_msat >> decays, + None => 0, + } }