Split LockableScore responsibilities between read & write operations
[rust-lightning] / lightning / src / routing / router.rs
index ded4a5595608c1343ad16bf569cf9f32b6ae0d23..962080b0db47080d2d6fe90169b8276aef72b322 100644 (file)
@@ -20,22 +20,22 @@ use crate::ln::features::{Bolt11InvoiceFeatures, Bolt12InvoiceFeatures, ChannelF
 use crate::ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT};
 use crate::offers::invoice::{BlindedPayInfo, Bolt12Invoice};
 use crate::routing::gossip::{DirectedChannelInfo, EffectiveCapacity, ReadOnlyNetworkGraph, NetworkGraph, NodeId, RoutingFees};
-use crate::routing::scoring::{ChannelUsage, LockableScore, Score};
+use crate::routing::scoring::{ChannelUsage, LockableScore, ScoreLookUp};
 use crate::util::ser::{Writeable, Readable, ReadableArgs, Writer};
 use crate::util::logger::{Level, Logger};
 use crate::util::chacha20::ChaCha20;
 
 use crate::io;
 use crate::prelude::*;
-use crate::sync::{Mutex};
+use crate::sync::Mutex;
 use alloc::collections::BinaryHeap;
 use core::{cmp, fmt};
-use core::ops::{Deref, DerefMut};
+use core::ops::Deref;
 
 /// A [`Router`] implemented using [`find_route`].
-pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> where
+pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: ScoreLookUp<ScoreParams = SP>> where
        L::Target: Logger,
-       S::Target: for <'a> LockableScore<'a, Score = Sc>,
+       S::Target: for <'a> LockableScore<'a, ScoreLookUp = Sc>,
 {
        network_graph: G,
        logger: L,
@@ -44,9 +44,9 @@ pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref,
        score_params: SP
 }
 
-impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
+impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: ScoreLookUp<ScoreParams = SP>> DefaultRouter<G, L, S, SP, Sc> where
        L::Target: Logger,
-       S::Target: for <'a> LockableScore<'a, Score = Sc>,
+       S::Target: for <'a> LockableScore<'a, ScoreLookUp = Sc>,
 {
        /// Creates a new router.
        pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self {
@@ -55,9 +55,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Scor
        }
 }
 
-impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Score<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
+impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: ScoreLookUp<ScoreParams = SP>> Router for DefaultRouter<G, L, S, SP, Sc> where
        L::Target: Logger,
-       S::Target: for <'a> LockableScore<'a, Score = Sc>,
+       S::Target: for <'a> LockableScore<'a, ScoreLookUp = Sc>,
 {
        fn find_route(
                &self,
@@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sco
                };
                find_route(
                        payer, params, &self.network_graph, first_hops, &*self.logger,
-                       &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), &inflight_htlcs),
+                       &ScorerAccountingForInFlightHtlcs::new(self.scorer.read_lock(), &inflight_htlcs),
                        &self.score_params,
                        &random_seed_bytes
                )
@@ -106,21 +106,20 @@ pub trait Router {
        }
 }
 
-/// [`Score`] implementation that factors in in-flight HTLC liquidity.
+/// [`ScoreLookUp`] implementation that factors in in-flight HTLC liquidity.
 ///
-/// Useful for custom [`Router`] implementations to wrap their [`Score`] on-the-fly when calling
+/// Useful for custom [`Router`] implementations to wrap their [`ScoreLookUp`] on-the-fly when calling
 /// [`find_route`].
 ///
-/// [`Score`]: crate::routing::scoring::Score
-pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP: Sized> {
-       scorer: &'a mut S,
+/// [`ScoreLookUp`]: crate::routing::scoring::ScoreLookUp
+pub struct ScorerAccountingForInFlightHtlcs<'a, SP: Sized, Sc: 'a + ScoreLookUp<ScoreParams = SP>, S: Deref<Target = Sc>> {
+       scorer: S,
        // Maps a channel's short channel id and its direction to the liquidity used up.
        inflight_htlcs: &'a InFlightHtlcs,
 }
-
-impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
+impl<'a, SP: Sized, Sc: ScoreLookUp<ScoreParams = SP>, S: Deref<Target = Sc>> ScorerAccountingForInFlightHtlcs<'a, SP, Sc, S> {
        /// Initialize a new `ScorerAccountingForInFlightHtlcs`.
-       pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
+       pub fn new(scorer: S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
                ScorerAccountingForInFlightHtlcs {
                        scorer,
                        inflight_htlcs
@@ -129,12 +128,12 @@ impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs
 }
 
 #[cfg(c_bindings)]
-impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Writeable for ScorerAccountingForInFlightHtlcs<'a, S, SP> {
+impl<'a, SP: Sized, Sc: ScoreLookUp<ScoreParams = SP>, S: Deref<Target = Sc>> Writeable for ScorerAccountingForInFlightHtlcs<'a, SP, Sc, S> {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> { self.scorer.write(writer) }
 }
 
-impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInFlightHtlcs<'a, S, SP>  {
-       type ScoreParams = S::ScoreParams;
+impl<'a, SP: Sized, Sc: 'a + ScoreLookUp<ScoreParams = SP>, S: Deref<Target = Sc>> ScoreLookUp for ScorerAccountingForInFlightHtlcs<'a, SP, Sc, S> {
+       type ScoreParams = Sc::ScoreParams;
        fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
                if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat(
                        source, target, short_channel_id
@@ -149,22 +148,6 @@ impl<'a, S: Score<ScoreParams = SP>, SP: Sized> Score for ScorerAccountingForInF
                        self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params)
                }
        }
-
-       fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
-               self.scorer.payment_path_failed(path, short_channel_id)
-       }
-
-       fn payment_path_successful(&mut self, path: &Path) {
-               self.scorer.payment_path_successful(path)
-       }
-
-       fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
-               self.scorer.probe_failed(path, short_channel_id)
-       }
-
-       fn probe_successful(&mut self, path: &Path) {
-               self.scorer.probe_successful(path)
-       }
 }
 
 /// A data structure for tracking in-flight HTLCs. May be used during pathfinding to account for
@@ -1410,7 +1393,7 @@ fn sort_first_hop_channels(
 /// [`ChannelManager::list_usable_channels`]: crate::ln::channelmanager::ChannelManager::list_usable_channels
 /// [`Event::PaymentPathFailed`]: crate::events::Event::PaymentPathFailed
 /// [`NetworkGraph`]: crate::routing::gossip::NetworkGraph
-pub fn find_route<L: Deref, GL: Deref, S: Score>(
+pub fn find_route<L: Deref, GL: Deref, S: ScoreLookUp>(
        our_node_pubkey: &PublicKey, route_params: &RouteParameters,
        network_graph: &NetworkGraph<GL>, first_hops: Option<&[&ChannelDetails]>, logger: L,
        scorer: &S, score_params: &S::ScoreParams, random_seed_bytes: &[u8; 32]
@@ -1424,7 +1407,7 @@ where L::Target: Logger, GL::Target: Logger {
        Ok(route)
 }
 
-pub(crate) fn get_route<L: Deref, S: Score>(
+pub(crate) fn get_route<L: Deref, S: ScoreLookUp>(
        our_node_pubkey: &PublicKey, payment_params: &PaymentParameters, network_graph: &ReadOnlyNetworkGraph,
        first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, logger: L, scorer: &S, score_params: &S::ScoreParams,
        _random_seed_bytes: &[u8; 32]
@@ -2614,7 +2597,7 @@ fn build_route_from_hops_internal<L: Deref>(
                hop_ids: [Option<NodeId>; MAX_PATH_LENGTH_ESTIMATE as usize],
        }
 
-       impl Score for HopScorer {
+       impl ScoreLookUp for HopScorer {
                type ScoreParams = ();
                fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
                        _usage: ChannelUsage, _score_params: &Self::ScoreParams) -> u64
@@ -2632,14 +2615,6 @@ fn build_route_from_hops_internal<L: Deref>(
                        }
                        u64::max_value()
                }
-
-               fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
-
-               fn payment_path_successful(&mut self, _path: &Path) {}
-
-               fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
-
-               fn probe_successful(&mut self, _path: &Path) {}
        }
 
        impl<'a> Writeable for HopScorer {
@@ -2673,7 +2648,7 @@ mod tests {
        use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
                BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
                DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
-       use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, Score, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
+       use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, ScoreLookUp, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
        use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
        use crate::chain::transaction::OutPoint;
        use crate::sign::EntropySource;
@@ -5720,16 +5695,11 @@ mod tests {
        impl Writeable for BadChannelScorer {
                fn write<W: Writer>(&self, _w: &mut W) -> Result<(), crate::io::Error> { unimplemented!() }
        }
-       impl Score for BadChannelScorer {
+       impl ScoreLookUp for BadChannelScorer {
                type ScoreParams = ();
                fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
-
-               fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
-               fn payment_path_successful(&mut self, _path: &Path) {}
-               fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
-               fn probe_successful(&mut self, _path: &Path) {}
        }
 
        struct BadNodeScorer {
@@ -5741,16 +5711,11 @@ mod tests {
                fn write<W: Writer>(&self, _w: &mut W) -> Result<(), crate::io::Error> { unimplemented!() }
        }
 
-       impl Score for BadNodeScorer {
+       impl ScoreLookUp for BadNodeScorer {
                type ScoreParams = ();
                fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
-
-               fn payment_path_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
-               fn payment_path_successful(&mut self, _path: &Path) {}
-               fn probe_failed(&mut self, _path: &Path, _short_channel_id: u64) {}
-               fn probe_successful(&mut self, _path: &Path) {}
        }
 
        #[test]
@@ -6721,6 +6686,7 @@ pub(crate) mod bench_utils {
        use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
 
        use crate::chain::transaction::OutPoint;
+       use crate::routing::scoring::ScoreUpdate;
        use crate::sign::{EntropySource, KeysManager};
        use crate::ln::channelmanager::{self, ChannelCounterparty, ChannelDetails};
        use crate::ln::features::Bolt11InvoiceFeatures;
@@ -6813,7 +6779,7 @@ pub(crate) mod bench_utils {
                }
        }
 
-       pub(crate) fn generate_test_routes<S: Score>(graph: &NetworkGraph<&TestLogger>, scorer: &mut S,
+       pub(crate) fn generate_test_routes<S: ScoreLookUp + ScoreUpdate>(graph: &NetworkGraph<&TestLogger>, scorer: &mut S,
                score_params: &S::ScoreParams, features: Bolt11InvoiceFeatures, mut seed: u64,
                starting_amount: u64, route_count: usize,
        ) -> Vec<(ChannelDetails, PaymentParameters, u64)> {
@@ -6839,7 +6805,7 @@ pub(crate) mod bench_utils {
                                let amt = starting_amount + seed % 1_000_000;
                                let path_exists =
                                        get_route(&payer, &params, &graph.read_only(), Some(&[&first_hop]),
-                                               amt, &TestLogger::new(), &scorer, score_params, &random_seed_bytes).is_ok();
+                                               amt, &TestLogger::new(), scorer, score_params, &random_seed_bytes).is_ok();
                                if path_exists {
                                        // ...and seed the scorer with success and failure data...
                                        seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
@@ -6853,7 +6819,7 @@ pub(crate) mod bench_utils {
                                                        .with_bolt11_features(mpp_features).unwrap();
 
                                                let route_res = get_route(&payer, &params, &graph.read_only(),
-                                                       Some(&[&first_hop]), score_amt, &TestLogger::new(), &scorer,
+                                                       Some(&[&first_hop]), score_amt, &TestLogger::new(), scorer,
                                                        score_params, &random_seed_bytes);
                                                if let Ok(route) = route_res {
                                                        for path in route.paths {
@@ -6882,7 +6848,7 @@ pub(crate) mod bench_utils {
                // requires a too-high CLTV delta.
                route_endpoints.retain(|(first_hop, params, amt)| {
                        get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt,
-                               &TestLogger::new(), &scorer, score_params, &random_seed_bytes).is_ok()
+                               &TestLogger::new(), scorer, score_params, &random_seed_bytes).is_ok()
                });
                route_endpoints.truncate(route_count);
                assert_eq!(route_endpoints.len(), route_count);
@@ -6893,6 +6859,7 @@ pub(crate) mod bench_utils {
 #[cfg(ldk_bench)]
 pub mod benches {
        use super::*;
+       use crate::routing::scoring::{ScoreUpdate, ScoreLookUp};
        use crate::sign::{EntropySource, KeysManager};
        use crate::ln::channelmanager;
        use crate::ln::features::Bolt11InvoiceFeatures;
@@ -6955,7 +6922,7 @@ pub mod benches {
                        "generate_large_mpp_routes_with_probabilistic_scorer");
        }
 
-       fn generate_routes<S: Score>(
+       fn generate_routes<S: ScoreLookUp + ScoreUpdate>(
                bench: &mut Criterion, graph: &NetworkGraph<&TestLogger>, mut scorer: S,
                score_params: &S::ScoreParams, features: Bolt11InvoiceFeatures, starting_amount: u64,
                bench_name: &'static str,