(Bindings Only) Concretize LockableScore to only map Scorer
authorMatt Corallo <git@bluematt.me>
Sun, 31 Oct 2021 19:39:02 +0000 (19:39 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 3 Nov 2021 17:11:59 +0000 (17:11 +0000)
We don't really care about more than this in bindings - calling
into a custom `Score` is likely too slow to be practical anyway,
so this is also a performance improvement.

Works around https://github.com/rust-lang/rust/issues/90448

lightning-background-processor/src/lib.rs
lightning-invoice/src/payment.rs
lightning-invoice/src/utils.rs
lightning/src/routing/mod.rs

index 50743774b7aa487753f416de66f1c952a7c86a3f..5e43d96b91d4cb90be8cdc757d551fadbb43abb8 100644 (file)
@@ -312,6 +312,7 @@ mod tests {
        use lightning::ln::features::InitFeatures;
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
+       use lightning::routing::LockableScore;
        use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use lightning::util::config::UserConfig;
        use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
@@ -634,7 +635,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = move |node: &ChannelManager<InMemorySigner, Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>>| FilesystemPersister::persist_manager(data_dir.clone(), node);
                let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger));
-               let scorer = Arc::new(Mutex::new(test_utils::TestScorer::default()));
+               let scorer = Arc::new(LockableScore::new(test_utils::TestScorer::default()));
                let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
                let event_handler = Arc::clone(&invoice_payer);
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
index 075559bfd8ed857878e391fcd2a396ae3742093c..e7b162b38f96236472b0de61c7f97446240f8d86 100644 (file)
@@ -30,7 +30,7 @@
 //! # use lightning::ln::{PaymentHash, PaymentSecret};
 //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 //! # use lightning::ln::msgs::LightningError;
-//! # use lightning::routing;
+//! # use lightning::routing::{self, LockableScore};
 //! # use lightning::routing::network_graph::NodeId;
 //! # use lightning::routing::router::{Route, RouteHop, RouteParameters};
 //! # use lightning::util::events::{Event, EventHandler, EventsProvider};
@@ -89,7 +89,7 @@
 //! };
 //! # let payer = FakePayer {};
 //! # let router = FakeRouter {};
-//! # let scorer = RefCell::new(FakeScorer {});
+//! # let scorer = LockableScore::new(FakeScorer {});
 //! # let logger = FakeLogger {};
 //! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 //!
@@ -118,7 +118,7 @@ use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 use lightning::ln::msgs::LightningError;
 use lightning::routing;
-use lightning::routing::{LockableScore, Score};
+use lightning::routing::LockableScore;
 use lightning::routing::router::{Payee, Route, RouteParameters};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
@@ -131,17 +131,14 @@ use std::sync::Mutex;
 use std::time::{Duration, SystemTime};
 
 /// A utility for paying [`Invoice]`s.
-pub struct InvoicePayer<P: Deref, R, S: Deref, L: Deref, E>
+pub struct InvoicePayer<P: Deref, R: Router<S>, S: routing::Score, Sc: Deref<Target=LockableScore<S>>, L: Deref, E: EventHandler>
 where
        P::Target: Payer,
-       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
-       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
-       E: EventHandler,
 {
        payer: P,
        router: R,
-       scorer: S,
+       scorer: Sc,
        logger: L,
        event_handler: E,
        payment_cache: Mutex<HashMap<PaymentHash, usize>>,
@@ -193,20 +190,17 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R, S: Deref, L: Deref, E> InvoicePayer<P, R, S, L, E>
+impl<P: Deref, R: Router<S>, S: routing::Score, Sc: Deref<Target=LockableScore<S>>, L: Deref, E: EventHandler> InvoicePayer<P, R, S, Sc, L, E>
 where
        P::Target: Payer,
-       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
-       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
-       E: EventHandler,
 {
        /// Creates an invoice payer that retries failed payment paths.
        ///
        /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once
        /// `retry_attempts` has been exceeded for a given [`Invoice`].
        pub fn new(
-               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts
+               payer: P, router: R, scorer: Sc, logger: L, event_handler: E, retry_attempts: RetryAttempts
        ) -> Self {
                Self {
                        payer,
@@ -401,13 +395,10 @@ fn has_expired(params: &RouteParameters) -> bool {
        } else { false }
 }
 
-impl<P: Deref, R, S: Deref, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
+impl<P: Deref, R: Router<S>, S: routing::Score, Sc: Deref<Target=LockableScore<S>>, L: Deref, E: EventHandler> EventHandler for InvoicePayer<P, R, S, Sc, L, E>
 where
        P::Target: Payer,
-       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
-       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
-       E: EventHandler,
 {
        fn handle_event(&self, event: &Event) {
                match event {
@@ -529,7 +520,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
@@ -558,7 +549,7 @@ mod tests {
                        .expect_value_msat(final_value_msat)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -599,7 +590,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -644,7 +635,7 @@ mod tests {
                        .expect_value_msat(final_value_msat / 2)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -694,7 +685,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -726,7 +717,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -745,7 +736,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -787,7 +778,7 @@ mod tests {
                        .fails_on_attempt(2)
                        .expect_value_msat(final_value_msat);
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -817,7 +808,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -849,7 +840,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
@@ -890,7 +881,7 @@ mod tests {
        fn fails_paying_invoice_with_routing_errors() {
                let payer = TestPayer::new();
                let router = FailingRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
@@ -908,7 +899,7 @@ mod tests {
        fn fails_paying_invoice_with_sending_errors() {
                let payer = TestPayer::new().fails_on_attempt(1);
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
@@ -934,7 +925,7 @@ mod tests {
 
                let payer = TestPayer::new().expect_value_msat(final_value_msat);
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
@@ -957,7 +948,7 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
@@ -988,7 +979,7 @@ mod tests {
                // Expect that scorer is given short_channel_id upon handling the event.
                let payer = TestPayer::new();
                let router = TestRouter {};
-               let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
+               let scorer = LockableScore::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
                let logger = TestLogger::new();
                let invoice_payer =
                        InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
@@ -1277,7 +1268,7 @@ mod tests {
                router.expect_find_route(Ok(route.clone()));
 
                let event_handler = |_: &_| { panic!(); };
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager(
@@ -1322,7 +1313,7 @@ mod tests {
                router.expect_find_route(Ok(route.clone()));
 
                let event_handler = |_: &_| { panic!(); };
-               let scorer = RefCell::new(TestScorer::new());
+               let scorer = LockableScore::new(TestScorer::new());
                let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
 
                assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager(
index 35a74b6a5ac6a3bf6a41984babc8619551a19622..ba66aa0586d93c27f3b78b494148b17c6761eaac 100644 (file)
@@ -97,20 +97,20 @@ where
 }
 
 /// A [`Router`] implemented using [`find_route`].
-pub struct DefaultRouter<G, L: Deref> where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+pub struct DefaultRouter<G: Deref<Target = NetworkGraph>, L: Deref> where L::Target: Logger {
        network_graph: G,
        logger: L,
 }
 
-impl<G, L: Deref> DefaultRouter<G, L> where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+impl<G: Deref<Target = NetworkGraph>, L: Deref> DefaultRouter<G, L> where L::Target: Logger {
        /// Creates a new router using the given [`NetworkGraph`] and  [`Logger`].
        pub fn new(network_graph: G, logger: L) -> Self {
                Self { network_graph, logger }
        }
 }
 
-impl<G, L: Deref, S: routing::Score> Router<S> for DefaultRouter<G, L>
-where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+impl<G: Deref<Target = NetworkGraph>, L: Deref, S: routing::Score> Router<S> for DefaultRouter<G, L>
+where L::Target: Logger {
        fn find_route(
                &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
                scorer: &S
index 3a48ffe93ddf42c77fb3b3c9b17d727dba14c8f0..0282ee93e40fc0da619e59225466205a4a5c746e 100644 (file)
@@ -16,8 +16,6 @@ pub mod scorer;
 use routing::network_graph::NodeId;
 use routing::router::RouteHop;
 
-use core::cell::{RefCell, RefMut};
-use core::ops::DerefMut;
 use sync::{Mutex, MutexGuard};
 
 /// An interface used to score payment channels for path finding.
@@ -40,36 +38,18 @@ pub trait Score {
 /// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
 ///
 /// [`find_route`]: crate::routing::router::find_route
-pub trait LockableScore<'a> {
-       /// The locked [`Score`] type.
-       type Locked: 'a + Score;
-
-       /// Returns the locked scorer.
-       fn lock(&'a self) -> Self::Locked;
-}
-
-impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
-       type Locked = MutexGuard<'a, T>;
-
-       fn lock(&'a self) -> MutexGuard<'a, T> {
-               Mutex::lock(self).unwrap()
-       }
+pub struct LockableScore<S: Score> {
+       scorer: Mutex<S>,
 }
 
-impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
-       type Locked = RefMut<'a, T>;
-
-       fn lock(&'a self) -> RefMut<'a, T> {
-               self.borrow_mut()
-       }
-}
-
-impl<S: Score, T: DerefMut<Target=S>> Score for T {
-       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
-               self.deref().channel_penalty_msat(short_channel_id, source, target)
+impl<S: Score> LockableScore<S> {
+       /// Constructs a new LockableScore from a Score
+       pub fn new(score: S) -> Self {
+               Self { scorer: Mutex::new(score) }
        }
-
-       fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
-               self.deref_mut().payment_path_failed(path, short_channel_id)
+       /// Returns the locked scorer.
+       /// (C-not exported)
+       pub fn lock<'a>(&'a self) -> MutexGuard<'a, S> {
+               self.scorer.lock().unwrap()
        }
 }