From 5f69c313d1e1f4c90067d8e5770c6bb9689dcc47 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sun, 31 Oct 2021 19:39:02 +0000 Subject: [PATCH] (Bindings Only) Concretize LockableScore to only map Scorer We don't really care about more than this in bindings - calling into a custom `Score` is likely too slow to be practical anyway, so this is also a performance improvement. Works around https://github.com/rust-lang/rust/issues/90448 --- lightning-background-processor/src/lib.rs | 3 +- lightning-invoice/src/payment.rs | 59 ++++++++++------------- lightning-invoice/src/utils.rs | 8 +-- lightning/src/routing/mod.rs | 40 ++++----------- 4 files changed, 41 insertions(+), 69 deletions(-) diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 50743774b..5e43d96b9 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -312,6 +312,7 @@ mod tests { use lightning::ln::features::InitFeatures; use lightning::ln::msgs::{ChannelMessageHandler, Init}; use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler}; + use lightning::routing::LockableScore; use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler}; use lightning::util::config::UserConfig; use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent}; @@ -634,7 +635,7 @@ mod tests { let data_dir = nodes[0].persister.get_data_dir(); let persister = move |node: &ChannelManager, Arc, Arc, Arc, Arc>| FilesystemPersister::persist_manager(data_dir.clone(), node); let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger)); - let scorer = Arc::new(Mutex::new(test_utils::TestScorer::default())); + let scorer = Arc::new(LockableScore::new(test_utils::TestScorer::default())); let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2))); let event_handler = Arc::clone(&invoice_payer); let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 075559bfd..e7b162b38 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -30,7 +30,7 @@ //! # use lightning::ln::{PaymentHash, PaymentSecret}; //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; //! # use lightning::ln::msgs::LightningError; -//! # use lightning::routing; +//! # use lightning::routing::{self, LockableScore}; //! # use lightning::routing::network_graph::NodeId; //! # use lightning::routing::router::{Route, RouteHop, RouteParameters}; //! # use lightning::util::events::{Event, EventHandler, EventsProvider}; @@ -89,7 +89,7 @@ //! }; //! # let payer = FakePayer {}; //! # let router = FakeRouter {}; -//! # let scorer = RefCell::new(FakeScorer {}); +//! # let scorer = LockableScore::new(FakeScorer {}); //! # let logger = FakeLogger {}; //! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); //! @@ -118,7 +118,7 @@ use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; use lightning::ln::msgs::LightningError; use lightning::routing; -use lightning::routing::{LockableScore, Score}; +use lightning::routing::LockableScore; use lightning::routing::router::{Payee, Route, RouteParameters}; use lightning::util::events::{Event, EventHandler}; use lightning::util::logger::Logger; @@ -131,17 +131,14 @@ use std::sync::Mutex; use std::time::{Duration, SystemTime}; /// A utility for paying [`Invoice]`s. -pub struct InvoicePayer +pub struct InvoicePayer, S: routing::Score, Sc: Deref>, L: Deref, E: EventHandler> where P::Target: Payer, - R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, - S::Target: for <'a> routing::LockableScore<'a>, L::Target: Logger, - E: EventHandler, { payer: P, router: R, - scorer: S, + scorer: Sc, logger: L, event_handler: E, payment_cache: Mutex>, @@ -193,20 +190,17 @@ pub enum PaymentError { Sending(PaymentSendFailure), } -impl InvoicePayer +impl, S: routing::Score, Sc: Deref>, L: Deref, E: EventHandler> InvoicePayer where P::Target: Payer, - R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, - S::Target: for <'a> routing::LockableScore<'a>, L::Target: Logger, - E: EventHandler, { /// Creates an invoice payer that retries failed payment paths. /// /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once /// `retry_attempts` has been exceeded for a given [`Invoice`]. pub fn new( - payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts + payer: P, router: R, scorer: Sc, logger: L, event_handler: E, retry_attempts: RetryAttempts ) -> Self { Self { payer, @@ -401,13 +395,10 @@ fn has_expired(params: &RouteParameters) -> bool { } else { false } } -impl EventHandler for InvoicePayer +impl, S: routing::Score, Sc: Deref>, L: Deref, E: EventHandler> EventHandler for InvoicePayer where P::Target: Payer, - R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, - S::Target: for <'a> routing::LockableScore<'a>, L::Target: Logger, - E: EventHandler, { fn handle_event(&self, event: &Event) { match event { @@ -529,7 +520,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); @@ -558,7 +549,7 @@ mod tests { .expect_value_msat(final_value_msat) .expect_value_msat(final_value_msat / 2); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -599,7 +590,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -644,7 +635,7 @@ mod tests { .expect_value_msat(final_value_msat / 2) .expect_value_msat(final_value_msat / 2); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -694,7 +685,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -726,7 +717,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -745,7 +736,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -787,7 +778,7 @@ mod tests { .fails_on_attempt(2) .expect_value_msat(final_value_msat); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -817,7 +808,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -849,7 +840,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); @@ -890,7 +881,7 @@ mod tests { fn fails_paying_invoice_with_routing_errors() { let payer = TestPayer::new(); let router = FailingRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0)); @@ -908,7 +899,7 @@ mod tests { fn fails_paying_invoice_with_sending_errors() { let payer = TestPayer::new().fails_on_attempt(1); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0)); @@ -934,7 +925,7 @@ mod tests { let payer = TestPayer::new().expect_value_msat(final_value_msat); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); @@ -957,7 +948,7 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); @@ -988,7 +979,7 @@ mod tests { // Expect that scorer is given short_channel_id upon handling the event. let payer = TestPayer::new(); let router = TestRouter {}; - let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap())); + let scorer = LockableScore::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap())); let logger = TestLogger::new(); let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); @@ -1277,7 +1268,7 @@ mod tests { router.expect_find_route(Ok(route.clone())); let event_handler = |_: &_| { panic!(); }; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1)); assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager( @@ -1322,7 +1313,7 @@ mod tests { router.expect_find_route(Ok(route.clone())); let event_handler = |_: &_| { panic!(); }; - let scorer = RefCell::new(TestScorer::new()); + let scorer = LockableScore::new(TestScorer::new()); let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1)); assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager( diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index 35a74b6a5..ba66aa058 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -97,20 +97,20 @@ where } /// A [`Router`] implemented using [`find_route`]. -pub struct DefaultRouter where G: Deref, L::Target: Logger { +pub struct DefaultRouter, L: Deref> where L::Target: Logger { network_graph: G, logger: L, } -impl DefaultRouter where G: Deref, L::Target: Logger { +impl, L: Deref> DefaultRouter where L::Target: Logger { /// Creates a new router using the given [`NetworkGraph`] and [`Logger`]. pub fn new(network_graph: G, logger: L) -> Self { Self { network_graph, logger } } } -impl Router for DefaultRouter -where G: Deref, L::Target: Logger { +impl, L: Deref, S: routing::Score> Router for DefaultRouter +where L::Target: Logger { fn find_route( &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, scorer: &S diff --git a/lightning/src/routing/mod.rs b/lightning/src/routing/mod.rs index 3a48ffe93..0282ee93e 100644 --- a/lightning/src/routing/mod.rs +++ b/lightning/src/routing/mod.rs @@ -16,8 +16,6 @@ pub mod scorer; use routing::network_graph::NodeId; use routing::router::RouteHop; -use core::cell::{RefCell, RefMut}; -use core::ops::DerefMut; use sync::{Mutex, MutexGuard}; /// An interface used to score payment channels for path finding. @@ -40,36 +38,18 @@ pub trait Score { /// result in [`Score::channel_penalty_msat`] returning a different value for the same channel. /// /// [`find_route`]: crate::routing::router::find_route -pub trait LockableScore<'a> { - /// The locked [`Score`] type. - type Locked: 'a + Score; - - /// Returns the locked scorer. - fn lock(&'a self) -> Self::Locked; -} - -impl<'a, T: 'a + Score> LockableScore<'a> for Mutex { - type Locked = MutexGuard<'a, T>; - - fn lock(&'a self) -> MutexGuard<'a, T> { - Mutex::lock(self).unwrap() - } +pub struct LockableScore { + scorer: Mutex, } -impl<'a, T: 'a + Score> LockableScore<'a> for RefCell { - type Locked = RefMut<'a, T>; - - fn lock(&'a self) -> RefMut<'a, T> { - self.borrow_mut() - } -} - -impl> Score for T { - fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 { - self.deref().channel_penalty_msat(short_channel_id, source, target) +impl LockableScore { + /// Constructs a new LockableScore from a Score + pub fn new(score: S) -> Self { + Self { scorer: Mutex::new(score) } } - - fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) { - self.deref_mut().payment_path_failed(path, short_channel_id) + /// Returns the locked scorer. + /// (C-not exported) + pub fn lock<'a>(&'a self) -> MutexGuard<'a, S> { + self.scorer.lock().unwrap() } } -- 2.39.5