]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Parameterize InvoicePayer by routing::Score
authorJeffrey Czyz <jkczyz@gmail.com>
Thu, 14 Oct 2021 18:04:39 +0000 (13:04 -0500)
committerJeffrey Czyz <jkczyz@gmail.com>
Wed, 27 Oct 2021 17:06:49 +0000 (12:06 -0500)
lightning-background-processor/src/lib.rs
lightning-invoice/src/payment.rs
lightning-invoice/src/utils.rs

index e38a4a975b2a8ff4a9c750081c9d43166ffd21af..16c5eb4f9393193ba9136fd645661625946359e6 100644 (file)
@@ -311,6 +311,7 @@ mod tests {
        use lightning::ln::features::InitFeatures;
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
+       use lightning::routing::scorer::Scorer;
        use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use lightning::util::config::UserConfig;
        use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
index 7e931d66f153a4f530d7cbb9af1623dfb1ee4e2c..f8fd8fcfabe8af03668726a1da04750af7d69875 100644 (file)
@@ -30,6 +30,8 @@
 //! # use lightning::ln::{PaymentHash, PaymentSecret};
 //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 //! # use lightning::ln::msgs::LightningError;
+//! # use lightning::routing;
+//! # use lightning::routing::network_graph::NodeId;
 //! # use lightning::routing::router::{Route, RouteParameters};
 //! # use lightning::util::events::{Event, EventHandler, EventsProvider};
 //! # use lightning::util::logger::{Logger, Record};
 //! #
 //! # struct FakeRouter {};
 //! # impl Router for FakeRouter {
-//! #     fn find_route(
+//! #     fn find_route<S: routing::Score>(
 //! #         &self, payer: &PublicKey, params: &RouteParameters,
-//! #         first_hops: Option<&[&ChannelDetails]>
+//! #         first_hops: Option<&[&ChannelDetails]>, scorer: &S
 //! #     ) -> Result<Route, LightningError> { unimplemented!() }
 //! # }
 //! #
+//! # struct FakeScorer {};
+//! # impl routing::Score for FakeScorer {
+//! #     fn channel_penalty_msat(
+//! #         &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+//! #     ) -> u64 { 0 }
+//! # }
+//! #
 //! # struct FakeLogger {};
 //! # impl Logger for FakeLogger {
 //! #     fn log(&self, record: &Record) { unimplemented!() }
@@ -78,8 +87,9 @@
 //! };
 //! # let payer = FakePayer {};
 //! # let router = FakeRouter {};
+//! # let scorer = FakeScorer {};
 //! # let logger = FakeLogger {};
-//! let invoice_payer = InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+//! let invoice_payer = InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 //!
 //! let invoice = "...";
 //! let invoice = invoice.parse::<Invoice>().unwrap();
@@ -105,6 +115,7 @@ use bitcoin_hashes::Hash;
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 use lightning::ln::msgs::LightningError;
+use lightning::routing;
 use lightning::routing::router::{Payee, Route, RouteParameters};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
@@ -117,15 +128,17 @@ use std::sync::Mutex;
 use std::time::{Duration, SystemTime};
 
 /// A utility for paying [`Invoice]`s.
-pub struct InvoicePayer<P: Deref, R, L: Deref, E>
+pub struct InvoicePayer<P: Deref, R, S, L: Deref, E>
 where
        P::Target: Payer,
        R: Router,
+       S: routing::Score,
        L::Target: Logger,
        E: EventHandler,
 {
        payer: P,
        router: R,
+       scorer: S,
        logger: L,
        event_handler: E,
        payment_cache: Mutex<HashMap<PaymentHash, usize>>,
@@ -152,8 +165,9 @@ pub trait Payer {
 /// A trait defining behavior for routing an [`Invoice`] payment.
 pub trait Router {
        /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values.
-       fn find_route(
-               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>
+       fn find_route<S: routing::Score>(
+               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
        ) -> Result<Route, LightningError>;
 }
 
@@ -172,10 +186,11 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R, L: Deref, E> InvoicePayer<P, R, L, E>
+impl<P: Deref, R, S, L: Deref, E> InvoicePayer<P, R, S, L, E>
 where
        P::Target: Payer,
        R: Router,
+       S: routing::Score,
        L::Target: Logger,
        E: EventHandler,
 {
@@ -184,11 +199,12 @@ where
        /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once
        /// `retry_attempts` has been exceeded for a given [`Invoice`].
        pub fn new(
-               payer: P, router: R, logger: L, event_handler: E, retry_attempts: RetryAttempts
+               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts
        ) -> Self {
                Self {
                        payer,
                        router,
+                       scorer,
                        logger,
                        event_handler,
                        payment_cache: Mutex::new(HashMap::new()),
@@ -242,6 +258,7 @@ where
                                        &payer,
                                        &params,
                                        Some(&first_hops.iter().collect::<Vec<_>>()),
+                                       &self.scorer,
                                ).map_err(|e| PaymentError::Routing(e))?;
 
                                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
@@ -261,7 +278,7 @@ where
                let payer = self.payer.node_id();
                let first_hops = self.payer.first_hops();
                let route = self.router.find_route(
-                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>())
+                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()), &self.scorer
                ).map_err(|e| PaymentError::Routing(e))?;
                self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e))
        }
@@ -284,10 +301,11 @@ fn has_expired(params: &RouteParameters) -> bool {
        Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time)
 }
 
-impl<P: Deref, R, L: Deref, E> EventHandler for InvoicePayer<P, R, L, E>
+impl<P: Deref, R, S, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
 where
        P::Target: Payer,
        R: Router,
+       S: routing::Score,
        L::Target: Logger,
        E: EventHandler,
 {
@@ -354,7 +372,8 @@ mod tests {
        use lightning::ln::PaymentPreimage;
        use lightning::ln::features::{ChannelFeatures, NodeFeatures};
        use lightning::ln::msgs::{ErrorAction, LightningError};
-       use lightning::routing::router::{Route, RouteHop};
+       use lightning::routing::network_graph::NodeId;
+       use lightning::routing::router::{Payee, Route, RouteHop};
        use lightning::util::test_utils::TestLogger;
        use lightning::util::errors::APIError;
        use lightning::util::events::Event;
@@ -422,9 +441,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -450,9 +470,10 @@ mod tests {
                        .expect_value_msat(final_value_msat)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -490,9 +511,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(PaymentId([1; 32]));
                let event = Event::PaymentPathFailed {
@@ -534,9 +556,10 @@ mod tests {
                        .expect_value_msat(final_value_msat / 2)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -583,9 +606,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -614,9 +638,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = expired_invoice(payment_preimage);
@@ -651,9 +676,10 @@ mod tests {
                        .fails_on_attempt(2)
                        .expect_value_msat(final_value_msat);
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -680,9 +706,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -711,9 +738,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -751,9 +779,10 @@ mod tests {
        fn fails_paying_invoice_with_routing_errors() {
                let payer = TestPayer::new();
                let router = FailingRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -768,9 +797,10 @@ mod tests {
        fn fails_paying_invoice_with_sending_errors() {
                let payer = TestPayer::new().fails_on_attempt(1);
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -793,9 +823,10 @@ mod tests {
 
                let payer = TestPayer::new().expect_value_msat(final_value_msat);
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_id =
                        Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
@@ -815,9 +846,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = TestScorer::new();
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -874,11 +906,12 @@ mod tests {
        }
 
        impl Router for TestRouter {
-               fn find_route(
+               fn find_route<S: routing::Score>(
                        &self,
                        _payer: &PublicKey,
                        params: &RouteParameters,
                        _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
                ) -> Result<Route, LightningError> {
                        Ok(Route {
                                payee: Some(params.payee.clone()), ..Self::route_for_value(params.final_value_msat)
@@ -889,16 +922,29 @@ mod tests {
        struct FailingRouter;
 
        impl Router for FailingRouter {
-               fn find_route(
+               fn find_route<S: routing::Score>(
                        &self,
                        _payer: &PublicKey,
                        _params: &RouteParameters,
                        _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
                ) -> Result<Route, LightningError> {
                        Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError })
                }
        }
 
+       struct TestScorer;
+
+       impl TestScorer {
+               fn new() -> Self { Self {} }
+       }
+
+       impl routing::Score for TestScorer {
+               fn channel_penalty_msat(
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               ) -> u64 { 0 }
+       }
+
        struct TestPayer {
                expectations: core::cell::RefCell<std::collections::VecDeque<u64>>,
                attempts: core::cell::RefCell<usize>,
index ef885f20381fb626fa3ca34e1a61d3dacdbf1198..1c1a021f7400ab6cf99c2ae08ec7f18d2d061eba 100644 (file)
@@ -11,9 +11,9 @@ use lightning::chain::keysinterface::{Sign, KeysInterface};
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY};
 use lightning::ln::msgs::LightningError;
+use lightning::routing;
 use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
 use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route};
-use lightning::routing::scorer::Scorer;
 use lightning::util::logger::Logger;
 use secp256k1::key::PublicKey;
 use std::convert::TryInto;
@@ -111,11 +111,11 @@ impl<G, L: Deref> DefaultRouter<G, L> where G: Deref<Target = NetworkGraph>, L::
 
 impl<G, L: Deref> Router for DefaultRouter<G, L>
 where G: Deref<Target = NetworkGraph>, L::Target: Logger {
-       fn find_route(
+       fn find_route<S: routing::Score>(
                &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
        ) -> Result<Route, LightningError> {
-               let scorer = Scorer::default();
-               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, &scorer)
+               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer)
        }
 }