From: henghonglee Date: Thu, 29 Jun 2023 02:41:38 +0000 (+0800) Subject: Fix DefaultRouter type restrained to only MutexGuard X-Git-Tag: v0.0.116-rc1~18^2 X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=54bcb6eb02c600df0adf3d438b7595bdb56ad7fb;p=rust-lightning Fix DefaultRouter type restrained to only MutexGuard Type of DerefMut for DefaultRouter was specialized to only MutexGuard. It should be generic around RefMut and MutexGuard. This commit fixes that --- diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 517e693de..1ed6a2a83 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -885,7 +885,22 @@ mod tests { fn disconnect_socket(&mut self) {} } - type ChannelManager = channelmanager::ChannelManager, Arc, Arc, Arc, Arc, Arc, Arc>>, Arc, Arc>, (), TestScorer>>, Arc>; + type ChannelManager = + channelmanager::ChannelManager< + Arc, + Arc, + Arc, + Arc, + Arc, + Arc, + Arc>>, + Arc, + Arc>, + (), + TestScorer> + >, + Arc>; type ChainMonitor = chainmonitor::ChainMonitor, Arc, Arc, Arc, Arc>; diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 65f442f1c..980616770 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -752,7 +752,23 @@ pub type SimpleArcChannelManager = ChannelManager< /// of [`KeysManager`] and [`DefaultRouter`]. /// /// This is not exported to bindings users as Arcs don't make sense in bindings -pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex, &'g L>>, ProbabilisticScoringFeeParameters, ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>, &'g L>; +pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = + ChannelManager< + &'a M, + &'b T, + &'c KeysManager, + &'c KeysManager, + &'c KeysManager, + &'d F, + &'e DefaultRouter< + &'f NetworkGraph<&'g L>, + &'g L, + &'h Mutex, &'g L>>, + ProbabilisticScoringFeeParameters, + ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L> + >, + &'g L + >; macro_rules! define_test_pub_trait { ($vis: vis) => { /// A trivial trait which describes any [`ChannelManager`] used in testing. diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index a8da26c23..ed5b772ed 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -27,15 +27,15 @@ use crate::util::chacha20::ChaCha20; use crate::io; use crate::prelude::*; -use crate::sync::{Mutex, MutexGuard}; +use crate::sync::{Mutex}; use alloc::collections::BinaryHeap; use core::{cmp, fmt}; -use core::ops::Deref; +use core::ops::{Deref, DerefMut}; /// A [`Router`] implemented using [`find_route`]. pub struct DefaultRouter>, L: Deref, S: Deref, SP: Sized, Sc: Score> where L::Target: Logger, - S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>, + S::Target: for <'a> LockableScore<'a, Score = Sc>, { network_graph: G, logger: L, @@ -46,7 +46,7 @@ pub struct DefaultRouter>, L: Deref, S: Deref, impl>, L: Deref, S: Deref, SP: Sized, Sc: Score> DefaultRouter where L::Target: Logger, - S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>, + S::Target: for <'a> LockableScore<'a, Score = Sc>, { /// Creates a new router. pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self { @@ -55,9 +55,9 @@ impl>, L: Deref, S: Deref, SP: Sized, Sc: Scor } } -impl< G: Deref>, L: Deref, S: Deref, SP: Sized, Sc: Score> Router for DefaultRouter where +impl< G: Deref>, L: Deref, S: Deref, SP: Sized, Sc: Score> Router for DefaultRouter where L::Target: Logger, - S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>, + S::Target: for <'a> LockableScore<'a, Score = Sc>, { fn find_route( &self, @@ -73,7 +73,7 @@ impl< G: Deref>, L: Deref, S: Deref, SP: Sized, Sc: Sc }; find_route( payer, params, &self.network_graph, first_hops, &*self.logger, - &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs), + &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs), &self.score_params, &random_seed_bytes ) @@ -104,15 +104,15 @@ pub trait Router { /// [`find_route`]. /// /// [`Score`]: crate::routing::scoring::Score -pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score> { - scorer: S, +pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score, SP: Sized> { + scorer: &'a mut S, // Maps a channel's short channel id and its direction to the liquidity used up. inflight_htlcs: &'a InFlightHtlcs, } -impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> { +impl<'a, S: Score, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> { /// Initialize a new `ScorerAccountingForInFlightHtlcs`. - pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self { + pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self { ScorerAccountingForInFlightHtlcs { scorer, inflight_htlcs @@ -121,11 +121,11 @@ impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> { } #[cfg(c_bindings)] -impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> { +impl<'a, S: Score, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> { fn write(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) } } -impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> { +impl<'a, S: Score, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP> { type ScoreParams = S::ScoreParams; fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 { if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat( diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index eca5ee605..d83adaf97 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -157,8 +157,11 @@ define_score!(); /// /// [`find_route`]: crate::routing::router::find_route pub trait LockableScore<'a> { + /// The [`Score`] type. + type Score: 'a + Score; + /// The locked [`Score`] type. - type Locked: 'a + Score; + type Locked: DerefMut + Sized; /// Returns the locked scorer. fn lock(&'a self) -> Self::Locked; @@ -174,60 +177,35 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {} impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {} /// This is not exported to bindings users impl<'a, T: 'a + Score> LockableScore<'a> for Mutex { + type Score = T; type Locked = MutexGuard<'a, T>; - fn lock(&'a self) -> MutexGuard<'a, T> { + fn lock(&'a self) -> Self::Locked { Mutex::lock(self).unwrap() } } impl<'a, T: 'a + Score> LockableScore<'a> for RefCell { + type Score = T; type Locked = RefMut<'a, T>; - fn lock(&'a self) -> RefMut<'a, T> { + fn lock(&'a self) -> Self::Locked { self.borrow_mut() } } #[cfg(c_bindings)] /// A concrete implementation of [`LockableScore`] which supports multi-threading. -pub struct MultiThreadedLockableScore { - score: Mutex, -} -#[cfg(c_bindings)] -/// A locked `MultiThreadedLockableScore`. -pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>); -#[cfg(c_bindings)] -impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> { - type ScoreParams = ::ScoreParams; - fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 { - self.0.channel_penalty_msat(scid, source, target, usage, score_params) - } - fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) { - self.0.payment_path_failed(path, short_channel_id) - } - fn payment_path_successful(&mut self, path: &Path) { - self.0.payment_path_successful(path) - } - fn probe_failed(&mut self, path: &Path, short_channel_id: u64) { - self.0.probe_failed(path, short_channel_id) - } - fn probe_successful(&mut self, path: &Path) { - self.0.probe_successful(path) - } -} -#[cfg(c_bindings)] -impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> { - fn write(&self, writer: &mut W) -> Result<(), io::Error> { - self.0.write(writer) - } +pub struct MultiThreadedLockableScore { + score: Mutex, } #[cfg(c_bindings)] -impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore { +impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore { + type Score = T; type Locked = MultiThreadedScoreLock<'a, T>; - fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> { + fn lock(&'a self) -> Self::Locked { MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap()) } } @@ -240,7 +218,7 @@ impl Writeable for MultiThreadedLockableScore { } #[cfg(c_bindings)] -impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore {} +impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore {} #[cfg(c_bindings)] impl MultiThreadedLockableScore { @@ -250,6 +228,33 @@ impl MultiThreadedLockableScore { } } +#[cfg(c_bindings)] +/// A locked `MultiThreadedLockableScore`. +pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>); + +#[cfg(c_bindings)] +impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> { + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + self.0.write(writer) + } +} + +#[cfg(c_bindings)] +impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } +} + +#[cfg(c_bindings)] +impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + #[cfg(c_bindings)] /// This is not exported to bindings users impl<'a, T: Writeable> Writeable for RefMut<'a, T> { diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 24979036e..212f1f4b6 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -51,6 +51,7 @@ use regex; use crate::io; use crate::prelude::*; use core::cell::RefCell; +use core::ops::DerefMut; use core::time::Duration; use crate::sync::{Mutex, Arc}; use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; @@ -113,8 +114,8 @@ impl<'a> Router for TestRouter<'a> { if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() { assert_eq!(find_route_query, *params); if let Ok(ref route) = find_route_res { - let locked_scorer = self.scorer.lock().unwrap(); - let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs); + let mut binding = self.scorer.lock().unwrap(); + let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs); for path in &route.paths { let mut aggregate_msat = 0u64; for (idx, hop) in path.hops.iter().rev().enumerate() { @@ -139,10 +140,9 @@ impl<'a> Router for TestRouter<'a> { return find_route_res; } let logger = TestLogger::new(); - let scorer = self.scorer.lock().unwrap(); find_route( payer, params, &self.network_graph, first_hops, &logger, - &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs), &(), + &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().unwrap().deref_mut(), &inflight_htlcs), &(), &[42; 32] ) }