From: Jeffrey Czyz Date: Thu, 28 Oct 2021 04:12:44 +0000 (-0500) Subject: WIP: LockableScore X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=dbedf14f54f150a40f79c004024b9e1c0c214868;p=rust-lightning WIP: LockableScore --- diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index c273a5ac2..fe6a25f1a 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -117,33 +117,36 @@ use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; use lightning::ln::msgs::LightningError; use lightning::routing; +use lightning::routing::Score; use lightning::routing::router::{Payee, Route, RouteParameters}; use lightning::util::events::{Event, EventHandler}; use lightning::util::logger::Logger; +//use lightning::util::ser::{Writeable, Writer}; use secp256k1::key::PublicKey; use std::collections::hash_map::{self, HashMap}; use std::ops::Deref; -use std::sync::{Mutex, RwLock}; +use std::sync::Mutex; use std::time::{Duration, SystemTime}; /// A utility for paying [`Invoice]`s. -pub struct InvoicePayer +pub struct InvoicePayer<'a: 'b, 'b, P: Deref, R, S, L: Deref, E> where P::Target: Payer, R: Router, - S: routing::Score, + S: routing::LockableScore<'a, 'b>, L::Target: Logger, E: EventHandler, { payer: P, router: R, - scorer: RwLock, + scorer: S, logger: L, event_handler: E, payment_cache: Mutex>, retry_attempts: RetryAttempts, + phantom: std::marker::PhantomData<(&'a (), &'b ())>, } /// A trait defining behavior of an [`Invoice`] payer. @@ -187,22 +190,11 @@ pub enum PaymentError { Sending(PaymentSendFailure), } -/// A read-only version of the scorer. -pub struct ReadOnlyScorer<'a, S: routing::Score>(std::sync::RwLockReadGuard<'a, S>); - -impl<'a, S: routing::Score> Deref for ReadOnlyScorer<'a, S> { - type Target = S; - - fn deref(&self) -> &Self::Target { - &*self.0 - } -} - -impl InvoicePayer +impl<'a: 'b, 'b, P: Deref, R, S, L: Deref, E> InvoicePayer<'a, 'b, P, R, S, L, E> where P::Target: Payer, R: Router, - S: routing::Score, + S: routing::LockableScore<'a, 'b>, L::Target: Logger, E: EventHandler, { @@ -216,22 +208,15 @@ where Self { payer, router, - scorer: RwLock::new(scorer), + scorer, logger, event_handler, payment_cache: Mutex::new(HashMap::new()), retry_attempts, + phantom: std::marker::PhantomData, } } - /// Returns a read-only reference to the parameterized [`routing::Score`]. - /// - /// Useful if the scorer needs to be persisted. Be sure to drop the returned guard immediately - /// after use since retrying failed payment paths require write access. - pub fn scorer(&'_ self) -> ReadOnlyScorer<'_, S> { - ReadOnlyScorer(self.scorer.read().unwrap()) - } - /// Pays the given [`Invoice`], caching it for later use in case a retry is needed. pub fn pay_invoice(&self, invoice: &Invoice) -> Result { if invoice.amount_milli_satoshis().is_none() { @@ -278,7 +263,7 @@ where &payer, ¶ms, Some(&first_hops.iter().collect::>()), - &*self.scorer.read().unwrap(), + &self.scorer.lock(), ).map_err(|e| PaymentError::Routing(e))?; let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner()); @@ -299,7 +284,7 @@ where let first_hops = self.payer.first_hops(); let route = self.router.find_route( &payer, ¶ms, Some(&first_hops.iter().collect::>()), - &*self.scorer.read().unwrap() + &self.scorer.lock() ).map_err(|e| PaymentError::Routing(e))?; self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e)) } @@ -322,11 +307,11 @@ fn has_expired(params: &RouteParameters) -> bool { Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time) } -impl EventHandler for InvoicePayer +impl<'a: 'b, 'b, P: Deref, R, S, L: Deref, E> EventHandler for InvoicePayer<'a, 'b, P, R, S, L, E> where P::Target: Payer, R: Router, - S: routing::Score, + S: routing::LockableScore<'a, 'b>, L::Target: Logger, E: EventHandler, { @@ -336,7 +321,7 @@ where payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, .. } => { if let Some(short_channel_id) = short_channel_id { - self.scorer.write().unwrap().payment_path_failed(path, *short_channel_id); + self.scorer.lock().payment_path_failed(path, *short_channel_id); } let mut payment_cache = self.payment_cache.lock().unwrap(); @@ -391,6 +376,51 @@ where } } +///// +//pub trait WriteableScore: routing::Score + Writeable {} +// +///// +//pub struct ScorePersister +//where +// I: Deref>, +// P::Target: Payer, +// R: Router, +// S: WriteableScore, +// L::Target: Logger, +// E: EventHandler, +//{ +// invoice_payer: I, +//} +// +//impl ScorePersister +//where +// I: Deref>, +// P::Target: Payer, +// R: Router, +// S: WriteableScore, +// L::Target: Logger, +// E: EventHandler, +//{ +// /// +// pub fn new(invoice_payer: I) -> Self { +// Self { invoice_payer } +// } +//} +// +//impl Writeable for ScorePersister +//where +// I: Deref>, +// P::Target: Payer, +// R: Router, +// S: WriteableScore, +// L::Target: Logger, +// E: EventHandler, +//{ +// fn write(&self, writer: &mut W) -> Result<(), std::io::Error> { +// self.invoice_payer.scorer.read().unwrap().write(writer) +// } +//} + #[cfg(test)] mod tests { use super::*; diff --git a/lightning/src/routing/mod.rs b/lightning/src/routing/mod.rs index 426699860..4c5469c5f 100644 --- a/lightning/src/routing/mod.rs +++ b/lightning/src/routing/mod.rs @@ -17,6 +17,8 @@ use routing::network_graph::NodeId; use routing::router::RouteHop; use prelude::*; +use core::ops::{Deref, DerefMut}; +use sync::{Mutex, MutexGuard}; /// An interface used to score payment channels for path finding. /// @@ -29,3 +31,27 @@ pub trait Score { /// Handles updating channel penalties after failing to route through a channel. fn payment_path_failed(&mut self, path: &Vec, short_channel_id: u64); } + +pub trait LockableScore<'a: 'b, 'b> { + type Locked: Score; + + fn lock(&'a self) -> Self::Locked; +} + +impl<'a: 'b, 'b, S: 'b + Score, T: Deref>> LockableScore<'a, 'b> for T { + type Locked = MutexGuard<'b, S>; + + fn lock(&'a self) -> Self::Locked { + self.deref().lock().unwrap() + } +} + +impl<'a, S: Score> Score for MutexGuard<'a, S> { + fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 { + self.deref().channel_penalty_msat(short_channel_id, source, target) + } + + fn payment_path_failed(&mut self, path: &Vec, short_channel_id: u64) { + self.deref_mut().payment_path_failed(path, short_channel_id) + } +}