From 5f12682d72ddc1d8774b14bf44172ea39e52849d Mon Sep 17 00:00:00 2001 From: Jeffrey Czyz Date: Thu, 14 Oct 2021 13:04:39 -0500 Subject: [PATCH] Parameterize InvoicePayer by routing::Score --- lightning-background-processor/src/lib.rs | 1 + lightning-invoice/src/payment.rs | 98 +++++++++++++++++------ lightning-invoice/src/utils.rs | 8 +- 3 files changed, 77 insertions(+), 30 deletions(-) diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index e38a4a975..16c5eb4f9 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -311,6 +311,7 @@ mod tests { use lightning::ln::features::InitFeatures; use lightning::ln::msgs::{ChannelMessageHandler, Init}; use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler}; + use lightning::routing::scorer::Scorer; use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler}; use lightning::util::config::UserConfig; use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent}; diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 7e931d66f..f8fd8fcfa 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -30,6 +30,8 @@ //! # use lightning::ln::{PaymentHash, PaymentSecret}; //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; //! # use lightning::ln::msgs::LightningError; +//! # use lightning::routing; +//! # use lightning::routing::network_graph::NodeId; //! # use lightning::routing::router::{Route, RouteParameters}; //! # use lightning::util::events::{Event, EventHandler, EventsProvider}; //! # use lightning::util::logger::{Logger, Record}; @@ -57,12 +59,19 @@ //! # //! # struct FakeRouter {}; //! # impl Router for FakeRouter { -//! # fn find_route( +//! # fn find_route( //! # &self, payer: &PublicKey, params: &RouteParameters, -//! # first_hops: Option<&[&ChannelDetails]> +//! # first_hops: Option<&[&ChannelDetails]>, scorer: &S //! # ) -> Result { unimplemented!() } //! # } //! # +//! # struct FakeScorer {}; +//! # impl routing::Score for FakeScorer { +//! # fn channel_penalty_msat( +//! # &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId +//! # ) -> u64 { 0 } +//! # } +//! # //! # struct FakeLogger {}; //! # impl Logger for FakeLogger { //! # fn log(&self, record: &Record) { unimplemented!() } @@ -78,8 +87,9 @@ //! }; //! # let payer = FakePayer {}; //! # let router = FakeRouter {}; +//! # let scorer = FakeScorer {}; //! # let logger = FakeLogger {}; -//! let invoice_payer = InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); +//! let invoice_payer = InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); //! //! let invoice = "..."; //! let invoice = invoice.parse::().unwrap(); @@ -105,6 +115,7 @@ use bitcoin_hashes::Hash; use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; use lightning::ln::msgs::LightningError; +use lightning::routing; use lightning::routing::router::{Payee, Route, RouteParameters}; use lightning::util::events::{Event, EventHandler}; use lightning::util::logger::Logger; @@ -117,15 +128,17 @@ use std::sync::Mutex; use std::time::{Duration, SystemTime}; /// A utility for paying [`Invoice]`s. -pub struct InvoicePayer +pub struct InvoicePayer where P::Target: Payer, R: Router, + S: routing::Score, L::Target: Logger, E: EventHandler, { payer: P, router: R, + scorer: S, logger: L, event_handler: E, payment_cache: Mutex>, @@ -152,8 +165,9 @@ pub trait Payer { /// A trait defining behavior for routing an [`Invoice`] payment. pub trait Router { /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. - fn find_route( - &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]> + fn find_route( + &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, + scorer: &S ) -> Result; } @@ -172,10 +186,11 @@ pub enum PaymentError { Sending(PaymentSendFailure), } -impl InvoicePayer +impl InvoicePayer where P::Target: Payer, R: Router, + S: routing::Score, L::Target: Logger, E: EventHandler, { @@ -184,11 +199,12 @@ where /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once /// `retry_attempts` has been exceeded for a given [`Invoice`]. pub fn new( - payer: P, router: R, logger: L, event_handler: E, retry_attempts: RetryAttempts + payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts ) -> Self { Self { payer, router, + scorer, logger, event_handler, payment_cache: Mutex::new(HashMap::new()), @@ -242,6 +258,7 @@ where &payer, ¶ms, Some(&first_hops.iter().collect::>()), + &self.scorer, ).map_err(|e| PaymentError::Routing(e))?; let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner()); @@ -261,7 +278,7 @@ where let payer = self.payer.node_id(); let first_hops = self.payer.first_hops(); let route = self.router.find_route( - &payer, ¶ms, Some(&first_hops.iter().collect::>()) + &payer, ¶ms, Some(&first_hops.iter().collect::>()), &self.scorer ).map_err(|e| PaymentError::Routing(e))?; self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e)) } @@ -284,10 +301,11 @@ fn has_expired(params: &RouteParameters) -> bool { Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time) } -impl EventHandler for InvoicePayer +impl EventHandler for InvoicePayer where P::Target: Payer, R: Router, + S: routing::Score, L::Target: Logger, E: EventHandler, { @@ -354,7 +372,8 @@ mod tests { use lightning::ln::PaymentPreimage; use lightning::ln::features::{ChannelFeatures, NodeFeatures}; use lightning::ln::msgs::{ErrorAction, LightningError}; - use lightning::routing::router::{Route, RouteHop}; + use lightning::routing::network_graph::NodeId; + use lightning::routing::router::{Payee, Route, RouteHop}; use lightning::util::test_utils::TestLogger; use lightning::util::errors::APIError; use lightning::util::events::Event; @@ -422,9 +441,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -450,9 +470,10 @@ mod tests { .expect_value_msat(final_value_msat) .expect_value_msat(final_value_msat / 2); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -490,9 +511,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(PaymentId([1; 32])); let event = Event::PaymentPathFailed { @@ -534,9 +556,10 @@ mod tests { .expect_value_msat(final_value_msat / 2) .expect_value_msat(final_value_msat / 2); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -583,9 +606,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -614,9 +638,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = expired_invoice(payment_preimage); @@ -651,9 +676,10 @@ mod tests { .fails_on_attempt(2) .expect_value_msat(final_value_msat); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -680,9 +706,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -711,9 +738,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -751,9 +779,10 @@ mod tests { fn fails_paying_invoice_with_routing_errors() { let payer = TestPayer::new(); let router = FailingRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0)); + InvoicePayer::new(&payer, router, scorer, &logger, |_: &_| {}, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -768,9 +797,10 @@ mod tests { fn fails_paying_invoice_with_sending_errors() { let payer = TestPayer::new().fails_on_attempt(1); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0)); + InvoicePayer::new(&payer, router, scorer, &logger, |_: &_| {}, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -793,9 +823,10 @@ mod tests { let payer = TestPayer::new().expect_value_msat(final_value_msat); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0)); let payment_id = Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap()); @@ -815,9 +846,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = TestScorer::new(); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -874,11 +906,12 @@ mod tests { } impl Router for TestRouter { - fn find_route( + fn find_route( &self, _payer: &PublicKey, params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, + _scorer: &S, ) -> Result { Ok(Route { payee: Some(params.payee.clone()), ..Self::route_for_value(params.final_value_msat) @@ -889,16 +922,29 @@ mod tests { struct FailingRouter; impl Router for FailingRouter { - fn find_route( + fn find_route( &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, + _scorer: &S, ) -> Result { Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError }) } } + struct TestScorer; + + impl TestScorer { + fn new() -> Self { Self {} } + } + + impl routing::Score for TestScorer { + fn channel_penalty_msat( + &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId + ) -> u64 { 0 } + } + struct TestPayer { expectations: core::cell::RefCell>, attempts: core::cell::RefCell, diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index ef885f203..1c1a021f7 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -11,9 +11,9 @@ use lightning::chain::keysinterface::{Sign, KeysInterface}; use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY}; use lightning::ln::msgs::LightningError; +use lightning::routing; use lightning::routing::network_graph::{NetworkGraph, RoutingFees}; use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route}; -use lightning::routing::scorer::Scorer; use lightning::util::logger::Logger; use secp256k1::key::PublicKey; use std::convert::TryInto; @@ -111,11 +111,11 @@ impl DefaultRouter where G: Deref, L:: impl Router for DefaultRouter where G: Deref, L::Target: Logger { - fn find_route( + fn find_route( &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, + scorer: &S ) -> Result { - let scorer = Scorer::default(); - find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, &scorer) + find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer) } } -- 2.39.5