]> git.bitcoin.ninja Git - rust-lightning/commitdiff
WIP: LockableScore
authorJeffrey Czyz <jkczyz@gmail.com>
Thu, 28 Oct 2021 04:12:44 +0000 (23:12 -0500)
committerJeffrey Czyz <jkczyz@gmail.com>
Thu, 28 Oct 2021 04:12:44 +0000 (23:12 -0500)
lightning-invoice/src/payment.rs
lightning/src/routing/mod.rs

index c273a5ac2c3d90c707234b0691365fb0aa7f7e9d..fe6a25f1aafdb7bd1e6579eb490c1d0b8bb855b1 100644 (file)
@@ -117,33 +117,36 @@ use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 use lightning::ln::msgs::LightningError;
 use lightning::routing;
+use lightning::routing::Score;
 use lightning::routing::router::{Payee, Route, RouteParameters};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
+//use lightning::util::ser::{Writeable, Writer};
 
 use secp256k1::key::PublicKey;
 
 use std::collections::hash_map::{self, HashMap};
 use std::ops::Deref;
-use std::sync::{Mutex, RwLock};
+use std::sync::Mutex;
 use std::time::{Duration, SystemTime};
 
 /// A utility for paying [`Invoice]`s.
-pub struct InvoicePayer<P: Deref, R, S, L: Deref, E>
+pub struct InvoicePayer<'a: 'b, 'b, P: Deref, R, S, L: Deref, E>
 where
        P::Target: Payer,
        R: Router,
-       S: routing::Score,
+       S: routing::LockableScore<'a, 'b>,
        L::Target: Logger,
        E: EventHandler,
 {
        payer: P,
        router: R,
-       scorer: RwLock<S>,
+       scorer: S,
        logger: L,
        event_handler: E,
        payment_cache: Mutex<HashMap<PaymentHash, usize>>,
        retry_attempts: RetryAttempts,
+       phantom: std::marker::PhantomData<(&'a (), &'b ())>,
 }
 
 /// A trait defining behavior of an [`Invoice`] payer.
@@ -187,22 +190,11 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-/// A read-only version of the scorer.
-pub struct ReadOnlyScorer<'a, S: routing::Score>(std::sync::RwLockReadGuard<'a, S>);
-
-impl<'a, S: routing::Score> Deref for ReadOnlyScorer<'a, S> {
-       type Target = S;
-
-       fn deref(&self) -> &Self::Target {
-               &*self.0
-       }
-}
-
-impl<P: Deref, R, S, L: Deref, E> InvoicePayer<P, R, S, L, E>
+impl<'a: 'b, 'b, P: Deref, R, S, L: Deref, E> InvoicePayer<'a, 'b, P, R, S, L, E>
 where
        P::Target: Payer,
        R: Router,
-       S: routing::Score,
+       S: routing::LockableScore<'a, 'b>,
        L::Target: Logger,
        E: EventHandler,
 {
@@ -216,22 +208,15 @@ where
                Self {
                        payer,
                        router,
-                       scorer: RwLock::new(scorer),
+                       scorer,
                        logger,
                        event_handler,
                        payment_cache: Mutex::new(HashMap::new()),
                        retry_attempts,
+                       phantom: std::marker::PhantomData,
                }
        }
 
-       /// Returns a read-only reference to the parameterized [`routing::Score`].
-       ///
-       /// Useful if the scorer needs to be persisted. Be sure to drop the returned guard immediately
-       /// after use since retrying failed payment paths require write access.
-       pub fn scorer(&'_ self) -> ReadOnlyScorer<'_, S> {
-               ReadOnlyScorer(self.scorer.read().unwrap())
-       }
-
        /// Pays the given [`Invoice`], caching it for later use in case a retry is needed.
        pub fn pay_invoice(&self, invoice: &Invoice) -> Result<PaymentId, PaymentError> {
                if invoice.amount_milli_satoshis().is_none() {
@@ -278,7 +263,7 @@ where
                                        &payer,
                                        &params,
                                        Some(&first_hops.iter().collect::<Vec<_>>()),
-                                       &*self.scorer.read().unwrap(),
+                                       &self.scorer.lock(),
                                ).map_err(|e| PaymentError::Routing(e))?;
 
                                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
@@ -299,7 +284,7 @@ where
                let first_hops = self.payer.first_hops();
                let route = self.router.find_route(
                        &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()),
-                       &*self.scorer.read().unwrap()
+                       &self.scorer.lock()
                ).map_err(|e| PaymentError::Routing(e))?;
                self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e))
        }
@@ -322,11 +307,11 @@ fn has_expired(params: &RouteParameters) -> bool {
        Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time)
 }
 
-impl<P: Deref, R, S, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
+impl<'a: 'b, 'b, P: Deref, R, S, L: Deref, E> EventHandler for InvoicePayer<'a, 'b, P, R, S, L, E>
 where
        P::Target: Payer,
        R: Router,
-       S: routing::Score,
+       S: routing::LockableScore<'a, 'b>,
        L::Target: Logger,
        E: EventHandler,
 {
@@ -336,7 +321,7 @@ where
                                payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, ..
                        } => {
                                if let Some(short_channel_id) = short_channel_id {
-                                       self.scorer.write().unwrap().payment_path_failed(path, *short_channel_id);
+                                       self.scorer.lock().payment_path_failed(path, *short_channel_id);
                                }
 
                                let mut payment_cache = self.payment_cache.lock().unwrap();
@@ -391,6 +376,51 @@ where
        }
 }
 
+/////
+//pub trait WriteableScore: routing::Score + Writeable {}
+//
+/////
+//pub struct ScorePersister<I, P: Deref, R, S, L: Deref, E>
+//where
+//     I: Deref<Target=InvoicePayer<P, R, S, L, E>>,
+//     P::Target: Payer,
+//     R: Router,
+//     S: WriteableScore,
+//     L::Target: Logger,
+//     E: EventHandler,
+//{
+//     invoice_payer: I,
+//}
+//
+//impl<I, P: Deref, R, S, L: Deref, E> ScorePersister<I, P, R, S, L, E>
+//where
+//     I: Deref<Target=InvoicePayer<P, R, S, L, E>>,
+//     P::Target: Payer,
+//     R: Router,
+//     S: WriteableScore,
+//     L::Target: Logger,
+//     E: EventHandler,
+//{
+//     ///
+//     pub fn new(invoice_payer: I) -> Self {
+//             Self { invoice_payer }
+//     }
+//}
+//
+//impl<I, P: Deref, R, S, L: Deref, E> Writeable for ScorePersister<I, P, R, S, L, E>
+//where
+//     I: Deref<Target=InvoicePayer<P, R, S, L, E>>,
+//     P::Target: Payer,
+//     R: Router,
+//     S: WriteableScore,
+//     L::Target: Logger,
+//     E: EventHandler,
+//{
+//     fn write<W: Writer>(&self, writer: &mut W) -> Result<(), std::io::Error> {
+//             self.invoice_payer.scorer.read().unwrap().write(writer)
+//     }
+//}
+
 #[cfg(test)]
 mod tests {
        use super::*;
index 42669986015af593e1c851656026e00f2b346233..4c5469c5f15f10acc543bcd6d7cb0f907d1c5fd3 100644 (file)
@@ -17,6 +17,8 @@ use routing::network_graph::NodeId;
 use routing::router::RouteHop;
 
 use prelude::*;
+use core::ops::{Deref, DerefMut};
+use sync::{Mutex, MutexGuard};
 
 /// An interface used to score payment channels for path finding.
 ///
@@ -29,3 +31,27 @@ pub trait Score {
        /// Handles updating channel penalties after failing to route through a channel.
        fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64);
 }
+
+pub trait LockableScore<'a: 'b, 'b> {
+       type Locked: Score;
+
+       fn lock(&'a self) -> Self::Locked;
+}
+
+impl<'a: 'b, 'b, S: 'b + Score, T: Deref<Target=Mutex<S>>> LockableScore<'a, 'b> for T {
+       type Locked = MutexGuard<'b, S>;
+
+       fn lock(&'a self) -> Self::Locked {
+               self.deref().lock().unwrap()
+       }
+}
+
+impl<'a, S: Score> Score for MutexGuard<'a, S> {
+       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
+               self.deref().channel_penalty_msat(short_channel_id, source, target)
+       }
+
+       fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64) {
+               self.deref_mut().payment_path_failed(path, short_channel_id)
+       }
+}