]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Notify scorer of failing payment path and channel
authorJeffrey Czyz <jkczyz@gmail.com>
Thu, 14 Oct 2021 18:04:39 +0000 (13:04 -0500)
committerJeffrey Czyz <jkczyz@gmail.com>
Fri, 29 Oct 2021 19:24:53 +0000 (14:24 -0500)
Upon receiving a PaymentPathFailed event, the failing payment may be
retried on a different path. To avoid using the channel responsible for
the failure, a scorer should be notified of the failure before being
used to find a new route.

Add a payment_path_failed method to routing::Score and call it in
InvoicePayer's event handler. Introduce a LockableScore parameterization
to InvoicePayer so the scorer is locked only once before calling
find_route.

lightning-background-processor/src/lib.rs
lightning-invoice/src/payment.rs
lightning-invoice/src/utils.rs
lightning/src/routing/mod.rs
lightning/src/routing/router.rs
lightning/src/routing/scorer.rs

index e38a4a975b2a8ff4a9c750081c9d43166ffd21af..16c5eb4f9393193ba9136fd645661625946359e6 100644 (file)
@@ -311,6 +311,7 @@ mod tests {
        use lightning::ln::features::InitFeatures;
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
+       use lightning::routing::scorer::Scorer;
        use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use lightning::util::config::UserConfig;
        use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
index 7e931d66f153a4f530d7cbb9af1623dfb1ee4e2c..ba260742bb41af8e6d1a500f58aee29a014a492b 100644 (file)
 //! # use lightning::ln::{PaymentHash, PaymentSecret};
 //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 //! # use lightning::ln::msgs::LightningError;
-//! # use lightning::routing::router::{Route, RouteParameters};
+//! # use lightning::routing;
+//! # use lightning::routing::network_graph::NodeId;
+//! # use lightning::routing::router::{Route, RouteHop, RouteParameters};
 //! # use lightning::util::events::{Event, EventHandler, EventsProvider};
 //! # use lightning::util::logger::{Logger, Record};
 //! # use lightning_invoice::Invoice;
 //! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router};
 //! # use secp256k1::key::PublicKey;
+//! # use std::cell::RefCell;
 //! # use std::ops::Deref;
 //! #
 //! # struct FakeEventProvider {}
 //! # }
 //! #
 //! # struct FakeRouter {};
-//! # impl Router for FakeRouter {
+//! # impl<S: routing::Score> Router<S> for FakeRouter {
 //! #     fn find_route(
 //! #         &self, payer: &PublicKey, params: &RouteParameters,
-//! #         first_hops: Option<&[&ChannelDetails]>
+//! #         first_hops: Option<&[&ChannelDetails]>, scorer: &S
 //! #     ) -> Result<Route, LightningError> { unimplemented!() }
 //! # }
 //! #
+//! # struct FakeScorer {};
+//! # impl routing::Score for FakeScorer {
+//! #     fn channel_penalty_msat(
+//! #         &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+//! #     ) -> u64 { 0 }
+//! #     fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
+//! # }
+//! #
 //! # struct FakeLogger {};
 //! # impl Logger for FakeLogger {
 //! #     fn log(&self, record: &Record) { unimplemented!() }
@@ -78,8 +89,9 @@
 //! };
 //! # let payer = FakePayer {};
 //! # let router = FakeRouter {};
+//! # let scorer = RefCell::new(FakeScorer {});
 //! # let logger = FakeLogger {};
-//! let invoice_payer = InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 //!
 //! let invoice = "...";
 //! let invoice = invoice.parse::<Invoice>().unwrap();
@@ -105,6 +117,8 @@ use bitcoin_hashes::Hash;
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 use lightning::ln::msgs::LightningError;
+use lightning::routing;
+use lightning::routing::{LockableScore, Score};
 use lightning::routing::router::{Payee, Route, RouteParameters};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
@@ -117,15 +131,17 @@ use std::sync::Mutex;
 use std::time::{Duration, SystemTime};
 
 /// A utility for paying [`Invoice]`s.
-pub struct InvoicePayer<P: Deref, R, L: Deref, E>
+pub struct InvoicePayer<P: Deref, R, S: Deref, L: Deref, E>
 where
        P::Target: Payer,
-       R: Router,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
        E: EventHandler,
 {
        payer: P,
        router: R,
+       scorer: S,
        logger: L,
        event_handler: E,
        payment_cache: Mutex<HashMap<PaymentHash, usize>>,
@@ -150,10 +166,11 @@ pub trait Payer {
 }
 
 /// A trait defining behavior for routing an [`Invoice`] payment.
-pub trait Router {
+pub trait Router<S: routing::Score> {
        /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values.
        fn find_route(
-               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>
+               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
        ) -> Result<Route, LightningError>;
 }
 
@@ -172,10 +189,11 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R, L: Deref, E> InvoicePayer<P, R, L, E>
+impl<P: Deref, R, S: Deref, L: Deref, E> InvoicePayer<P, R, S, L, E>
 where
        P::Target: Payer,
-       R: Router,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
        E: EventHandler,
 {
@@ -184,11 +202,12 @@ where
        /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once
        /// `retry_attempts` has been exceeded for a given [`Invoice`].
        pub fn new(
-               payer: P, router: R, logger: L, event_handler: E, retry_attempts: RetryAttempts
+               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts
        ) -> Self {
                Self {
                        payer,
                        router,
+                       scorer,
                        logger,
                        event_handler,
                        payment_cache: Mutex::new(HashMap::new()),
@@ -242,6 +261,7 @@ where
                                        &payer,
                                        &params,
                                        Some(&first_hops.iter().collect::<Vec<_>>()),
+                                       &self.scorer.lock(),
                                ).map_err(|e| PaymentError::Routing(e))?;
 
                                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
@@ -261,7 +281,8 @@ where
                let payer = self.payer.node_id();
                let first_hops = self.payer.first_hops();
                let route = self.router.find_route(
-                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>())
+                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()),
+                       &self.scorer.lock()
                ).map_err(|e| PaymentError::Routing(e))?;
                self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e))
        }
@@ -284,16 +305,23 @@ fn has_expired(params: &RouteParameters) -> bool {
        Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time)
 }
 
-impl<P: Deref, R, L: Deref, E> EventHandler for InvoicePayer<P, R, L, E>
+impl<P: Deref, R, S: Deref, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
 where
        P::Target: Payer,
-       R: Router,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
        E: EventHandler,
 {
        fn handle_event(&self, event: &Event) {
                match event {
-                       Event::PaymentPathFailed { payment_id, payment_hash, rejected_by_dest, retry, .. } => {
+                       Event::PaymentPathFailed {
+                               payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, ..
+                       } => {
+                               if let Some(short_channel_id) = short_channel_id {
+                                       self.scorer.lock().payment_path_failed(path, *short_channel_id);
+                               }
+
                                let mut payment_cache = self.payment_cache.lock().unwrap();
                                let entry = loop {
                                        let entry = payment_cache.entry(*payment_hash);
@@ -354,11 +382,13 @@ mod tests {
        use lightning::ln::PaymentPreimage;
        use lightning::ln::features::{ChannelFeatures, NodeFeatures};
        use lightning::ln::msgs::{ErrorAction, LightningError};
-       use lightning::routing::router::{Route, RouteHop};
+       use lightning::routing::network_graph::NodeId;
+       use lightning::routing::router::{Payee, Route, RouteHop};
        use lightning::util::test_utils::TestLogger;
        use lightning::util::errors::APIError;
        use lightning::util::events::Event;
        use secp256k1::{SecretKey, PublicKey, Secp256k1};
+       use std::cell::RefCell;
        use std::time::{SystemTime, Duration};
 
        fn invoice(payment_preimage: PaymentPreimage) -> Invoice {
@@ -422,9 +452,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -450,9 +481,10 @@ mod tests {
                        .expect_value_msat(final_value_msat)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -490,9 +522,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(PaymentId([1; 32]));
                let event = Event::PaymentPathFailed {
@@ -534,9 +567,10 @@ mod tests {
                        .expect_value_msat(final_value_msat / 2)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -583,9 +617,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -614,9 +649,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = expired_invoice(payment_preimage);
@@ -651,9 +687,10 @@ mod tests {
                        .fails_on_attempt(2)
                        .expect_value_msat(final_value_msat);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -680,9 +717,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -711,9 +749,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -751,9 +790,10 @@ mod tests {
        fn fails_paying_invoice_with_routing_errors() {
                let payer = TestPayer::new();
                let router = FailingRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -768,9 +808,10 @@ mod tests {
        fn fails_paying_invoice_with_sending_errors() {
                let payer = TestPayer::new().fails_on_attempt(1);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -793,9 +834,10 @@ mod tests {
 
                let payer = TestPayer::new().expect_value_msat(final_value_msat);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_id =
                        Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
@@ -815,9 +857,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -830,6 +873,40 @@ mod tests {
                }
        }
 
+       #[test]
+       fn scores_failed_channel() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+               let path = TestRouter::path_for_value(final_value_msat);
+               let short_channel_id = Some(path[0].short_channel_id);
+
+               // Expect that scorer is given short_channel_id upon handling the event.
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash,
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path,
+                       short_channel_id,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+       }
+
        struct TestRouter;
 
        impl TestRouter {
@@ -873,12 +950,13 @@ mod tests {
                }
        }
 
-       impl Router for TestRouter {
+       impl<S: routing::Score> Router<S> for TestRouter {
                fn find_route(
                        &self,
                        _payer: &PublicKey,
                        params: &RouteParameters,
                        _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
                ) -> Result<Route, LightningError> {
                        Ok(Route {
                                payee: Some(params.payee.clone()), ..Self::route_for_value(params.final_value_msat)
@@ -888,17 +966,59 @@ mod tests {
 
        struct FailingRouter;
 
-       impl Router for FailingRouter {
+       impl<S: routing::Score> Router<S> for FailingRouter {
                fn find_route(
                        &self,
                        _payer: &PublicKey,
                        _params: &RouteParameters,
                        _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
                ) -> Result<Route, LightningError> {
                        Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError })
                }
        }
 
+       struct TestScorer {
+               expectations: std::collections::VecDeque<u64>,
+       }
+
+       impl TestScorer {
+               fn new() -> Self {
+                       Self {
+                               expectations: std::collections::VecDeque::new(),
+                       }
+               }
+
+               fn expect_channel_failure(mut self, short_channel_id: u64) -> Self {
+                       self.expectations.push_back(short_channel_id);
+                       self
+               }
+       }
+
+       impl routing::Score for TestScorer {
+               fn channel_penalty_msat(
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               ) -> u64 { 0 }
+
+               fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, short_channel_id: u64) {
+                       if let Some(expected_short_channel_id) = self.expectations.pop_front() {
+                               assert_eq!(short_channel_id, expected_short_channel_id);
+                       }
+               }
+       }
+
+       impl Drop for TestScorer {
+               fn drop(&mut self) {
+                       if std::thread::panicking() {
+                               return;
+                       }
+
+                       if !self.expectations.is_empty() {
+                               panic!("Unsatisfied channel failure expectations: {:?}", self.expectations);
+                       }
+               }
+       }
+
        struct TestPayer {
                expectations: core::cell::RefCell<std::collections::VecDeque<u64>>,
                attempts: core::cell::RefCell<usize>,
index ef885f20381fb626fa3ca34e1a61d3dacdbf1198..e47141161b533fb3008ddd10c2a18e8bfb139f23 100644 (file)
@@ -11,9 +11,9 @@ use lightning::chain::keysinterface::{Sign, KeysInterface};
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY};
 use lightning::ln::msgs::LightningError;
+use lightning::routing;
 use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
 use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route};
-use lightning::routing::scorer::Scorer;
 use lightning::util::logger::Logger;
 use secp256k1::key::PublicKey;
 use std::convert::TryInto;
@@ -109,13 +109,13 @@ impl<G, L: Deref> DefaultRouter<G, L> where G: Deref<Target = NetworkGraph>, L::
        }
 }
 
-impl<G, L: Deref> Router for DefaultRouter<G, L>
+impl<G, L: Deref, S: routing::Score> Router<S> for DefaultRouter<G, L>
 where G: Deref<Target = NetworkGraph>, L::Target: Logger {
        fn find_route(
                &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
        ) -> Result<Route, LightningError> {
-               let scorer = Scorer::default();
-               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, &scorer)
+               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer)
        }
 }
 
index 51ffd91b50483d58bc2016440bdb98d9f82886e5..d6c016468ce79f696483ccafcee03e0269b6267f 100644 (file)
@@ -14,6 +14,12 @@ pub mod router;
 pub mod scorer;
 
 use routing::network_graph::NodeId;
+use routing::router::RouteHop;
+
+use prelude::*;
+use core::cell::{RefCell, RefMut};
+use core::ops::DerefMut;
+use sync::{Mutex, MutexGuard};
 
 /// An interface used to score payment channels for path finding.
 ///
@@ -22,4 +28,49 @@ pub trait Score {
        /// Returns the fee in msats willing to be paid to avoid routing through the given channel
        /// in the direction from `source` to `target`.
        fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
+
+       /// Handles updating channel penalties after failing to route through a channel.
+       fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64);
+}
+
+/// A scorer that is accessed under a lock.
+///
+/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
+/// having shared ownership of a scorer but without requiring internal locking in [`Score`]
+/// implementations. Internal locking would be detrimental to route finding performance and could
+/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
+///
+/// [`find_route`]: crate::routing::router::find_route
+pub trait LockableScore<'a> {
+       /// The locked [`Score`] type.
+       type Locked: 'a + Score;
+
+       /// Returns the locked scorer.
+       fn lock(&'a self) -> Self::Locked;
+}
+
+impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
+       type Locked = MutexGuard<'a, T>;
+
+       fn lock(&'a self) -> MutexGuard<'a, T> {
+               Mutex::lock(self).unwrap()
+       }
+}
+
+impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
+       type Locked = RefMut<'a, T>;
+
+       fn lock(&'a self) -> RefMut<'a, T> {
+               self.borrow_mut()
+       }
+}
+
+impl<S: Score, T: DerefMut<Target=S>> Score for T {
+       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
+               self.deref().channel_penalty_msat(short_channel_id, source, target)
+       }
+
+       fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64) {
+               self.deref_mut().payment_path_failed(path, short_channel_id)
+       }
 }
index 545ff9f24ce386520ec630dab31f13b75b9c08ae..226dbe04ff9e20e1b89aefc0957032bdd18978d5 100644 (file)
@@ -4537,6 +4537,8 @@ mod tests {
                fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
+
+               fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
        }
 
        struct BadNodeScorer {
@@ -4547,6 +4549,8 @@ mod tests {
                fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
+
+               fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
        }
 
        #[test]
index e3f5c8679b68d6b57ca418b70befdc0ff34e35d7..01481f16c297f9218aa05082b3fbdc7f22e886ba 100644 (file)
@@ -45,6 +45,9 @@
 use routing;
 
 use routing::network_graph::NodeId;
+use routing::router::RouteHop;
+
+use prelude::*;
 
 /// [`routing::Score`] implementation that provides reasonable default behavior.
 ///
@@ -78,4 +81,6 @@ impl routing::Score for Scorer {
        ) -> u64 {
                self.base_penalty_msat
        }
+
+       fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
 }