From: Matt Corallo Date: Fri, 12 Nov 2021 15:52:59 +0000 (+0000) Subject: Move `Score` into a `scoring` module instead of a top-level module X-Git-Tag: v0.0.104~34^2 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=42ebf774155632b5656fd5820eb8c28d0003d9b6;p=rust-lightning Move `Score` into a `scoring` module instead of a top-level module Traits in top-level modules is somewhat confusing - generally top-level modules are just organizational modules and don't contain things themselves, instead placing traits and structs in sub-modules. Further, its incredibly awkward to have a `scorer` sub-module, but only have a single struct in it, with the relevant trait it is the only implementation of somewhere else. Not having `Score` in the `scorer` sub-module is further confusing because it's the only module anywhere that references scoring at all. --- diff --git a/fuzz/src/full_stack.rs b/fuzz/src/full_stack.rs index 829ef20b0..81408b85b 100644 --- a/fuzz/src/full_stack.rs +++ b/fuzz/src/full_stack.rs @@ -39,7 +39,7 @@ use lightning::ln::msgs::DecodeError; use lightning::ln::script::ShutdownScript; use lightning::routing::network_graph::{NetGraphMsgHandler, NetworkGraph}; use lightning::routing::router::{find_route, Payee, RouteParameters}; -use lightning::routing::scorer::Scorer; +use lightning::routing::scoring::Scorer; use lightning::util::config::UserConfig; use lightning::util::errors::APIError; use lightning::util::events::Event; diff --git a/fuzz/src/router.rs b/fuzz/src/router.rs index 8c9b4b7d8..149d40134 100644 --- a/fuzz/src/router.rs +++ b/fuzz/src/router.rs @@ -17,7 +17,7 @@ use lightning::ln::channelmanager::{ChannelDetails, ChannelCounterparty}; use lightning::ln::features::InitFeatures; use lightning::ln::msgs; use lightning::routing::router::{find_route, Payee, RouteHint, RouteHintHop, RouteParameters}; -use lightning::routing::scorer::Scorer; +use lightning::routing::scoring::Scorer; use lightning::util::logger::Logger; use lightning::util::ser::Readable; use lightning::routing::network_graph::{NetworkGraph, RoutingFees}; diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 4099afbaa..a480d40e9 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -31,7 +31,7 @@ //! # use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret}; //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; //! # use lightning::ln::msgs::LightningError; -//! # use lightning::routing; +//! # use lightning::routing::scoring::Score; //! # use lightning::routing::network_graph::NodeId; //! # use lightning::routing::router::{Route, RouteHop, RouteParameters}; //! # use lightning::util::events::{Event, EventHandler, EventsProvider}; @@ -63,7 +63,7 @@ //! # } //! # //! # struct FakeRouter {}; -//! # impl Router for FakeRouter { +//! # impl Router for FakeRouter { //! # fn find_route( //! # &self, payer: &PublicKey, params: &RouteParameters, payment_hash: &PaymentHash, //! # first_hops: Option<&[&ChannelDetails]>, scorer: &S @@ -71,7 +71,7 @@ //! # } //! # //! # struct FakeScorer {}; -//! # impl routing::Score for FakeScorer { +//! # impl Score for FakeScorer { //! # fn channel_penalty_msat( //! # &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, _target: &NodeId //! # ) -> u64 { 0 } @@ -122,8 +122,7 @@ use bitcoin_hashes::sha256::Hash as Sha256; use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; use lightning::ln::msgs::LightningError; -use lightning::routing; -use lightning::routing::{LockableScore, Score}; +use lightning::routing::scoring::{LockableScore, Score}; use lightning::routing::router::{Payee, Route, RouteParameters}; use lightning::util::events::{Event, EventHandler}; use lightning::util::logger::Logger; @@ -139,8 +138,8 @@ use std::time::{Duration, SystemTime}; pub struct InvoicePayer where P::Target: Payer, - R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, - S::Target: for <'a> routing::LockableScore<'a>, + R: for <'a> Router<<::Target as LockableScore<'a>>::Locked>, + S::Target: for <'a> LockableScore<'a>, L::Target: Logger, E: EventHandler, { @@ -177,7 +176,7 @@ pub trait Payer { } /// A trait defining behavior for routing an [`Invoice`] payment. -pub trait Router { +pub trait Router { /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. fn find_route( &self, payer: &PublicKey, params: &RouteParameters, payment_hash: &PaymentHash, @@ -207,8 +206,8 @@ pub enum PaymentError { impl InvoicePayer where P::Target: Payer, - R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, - S::Target: for <'a> routing::LockableScore<'a>, + R: for <'a> Router<<::Target as LockableScore<'a>>::Locked>, + S::Target: for <'a> LockableScore<'a>, L::Target: Logger, E: EventHandler, { @@ -441,8 +440,8 @@ fn has_expired(params: &RouteParameters) -> bool { impl EventHandler for InvoicePayer where P::Target: Payer, - R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, - S::Target: for <'a> routing::LockableScore<'a>, + R: for <'a> Router<<::Target as LockableScore<'a>>::Locked>, + S::Target: for <'a> LockableScore<'a>, L::Target: Logger, E: EventHandler, { @@ -1186,7 +1185,7 @@ mod tests { } } - impl Router for TestRouter { + impl Router for TestRouter { fn find_route( &self, _payer: &PublicKey, params: &RouteParameters, _payment_hash: &PaymentHash, _first_hops: Option<&[&ChannelDetails]>, _scorer: &S @@ -1199,7 +1198,7 @@ mod tests { struct FailingRouter; - impl Router for FailingRouter { + impl Router for FailingRouter { fn find_route( &self, _payer: &PublicKey, _params: &RouteParameters, _payment_hash: &PaymentHash, _first_hops: Option<&[&ChannelDetails]>, _scorer: &S @@ -1225,7 +1224,7 @@ mod tests { } } - impl routing::Score for TestScorer { + impl Score for TestScorer { fn channel_penalty_msat( &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, _target: &NodeId ) -> u64 { 0 } @@ -1364,7 +1363,7 @@ mod tests { // *** Full Featured Functional Tests with a Real ChannelManager *** struct ManualRouter(RefCell>>); - impl Router for ManualRouter { + impl Router for ManualRouter { fn find_route( &self, _payer: &PublicKey, _params: &RouteParameters, _payment_hash: &PaymentHash, _first_hops: Option<&[&ChannelDetails]>, _scorer: &S diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index ad94b0804..b7fdb73f2 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -11,7 +11,7 @@ use lightning::chain::keysinterface::{Sign, KeysInterface}; use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY}; use lightning::ln::msgs::LightningError; -use lightning::routing; +use lightning::routing::scoring::Score; use lightning::routing::network_graph::{NetworkGraph, RoutingFees}; use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route}; use lightning::util::logger::Logger; @@ -109,7 +109,7 @@ impl DefaultRouter where G: Deref, L:: } } -impl Router for DefaultRouter +impl Router for DefaultRouter where G: Deref, L::Target: Logger { fn find_route( &self, payer: &PublicKey, params: &RouteParameters, _payment_hash: &PaymentHash, diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 631edfe6c..03522e24c 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -6552,7 +6552,7 @@ pub mod bench { use ln::msgs::{ChannelMessageHandler, Init}; use routing::network_graph::NetworkGraph; use routing::router::{Payee, get_route}; - use routing::scorer::Scorer; + use routing::scoring::Scorer; use util::test_utils; use util::config::UserConfig; use util::events::{Event, MessageSendEvent, MessageSendEventsProvider, PaymentPurpose}; diff --git a/lightning/src/routing/mod.rs b/lightning/src/routing/mod.rs index 91478bafc..a3ab6c0c1 100644 --- a/lightning/src/routing/mod.rs +++ b/lightning/src/routing/mod.rs @@ -11,75 +11,4 @@ pub mod network_graph; pub mod router; -pub mod scorer; - -use routing::network_graph::NodeId; -use routing::router::RouteHop; - -use core::cell::{RefCell, RefMut}; -use core::ops::DerefMut; -use sync::{Mutex, MutexGuard}; - -/// An interface used to score payment channels for path finding. -/// -/// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel. -pub trait Score { - /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the - /// given channel in the direction from `source` to `target`. - /// - /// The channel's capacity (less any other MPP parts which are also being considered for use in - /// the same payment) is given by `channel_capacity_msat`. It may be guessed from various - /// sources or assumed from no data at all. - /// - /// For hints provided in the invoice, we assume the channel has sufficient capacity to accept - /// the invoice's full amount, and provide a `channel_capacity_msat` of `None`. In all other - /// cases it is set to `Some`, even if we're guessing at the channel value. - /// - /// Your code should be overflow-safe through a `channel_capacity_msat` of 21 million BTC. - fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option, source: &NodeId, target: &NodeId) -> u64; - - /// Handles updating channel penalties after failing to route through a channel. - fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64); -} - -/// A scorer that is accessed under a lock. -/// -/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while -/// having shared ownership of a scorer but without requiring internal locking in [`Score`] -/// implementations. Internal locking would be detrimental to route finding performance and could -/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel. -/// -/// [`find_route`]: crate::routing::router::find_route -pub trait LockableScore<'a> { - /// The locked [`Score`] type. - type Locked: 'a + Score; - - /// Returns the locked scorer. - fn lock(&'a self) -> Self::Locked; -} - -impl<'a, T: 'a + Score> LockableScore<'a> for Mutex { - type Locked = MutexGuard<'a, T>; - - fn lock(&'a self) -> MutexGuard<'a, T> { - Mutex::lock(self).unwrap() - } -} - -impl<'a, T: 'a + Score> LockableScore<'a> for RefCell { - type Locked = RefMut<'a, T>; - - fn lock(&'a self) -> RefMut<'a, T> { - self.borrow_mut() - } -} - -impl> Score for T { - fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option, source: &NodeId, target: &NodeId) -> u64 { - self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, channel_capacity_msat, source, target) - } - - fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) { - self.deref_mut().payment_path_failed(path, short_channel_id) - } -} +pub mod scoring; diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index a98fb9912..90c01ec0d 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -17,7 +17,7 @@ use bitcoin::secp256k1::key::PublicKey; use ln::channelmanager::ChannelDetails; use ln::features::{ChannelFeatures, InvoiceFeatures, NodeFeatures}; use ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT}; -use routing; +use routing::scoring::Score; use routing::network_graph::{NetworkGraph, NodeId, RoutingFees}; use util::ser::{Writeable, Readable}; use util::logger::{Level, Logger}; @@ -529,7 +529,7 @@ fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option { /// /// [`ChannelManager::list_usable_channels`]: crate::ln::channelmanager::ChannelManager::list_usable_channels /// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed -pub fn find_route( +pub fn find_route( our_node_pubkey: &PublicKey, params: &RouteParameters, network: &NetworkGraph, first_hops: Option<&[&ChannelDetails]>, logger: L, scorer: &S ) -> Result @@ -540,7 +540,7 @@ where L::Target: Logger { ) } -pub(crate) fn get_route( +pub(crate) fn get_route( our_node_pubkey: &PublicKey, payee: &Payee, network: &NetworkGraph, first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, final_cltv_expiry_delta: u32, logger: L, scorer: &S @@ -1472,7 +1472,7 @@ where L::Target: Logger { #[cfg(test)] mod tests { - use routing; + use routing::scoring::Score; use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId}; use routing::router::{get_route, Payee, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees}; use chain::transaction::OutPoint; @@ -4549,7 +4549,7 @@ mod tests { short_channel_id: u64, } - impl routing::Score for BadChannelScorer { + impl Score for BadChannelScorer { fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, _target: &NodeId) -> u64 { if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 } } @@ -4561,7 +4561,7 @@ mod tests { node_id: NodeId, } - impl routing::Score for BadNodeScorer { + impl Score for BadNodeScorer { fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, target: &NodeId) -> u64 { if *target == self.node_id { u64::max_value() } else { 0 } } @@ -4787,7 +4787,7 @@ pub(crate) mod test_utils { #[cfg(all(test, feature = "unstable", not(feature = "no-std")))] mod benches { use super::*; - use routing::scorer::Scorer; + use routing::scoring::Scorer; use util::logger::{Logger, Record}; use test::Bencher; diff --git a/lightning/src/routing/scorer.rs b/lightning/src/routing/scorer.rs deleted file mode 100644 index 382eae25b..000000000 --- a/lightning/src/routing/scorer.rs +++ /dev/null @@ -1,624 +0,0 @@ -// This file is Copyright its original authors, visible in version control -// history. -// -// This file is licensed under the Apache License, Version 2.0 or the MIT license -// , at your option. -// You may not use this file except in accordance with one or both of these -// licenses. - -//! Utilities for scoring payment channels. -//! -//! [`Scorer`] may be given to [`find_route`] to score payment channels during path finding when a -//! custom [`routing::Score`] implementation is not needed. -//! -//! # Example -//! -//! ``` -//! # extern crate secp256k1; -//! # -//! # use lightning::routing::network_graph::NetworkGraph; -//! # use lightning::routing::router::{RouteParameters, find_route}; -//! # use lightning::routing::scorer::{Scorer, ScoringParameters}; -//! # use lightning::util::logger::{Logger, Record}; -//! # use secp256k1::key::PublicKey; -//! # -//! # struct FakeLogger {}; -//! # impl Logger for FakeLogger { -//! # fn log(&self, record: &Record) { unimplemented!() } -//! # } -//! # fn find_scored_route(payer: PublicKey, params: RouteParameters, network_graph: NetworkGraph) { -//! # let logger = FakeLogger {}; -//! # -//! // Use the default channel penalties. -//! let scorer = Scorer::default(); -//! -//! // Or use custom channel penalties. -//! let scorer = Scorer::new(ScoringParameters { -//! base_penalty_msat: 1000, -//! failure_penalty_msat: 2 * 1024 * 1000, -//! ..ScoringParameters::default() -//! }); -//! -//! let route = find_route(&payer, ¶ms, &network_graph, None, &logger, &scorer); -//! # } -//! ``` -//! -//! # Note -//! -//! If persisting [`Scorer`], it must be restored using the same [`Time`] parameterization. Using a -//! different type results in undefined behavior. Specifically, persisting when built with feature -//! `no-std` and restoring without it, or vice versa, uses different types and thus is undefined. -//! -//! [`find_route`]: crate::routing::router::find_route - -use routing; - -use ln::msgs::DecodeError; -use routing::network_graph::NodeId; -use routing::router::RouteHop; -use util::ser::{Readable, Writeable, Writer}; - -use prelude::*; -use core::ops::Sub; -use core::time::Duration; -use io::{self, Read}; - -/// [`routing::Score`] implementation that provides reasonable default behavior. -/// -/// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with -/// slightly higher fees are available. Will further penalize channels that fail to relay payments. -/// -/// See [module-level documentation] for usage. -/// -/// [module-level documentation]: crate::routing::scorer -pub type Scorer = ScorerUsingTime::; - -/// Time used by [`Scorer`]. -#[cfg(not(feature = "no-std"))] -pub type DefaultTime = std::time::Instant; - -/// Time used by [`Scorer`]. -#[cfg(feature = "no-std")] -pub type DefaultTime = Eternity; - -/// [`routing::Score`] implementation parameterized by [`Time`]. -/// -/// See [`Scorer`] for details. -/// -/// # Note -/// -/// Mixing [`Time`] types between serialization and deserialization results in undefined behavior. -pub struct ScorerUsingTime { - params: ScoringParameters, - // TODO: Remove entries of closed channels. - channel_failures: HashMap>, -} - -/// Parameters for configuring [`Scorer`]. -pub struct ScoringParameters { - /// A fixed penalty in msats to apply to each channel. - /// - /// Default value: 500 msat - pub base_penalty_msat: u64, - - /// A penalty in msats to apply to a channel upon failing to relay a payment. - /// - /// This accumulates for each failure but may be reduced over time based on - /// [`failure_penalty_half_life`]. - /// - /// Default value: 1,024,000 msat - /// - /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life - pub failure_penalty_msat: u64, - - /// When the amount being sent over a channel is this many 1024ths of the total channel - /// capacity, we begin applying [`overuse_penalty_msat_per_1024th`]. - /// - /// Default value: 128 1024ths (i.e. begin penalizing when an HTLC uses 1/8th of a channel) - /// - /// [`overuse_penalty_msat_per_1024th`]: Self::overuse_penalty_msat_per_1024th - pub overuse_penalty_start_1024th: u16, - - /// A penalty applied, per whole 1024ths of the channel capacity which the amount being sent - /// over the channel exceeds [`overuse_penalty_start_1024th`] by. - /// - /// Default value: 20 msat (i.e. 2560 msat penalty to use 1/4th of a channel, 7680 msat penalty - /// to use half a channel, and 12,560 msat penalty to use 3/4ths of a channel) - /// - /// [`overuse_penalty_start_1024th`]: Self::overuse_penalty_start_1024th - pub overuse_penalty_msat_per_1024th: u64, - - /// The time required to elapse before any accumulated [`failure_penalty_msat`] penalties are - /// cut in half. - /// - /// # Note - /// - /// When time is an [`Eternity`], as is default when enabling feature `no-std`, it will never - /// elapse. Therefore, this penalty will never decay. - /// - /// [`failure_penalty_msat`]: Self::failure_penalty_msat - pub failure_penalty_half_life: Duration, -} - -impl_writeable_tlv_based!(ScoringParameters, { - (0, base_penalty_msat, required), - (1, overuse_penalty_start_1024th, (default_value, 128)), - (2, failure_penalty_msat, required), - (3, overuse_penalty_msat_per_1024th, (default_value, 20)), - (4, failure_penalty_half_life, required), -}); - -/// Accounting for penalties against a channel for failing to relay any payments. -/// -/// Penalties decay over time, though accumulate as more failures occur. -struct ChannelFailure { - /// Accumulated penalty in msats for the channel as of `last_failed`. - undecayed_penalty_msat: u64, - - /// Last time the channel failed. Used to decay `undecayed_penalty_msat`. - last_failed: T, -} - -/// A measurement of time. -pub trait Time: Sub where Self: Sized { - /// Returns an instance corresponding to the current moment. - fn now() -> Self; - - /// Returns the amount of time elapsed since `self` was created. - fn elapsed(&self) -> Duration; - - /// Returns the amount of time passed since the beginning of [`Time`]. - /// - /// Used during (de-)serialization. - fn duration_since_epoch() -> Duration; -} - -impl ScorerUsingTime { - /// Creates a new scorer using the given scoring parameters. - pub fn new(params: ScoringParameters) -> Self { - Self { - params, - channel_failures: HashMap::new(), - } - } - - /// Creates a new scorer using `penalty_msat` as a fixed channel penalty. - #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))] - pub fn with_fixed_penalty(penalty_msat: u64) -> Self { - Self::new(ScoringParameters { - base_penalty_msat: penalty_msat, - failure_penalty_msat: 0, - failure_penalty_half_life: Duration::from_secs(0), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }) - } -} - -impl ChannelFailure { - fn new(failure_penalty_msat: u64) -> Self { - Self { - undecayed_penalty_msat: failure_penalty_msat, - last_failed: T::now(), - } - } - - fn add_penalty(&mut self, failure_penalty_msat: u64, half_life: Duration) { - self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) + failure_penalty_msat; - self.last_failed = T::now(); - } - - fn decayed_penalty_msat(&self, half_life: Duration) -> u64 { - let decays = self.last_failed.elapsed().as_secs().checked_div(half_life.as_secs()); - match decays { - Some(decays) => self.undecayed_penalty_msat >> decays, - None => 0, - } - } -} - -impl Default for ScorerUsingTime { - fn default() -> Self { - Self::new(ScoringParameters::default()) - } -} - -impl Default for ScoringParameters { - fn default() -> Self { - Self { - base_penalty_msat: 500, - failure_penalty_msat: 1024 * 1000, - failure_penalty_half_life: Duration::from_secs(3600), - overuse_penalty_start_1024th: 1024 / 8, - overuse_penalty_msat_per_1024th: 20, - } - } -} - -impl routing::Score for ScorerUsingTime { - fn channel_penalty_msat( - &self, short_channel_id: u64, send_amt_msat: u64, chan_capacity_opt: Option, _source: &NodeId, _target: &NodeId - ) -> u64 { - let failure_penalty_msat = self.channel_failures - .get(&short_channel_id) - .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life)); - - let mut penalty_msat = self.params.base_penalty_msat + failure_penalty_msat; - - if let Some(chan_capacity_msat) = chan_capacity_opt { - let send_1024ths = send_amt_msat.checked_mul(1024).unwrap_or(u64::max_value()) / chan_capacity_msat; - - if send_1024ths > self.params.overuse_penalty_start_1024th as u64 { - penalty_msat = penalty_msat.checked_add( - (send_1024ths - self.params.overuse_penalty_start_1024th as u64) - .checked_mul(self.params.overuse_penalty_msat_per_1024th).unwrap_or(u64::max_value())) - .unwrap_or(u64::max_value()); - } - } - - penalty_msat - } - - fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) { - let failure_penalty_msat = self.params.failure_penalty_msat; - let half_life = self.params.failure_penalty_half_life; - self.channel_failures - .entry(short_channel_id) - .and_modify(|failure| failure.add_penalty(failure_penalty_msat, half_life)) - .or_insert_with(|| ChannelFailure::new(failure_penalty_msat)); - } -} - -#[cfg(not(feature = "no-std"))] -impl Time for std::time::Instant { - fn now() -> Self { - std::time::Instant::now() - } - - fn duration_since_epoch() -> Duration { - use std::time::SystemTime; - SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap() - } - - fn elapsed(&self) -> Duration { - std::time::Instant::elapsed(self) - } -} - -/// A state in which time has no meaning. -#[derive(Debug, PartialEq, Eq)] -pub struct Eternity; - -impl Time for Eternity { - fn now() -> Self { - Self - } - - fn duration_since_epoch() -> Duration { - Duration::from_secs(0) - } - - fn elapsed(&self) -> Duration { - Duration::from_secs(0) - } -} - -impl Sub for Eternity { - type Output = Self; - - fn sub(self, _other: Duration) -> Self { - self - } -} - -impl Writeable for ScorerUsingTime { - #[inline] - fn write(&self, w: &mut W) -> Result<(), io::Error> { - self.params.write(w)?; - self.channel_failures.write(w)?; - write_tlv_fields!(w, {}); - Ok(()) - } -} - -impl Readable for ScorerUsingTime { - #[inline] - fn read(r: &mut R) -> Result { - let res = Ok(Self { - params: Readable::read(r)?, - channel_failures: Readable::read(r)?, - }); - read_tlv_fields!(r, {}); - res - } -} - -impl Writeable for ChannelFailure { - #[inline] - fn write(&self, w: &mut W) -> Result<(), io::Error> { - let duration_since_epoch = T::duration_since_epoch() - self.last_failed.elapsed(); - write_tlv_fields!(w, { - (0, self.undecayed_penalty_msat, required), - (2, duration_since_epoch, required), - }); - Ok(()) - } -} - -impl Readable for ChannelFailure { - #[inline] - fn read(r: &mut R) -> Result { - let mut undecayed_penalty_msat = 0; - let mut duration_since_epoch = Duration::from_secs(0); - read_tlv_fields!(r, { - (0, undecayed_penalty_msat, required), - (2, duration_since_epoch, required), - }); - Ok(Self { - undecayed_penalty_msat, - last_failed: T::now() - (T::duration_since_epoch() - duration_since_epoch), - }) - } -} - -#[cfg(test)] -mod tests { - use super::{Eternity, ScoringParameters, ScorerUsingTime, Time}; - - use routing::Score; - use routing::network_graph::NodeId; - use util::ser::{Readable, Writeable}; - - use bitcoin::secp256k1::PublicKey; - use core::cell::Cell; - use core::ops::Sub; - use core::time::Duration; - use io; - - /// Time that can be advanced manually in tests. - #[derive(Debug, PartialEq, Eq)] - struct SinceEpoch(Duration); - - impl SinceEpoch { - thread_local! { - static ELAPSED: Cell = core::cell::Cell::new(Duration::from_secs(0)); - } - - fn advance(duration: Duration) { - Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration)) - } - } - - impl Time for SinceEpoch { - fn now() -> Self { - Self(Self::duration_since_epoch()) - } - - fn duration_since_epoch() -> Duration { - Self::ELAPSED.with(|elapsed| elapsed.get()) - } - - fn elapsed(&self) -> Duration { - Self::duration_since_epoch() - self.0 - } - } - - impl Sub for SinceEpoch { - type Output = Self; - - fn sub(self, other: Duration) -> Self { - Self(self.0 - other) - } - } - - #[test] - fn time_passes_when_advanced() { - let now = SinceEpoch::now(); - assert_eq!(now.elapsed(), Duration::from_secs(0)); - - SinceEpoch::advance(Duration::from_secs(1)); - SinceEpoch::advance(Duration::from_secs(1)); - - let elapsed = now.elapsed(); - let later = SinceEpoch::now(); - - assert_eq!(elapsed, Duration::from_secs(2)); - assert_eq!(later - elapsed, now); - } - - #[test] - fn time_never_passes_in_an_eternity() { - let now = Eternity::now(); - let elapsed = now.elapsed(); - let later = Eternity::now(); - - assert_eq!(now.elapsed(), Duration::from_secs(0)); - assert_eq!(later - elapsed, now); - } - - /// A scorer for testing with time that can be manually advanced. - type Scorer = ScorerUsingTime::; - - fn source_node_id() -> NodeId { - NodeId::from_pubkey(&PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap()) - } - - fn target_node_id() -> NodeId { - NodeId::from_pubkey(&PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap()) - } - - #[test] - fn penalizes_without_channel_failures() { - let scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 1_000, - failure_penalty_msat: 512, - failure_penalty_half_life: Duration::from_secs(1), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }); - let source = source_node_id(); - let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - - SinceEpoch::advance(Duration::from_secs(1)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - } - - #[test] - fn accumulates_channel_failure_penalties() { - let mut scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 1_000, - failure_penalty_msat: 64, - failure_penalty_half_life: Duration::from_secs(10), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }); - let source = source_node_id(); - let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_064); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_192); - } - - #[test] - fn decays_channel_failure_penalties_over_time() { - let mut scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 1_000, - failure_penalty_msat: 512, - failure_penalty_half_life: Duration::from_secs(10), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }); - let source = source_node_id(); - let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); - - SinceEpoch::advance(Duration::from_secs(9)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); - - SinceEpoch::advance(Duration::from_secs(1)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); - - SinceEpoch::advance(Duration::from_secs(10 * 8)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_001); - - SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - - SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - } - - #[test] - fn accumulates_channel_failure_penalties_after_decay() { - let mut scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 1_000, - failure_penalty_msat: 512, - failure_penalty_half_life: Duration::from_secs(10), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }); - let source = source_node_id(); - let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); - - SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_768); - - SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_384); - } - - #[test] - fn restores_persisted_channel_failure_penalties() { - let mut scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 1_000, - failure_penalty_msat: 512, - failure_penalty_half_life: Duration::from_secs(10), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }); - let source = source_node_id(); - let target = target_node_id(); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); - - SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); - - scorer.payment_path_failed(&[], 43); - assert_eq!(scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512); - - let mut serialized_scorer = Vec::new(); - scorer.write(&mut serialized_scorer).unwrap(); - - let deserialized_scorer = ::read(&mut io::Cursor::new(&serialized_scorer)).unwrap(); - assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); - assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512); - } - - #[test] - fn decays_persisted_channel_failure_penalties() { - let mut scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 1_000, - failure_penalty_msat: 512, - failure_penalty_half_life: Duration::from_secs(10), - overuse_penalty_start_1024th: 1024, - overuse_penalty_msat_per_1024th: 0, - }); - let source = source_node_id(); - let target = target_node_id(); - - scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); - - let mut serialized_scorer = Vec::new(); - scorer.write(&mut serialized_scorer).unwrap(); - - SinceEpoch::advance(Duration::from_secs(10)); - - let deserialized_scorer = ::read(&mut io::Cursor::new(&serialized_scorer)).unwrap(); - assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); - - SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128); - } - - #[test] - fn charges_per_1024th_penalty() { - let scorer = Scorer::new(ScoringParameters { - base_penalty_msat: 0, - failure_penalty_msat: 0, - failure_penalty_half_life: Duration::from_secs(0), - overuse_penalty_start_1024th: 256, - overuse_penalty_msat_per_1024th: 100, - }); - let source = source_node_id(); - let target = target_node_id(); - - assert_eq!(scorer.channel_penalty_msat(42, 1_000, None, &source, &target), 0); - assert_eq!(scorer.channel_penalty_msat(42, 1_000, Some(1_024_000), &source, &target), 0); - assert_eq!(scorer.channel_penalty_msat(42, 256_999, Some(1_024_000), &source, &target), 0); - assert_eq!(scorer.channel_penalty_msat(42, 257_000, Some(1_024_000), &source, &target), 100); - assert_eq!(scorer.channel_penalty_msat(42, 258_000, Some(1_024_000), &source, &target), 200); - assert_eq!(scorer.channel_penalty_msat(42, 512_000, Some(1_024_000), &source, &target), 256 * 100); - } -} diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs new file mode 100644 index 000000000..a2d314665 --- /dev/null +++ b/lightning/src/routing/scoring.rs @@ -0,0 +1,687 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! Utilities for scoring payment channels. +//! +//! [`Scorer`] may be given to [`find_route`] to score payment channels during path finding when a +//! custom [`Score`] implementation is not needed. +//! +//! # Example +//! +//! ``` +//! # extern crate secp256k1; +//! # +//! # use lightning::routing::network_graph::NetworkGraph; +//! # use lightning::routing::router::{RouteParameters, find_route}; +//! # use lightning::routing::scoring::{Scorer, ScoringParameters}; +//! # use lightning::util::logger::{Logger, Record}; +//! # use secp256k1::key::PublicKey; +//! # +//! # struct FakeLogger {}; +//! # impl Logger for FakeLogger { +//! # fn log(&self, record: &Record) { unimplemented!() } +//! # } +//! # fn find_scored_route(payer: PublicKey, params: RouteParameters, network_graph: NetworkGraph) { +//! # let logger = FakeLogger {}; +//! # +//! // Use the default channel penalties. +//! let scorer = Scorer::default(); +//! +//! // Or use custom channel penalties. +//! let scorer = Scorer::new(ScoringParameters { +//! base_penalty_msat: 1000, +//! failure_penalty_msat: 2 * 1024 * 1000, +//! ..ScoringParameters::default() +//! }); +//! +//! let route = find_route(&payer, ¶ms, &network_graph, None, &logger, &scorer); +//! # } +//! ``` +//! +//! # Note +//! +//! If persisting [`Scorer`], it must be restored using the same [`Time`] parameterization. Using a +//! different type results in undefined behavior. Specifically, persisting when built with feature +//! `no-std` and restoring without it, or vice versa, uses different types and thus is undefined. +//! +//! [`find_route`]: crate::routing::router::find_route + +use ln::msgs::DecodeError; +use routing::network_graph::NodeId; +use routing::router::RouteHop; +use util::ser::{Readable, Writeable, Writer}; + +use prelude::*; +use core::cell::{RefCell, RefMut}; +use core::ops::{DerefMut, Sub}; +use core::time::Duration; +use io::{self, Read}; use sync::{Mutex, MutexGuard}; + +/// An interface used to score payment channels for path finding. +/// +/// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel. +pub trait Score { + /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the + /// given channel in the direction from `source` to `target`. + /// + /// The channel's capacity (less any other MPP parts which are also being considered for use in + /// the same payment) is given by `channel_capacity_msat`. It may be guessed from various + /// sources or assumed from no data at all. + /// + /// For hints provided in the invoice, we assume the channel has sufficient capacity to accept + /// the invoice's full amount, and provide a `channel_capacity_msat` of `None`. In all other + /// cases it is set to `Some`, even if we're guessing at the channel value. + /// + /// Your code should be overflow-safe through a `channel_capacity_msat` of 21 million BTC. + fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option, source: &NodeId, target: &NodeId) -> u64; + + /// Handles updating channel penalties after failing to route through a channel. + fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64); +} + +/// A scorer that is accessed under a lock. +/// +/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while +/// having shared ownership of a scorer but without requiring internal locking in [`Score`] +/// implementations. Internal locking would be detrimental to route finding performance and could +/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel. +/// +/// [`find_route`]: crate::routing::router::find_route +pub trait LockableScore<'a> { + /// The locked [`Score`] type. + type Locked: 'a + Score; + + /// Returns the locked scorer. + fn lock(&'a self) -> Self::Locked; +} + +impl<'a, T: 'a + Score> LockableScore<'a> for Mutex { + type Locked = MutexGuard<'a, T>; + + fn lock(&'a self) -> MutexGuard<'a, T> { + Mutex::lock(self).unwrap() + } +} + +impl<'a, T: 'a + Score> LockableScore<'a> for RefCell { + type Locked = RefMut<'a, T>; + + fn lock(&'a self) -> RefMut<'a, T> { + self.borrow_mut() + } +} + +impl> Score for T { + fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option, source: &NodeId, target: &NodeId) -> u64 { + self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, channel_capacity_msat, source, target) + } + + fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) { + self.deref_mut().payment_path_failed(path, short_channel_id) + } +} + +/// [`Score`] implementation that provides reasonable default behavior. +/// +/// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with +/// slightly higher fees are available. Will further penalize channels that fail to relay payments. +/// +/// See [module-level documentation] for usage. +/// +/// [module-level documentation]: crate::routing::scoring +pub type Scorer = ScorerUsingTime::; + +/// Time used by [`Scorer`]. +#[cfg(not(feature = "no-std"))] +pub type DefaultTime = std::time::Instant; + +/// Time used by [`Scorer`]. +#[cfg(feature = "no-std")] +pub type DefaultTime = Eternity; + +/// [`Score`] implementation parameterized by [`Time`]. +/// +/// See [`Scorer`] for details. +/// +/// # Note +/// +/// Mixing [`Time`] types between serialization and deserialization results in undefined behavior. +pub struct ScorerUsingTime { + params: ScoringParameters, + // TODO: Remove entries of closed channels. + channel_failures: HashMap>, +} + +/// Parameters for configuring [`Scorer`]. +pub struct ScoringParameters { + /// A fixed penalty in msats to apply to each channel. + /// + /// Default value: 500 msat + pub base_penalty_msat: u64, + + /// A penalty in msats to apply to a channel upon failing to relay a payment. + /// + /// This accumulates for each failure but may be reduced over time based on + /// [`failure_penalty_half_life`]. + /// + /// Default value: 1,024,000 msat + /// + /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life + pub failure_penalty_msat: u64, + + /// When the amount being sent over a channel is this many 1024ths of the total channel + /// capacity, we begin applying [`overuse_penalty_msat_per_1024th`]. + /// + /// Default value: 128 1024ths (i.e. begin penalizing when an HTLC uses 1/8th of a channel) + /// + /// [`overuse_penalty_msat_per_1024th`]: Self::overuse_penalty_msat_per_1024th + pub overuse_penalty_start_1024th: u16, + + /// A penalty applied, per whole 1024ths of the channel capacity which the amount being sent + /// over the channel exceeds [`overuse_penalty_start_1024th`] by. + /// + /// Default value: 20 msat (i.e. 2560 msat penalty to use 1/4th of a channel, 7680 msat penalty + /// to use half a channel, and 12,560 msat penalty to use 3/4ths of a channel) + /// + /// [`overuse_penalty_start_1024th`]: Self::overuse_penalty_start_1024th + pub overuse_penalty_msat_per_1024th: u64, + + /// The time required to elapse before any accumulated [`failure_penalty_msat`] penalties are + /// cut in half. + /// + /// # Note + /// + /// When time is an [`Eternity`], as is default when enabling feature `no-std`, it will never + /// elapse. Therefore, this penalty will never decay. + /// + /// [`failure_penalty_msat`]: Self::failure_penalty_msat + pub failure_penalty_half_life: Duration, +} + +impl_writeable_tlv_based!(ScoringParameters, { + (0, base_penalty_msat, required), + (1, overuse_penalty_start_1024th, (default_value, 128)), + (2, failure_penalty_msat, required), + (3, overuse_penalty_msat_per_1024th, (default_value, 20)), + (4, failure_penalty_half_life, required), +}); + +/// Accounting for penalties against a channel for failing to relay any payments. +/// +/// Penalties decay over time, though accumulate as more failures occur. +struct ChannelFailure { + /// Accumulated penalty in msats for the channel as of `last_failed`. + undecayed_penalty_msat: u64, + + /// Last time the channel failed. Used to decay `undecayed_penalty_msat`. + last_failed: T, +} + +/// A measurement of time. +pub trait Time: Sub where Self: Sized { + /// Returns an instance corresponding to the current moment. + fn now() -> Self; + + /// Returns the amount of time elapsed since `self` was created. + fn elapsed(&self) -> Duration; + + /// Returns the amount of time passed since the beginning of [`Time`]. + /// + /// Used during (de-)serialization. + fn duration_since_epoch() -> Duration; +} + +impl ScorerUsingTime { + /// Creates a new scorer using the given scoring parameters. + pub fn new(params: ScoringParameters) -> Self { + Self { + params, + channel_failures: HashMap::new(), + } + } + + /// Creates a new scorer using `penalty_msat` as a fixed channel penalty. + #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))] + pub fn with_fixed_penalty(penalty_msat: u64) -> Self { + Self::new(ScoringParameters { + base_penalty_msat: penalty_msat, + failure_penalty_msat: 0, + failure_penalty_half_life: Duration::from_secs(0), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }) + } +} + +impl ChannelFailure { + fn new(failure_penalty_msat: u64) -> Self { + Self { + undecayed_penalty_msat: failure_penalty_msat, + last_failed: T::now(), + } + } + + fn add_penalty(&mut self, failure_penalty_msat: u64, half_life: Duration) { + self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) + failure_penalty_msat; + self.last_failed = T::now(); + } + + fn decayed_penalty_msat(&self, half_life: Duration) -> u64 { + let decays = self.last_failed.elapsed().as_secs().checked_div(half_life.as_secs()); + match decays { + Some(decays) => self.undecayed_penalty_msat >> decays, + None => 0, + } + } +} + +impl Default for ScorerUsingTime { + fn default() -> Self { + Self::new(ScoringParameters::default()) + } +} + +impl Default for ScoringParameters { + fn default() -> Self { + Self { + base_penalty_msat: 500, + failure_penalty_msat: 1024 * 1000, + failure_penalty_half_life: Duration::from_secs(3600), + overuse_penalty_start_1024th: 1024 / 8, + overuse_penalty_msat_per_1024th: 20, + } + } +} + +impl Score for ScorerUsingTime { + fn channel_penalty_msat( + &self, short_channel_id: u64, send_amt_msat: u64, chan_capacity_opt: Option, _source: &NodeId, _target: &NodeId + ) -> u64 { + let failure_penalty_msat = self.channel_failures + .get(&short_channel_id) + .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life)); + + let mut penalty_msat = self.params.base_penalty_msat + failure_penalty_msat; + + if let Some(chan_capacity_msat) = chan_capacity_opt { + let send_1024ths = send_amt_msat.checked_mul(1024).unwrap_or(u64::max_value()) / chan_capacity_msat; + + if send_1024ths > self.params.overuse_penalty_start_1024th as u64 { + penalty_msat = penalty_msat.checked_add( + (send_1024ths - self.params.overuse_penalty_start_1024th as u64) + .checked_mul(self.params.overuse_penalty_msat_per_1024th).unwrap_or(u64::max_value())) + .unwrap_or(u64::max_value()); + } + } + + penalty_msat + } + + fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) { + let failure_penalty_msat = self.params.failure_penalty_msat; + let half_life = self.params.failure_penalty_half_life; + self.channel_failures + .entry(short_channel_id) + .and_modify(|failure| failure.add_penalty(failure_penalty_msat, half_life)) + .or_insert_with(|| ChannelFailure::new(failure_penalty_msat)); + } +} + +#[cfg(not(feature = "no-std"))] +impl Time for std::time::Instant { + fn now() -> Self { + std::time::Instant::now() + } + + fn duration_since_epoch() -> Duration { + use std::time::SystemTime; + SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap() + } + + fn elapsed(&self) -> Duration { + std::time::Instant::elapsed(self) + } +} + +/// A state in which time has no meaning. +#[derive(Debug, PartialEq, Eq)] +pub struct Eternity; + +impl Time for Eternity { + fn now() -> Self { + Self + } + + fn duration_since_epoch() -> Duration { + Duration::from_secs(0) + } + + fn elapsed(&self) -> Duration { + Duration::from_secs(0) + } +} + +impl Sub for Eternity { + type Output = Self; + + fn sub(self, _other: Duration) -> Self { + self + } +} + +impl Writeable for ScorerUsingTime { + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + self.params.write(w)?; + self.channel_failures.write(w)?; + write_tlv_fields!(w, {}); + Ok(()) + } +} + +impl Readable for ScorerUsingTime { + #[inline] + fn read(r: &mut R) -> Result { + let res = Ok(Self { + params: Readable::read(r)?, + channel_failures: Readable::read(r)?, + }); + read_tlv_fields!(r, {}); + res + } +} + +impl Writeable for ChannelFailure { + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + let duration_since_epoch = T::duration_since_epoch() - self.last_failed.elapsed(); + write_tlv_fields!(w, { + (0, self.undecayed_penalty_msat, required), + (2, duration_since_epoch, required), + }); + Ok(()) + } +} + +impl Readable for ChannelFailure { + #[inline] + fn read(r: &mut R) -> Result { + let mut undecayed_penalty_msat = 0; + let mut duration_since_epoch = Duration::from_secs(0); + read_tlv_fields!(r, { + (0, undecayed_penalty_msat, required), + (2, duration_since_epoch, required), + }); + Ok(Self { + undecayed_penalty_msat, + last_failed: T::now() - (T::duration_since_epoch() - duration_since_epoch), + }) + } +} + +#[cfg(test)] +mod tests { + use super::{Eternity, ScoringParameters, ScorerUsingTime, Time}; + + use routing::scoring::Score; + use routing::network_graph::NodeId; + use util::ser::{Readable, Writeable}; + + use bitcoin::secp256k1::PublicKey; + use core::cell::Cell; + use core::ops::Sub; + use core::time::Duration; + use io; + + /// Time that can be advanced manually in tests. + #[derive(Debug, PartialEq, Eq)] + struct SinceEpoch(Duration); + + impl SinceEpoch { + thread_local! { + static ELAPSED: Cell = core::cell::Cell::new(Duration::from_secs(0)); + } + + fn advance(duration: Duration) { + Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration)) + } + } + + impl Time for SinceEpoch { + fn now() -> Self { + Self(Self::duration_since_epoch()) + } + + fn duration_since_epoch() -> Duration { + Self::ELAPSED.with(|elapsed| elapsed.get()) + } + + fn elapsed(&self) -> Duration { + Self::duration_since_epoch() - self.0 + } + } + + impl Sub for SinceEpoch { + type Output = Self; + + fn sub(self, other: Duration) -> Self { + Self(self.0 - other) + } + } + + #[test] + fn time_passes_when_advanced() { + let now = SinceEpoch::now(); + assert_eq!(now.elapsed(), Duration::from_secs(0)); + + SinceEpoch::advance(Duration::from_secs(1)); + SinceEpoch::advance(Duration::from_secs(1)); + + let elapsed = now.elapsed(); + let later = SinceEpoch::now(); + + assert_eq!(elapsed, Duration::from_secs(2)); + assert_eq!(later - elapsed, now); + } + + #[test] + fn time_never_passes_in_an_eternity() { + let now = Eternity::now(); + let elapsed = now.elapsed(); + let later = Eternity::now(); + + assert_eq!(now.elapsed(), Duration::from_secs(0)); + assert_eq!(later - elapsed, now); + } + + /// A scorer for testing with time that can be manually advanced. + type Scorer = ScorerUsingTime::; + + fn source_node_id() -> NodeId { + NodeId::from_pubkey(&PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap()) + } + + fn target_node_id() -> NodeId { + NodeId::from_pubkey(&PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap()) + } + + #[test] + fn penalizes_without_channel_failures() { + let scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 1_000, + failure_penalty_msat: 512, + failure_penalty_half_life: Duration::from_secs(1), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }); + let source = source_node_id(); + let target = target_node_id(); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + + SinceEpoch::advance(Duration::from_secs(1)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + } + + #[test] + fn accumulates_channel_failure_penalties() { + let mut scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 1_000, + failure_penalty_msat: 64, + failure_penalty_half_life: Duration::from_secs(10), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }); + let source = source_node_id(); + let target = target_node_id(); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_064); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_192); + } + + #[test] + fn decays_channel_failure_penalties_over_time() { + let mut scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 1_000, + failure_penalty_msat: 512, + failure_penalty_half_life: Duration::from_secs(10), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }); + let source = source_node_id(); + let target = target_node_id(); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); + + SinceEpoch::advance(Duration::from_secs(9)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); + + SinceEpoch::advance(Duration::from_secs(1)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); + + SinceEpoch::advance(Duration::from_secs(10 * 8)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_001); + + SinceEpoch::advance(Duration::from_secs(10)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + + SinceEpoch::advance(Duration::from_secs(10)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + } + + #[test] + fn accumulates_channel_failure_penalties_after_decay() { + let mut scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 1_000, + failure_penalty_msat: 512, + failure_penalty_half_life: Duration::from_secs(10), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }); + let source = source_node_id(); + let target = target_node_id(); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); + + SinceEpoch::advance(Duration::from_secs(10)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_768); + + SinceEpoch::advance(Duration::from_secs(10)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_384); + } + + #[test] + fn restores_persisted_channel_failure_penalties() { + let mut scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 1_000, + failure_penalty_msat: 512, + failure_penalty_half_life: Duration::from_secs(10), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }); + let source = source_node_id(); + let target = target_node_id(); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); + + SinceEpoch::advance(Duration::from_secs(10)); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); + + scorer.payment_path_failed(&[], 43); + assert_eq!(scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512); + + let mut serialized_scorer = Vec::new(); + scorer.write(&mut serialized_scorer).unwrap(); + + let deserialized_scorer = ::read(&mut io::Cursor::new(&serialized_scorer)).unwrap(); + assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); + assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512); + } + + #[test] + fn decays_persisted_channel_failure_penalties() { + let mut scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 1_000, + failure_penalty_msat: 512, + failure_penalty_half_life: Duration::from_secs(10), + overuse_penalty_start_1024th: 1024, + overuse_penalty_msat_per_1024th: 0, + }); + let source = source_node_id(); + let target = target_node_id(); + + scorer.payment_path_failed(&[], 42); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); + + let mut serialized_scorer = Vec::new(); + scorer.write(&mut serialized_scorer).unwrap(); + + SinceEpoch::advance(Duration::from_secs(10)); + + let deserialized_scorer = ::read(&mut io::Cursor::new(&serialized_scorer)).unwrap(); + assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); + + SinceEpoch::advance(Duration::from_secs(10)); + assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128); + } + + #[test] + fn charges_per_1024th_penalty() { + let scorer = Scorer::new(ScoringParameters { + base_penalty_msat: 0, + failure_penalty_msat: 0, + failure_penalty_half_life: Duration::from_secs(0), + overuse_penalty_start_1024th: 256, + overuse_penalty_msat_per_1024th: 100, + }); + let source = source_node_id(); + let target = target_node_id(); + + assert_eq!(scorer.channel_penalty_msat(42, 1_000, None, &source, &target), 0); + assert_eq!(scorer.channel_penalty_msat(42, 1_000, Some(1_024_000), &source, &target), 0); + assert_eq!(scorer.channel_penalty_msat(42, 256_999, Some(1_024_000), &source, &target), 0); + assert_eq!(scorer.channel_penalty_msat(42, 257_000, Some(1_024_000), &source, &target), 100); + assert_eq!(scorer.channel_penalty_msat(42, 258_000, Some(1_024_000), &source, &target), 200); + assert_eq!(scorer.channel_penalty_msat(42, 512_000, Some(1_024_000), &source, &target), 256 * 100); + } +} diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 4734a1cb4..45f21e0b3 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -21,7 +21,7 @@ use ln::features::{ChannelFeatures, InitFeatures}; use ln::msgs; use ln::msgs::OptionalField; use ln::script::ShutdownScript; -use routing::scorer::{Eternity, ScorerUsingTime}; +use routing::scoring::{Eternity, ScorerUsingTime}; use util::enforcing_trait_impls::{EnforcingSigner, EnforcementState}; use util::events; use util::logger::{Logger, Level, Record};