Merge pull request #2466 from TheBlueMatt/2023-07-expose-success-prob
[rust-lightning] / lightning / src / routing / scoring.rs
index 9c337f9b9edbe9ef98ac4c44016233d975d15e24..3235d85a8c2a29edf8914cfcef0fdee8ea844a96 100644 (file)
@@ -157,8 +157,11 @@ define_score!();
 ///
 /// [`find_route`]: crate::routing::router::find_route
 pub trait LockableScore<'a> {
+       /// The [`Score`] type.
+       type Score: 'a + Score;
+
        /// The locked [`Score`] type.
-       type Locked: 'a + Score;
+       type Locked: DerefMut<Target = Self::Score> + Sized;
 
        /// Returns the locked scorer.
        fn lock(&'a self) -> Self::Locked;
@@ -172,62 +175,38 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {}
 
 #[cfg(not(c_bindings))]
 impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {}
-/// This is not exported to bindings users
+#[cfg(not(c_bindings))]
 impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
+       type Score = T;
        type Locked = MutexGuard<'a, T>;
 
-       fn lock(&'a self) -> MutexGuard<'a, T> {
+       fn lock(&'a self) -> Self::Locked {
                Mutex::lock(self).unwrap()
        }
 }
 
+#[cfg(not(c_bindings))]
 impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
+       type Score = T;
        type Locked = RefMut<'a, T>;
 
-       fn lock(&'a self) -> RefMut<'a, T> {
+       fn lock(&'a self) -> Self::Locked {
                self.borrow_mut()
        }
 }
 
 #[cfg(c_bindings)]
 /// A concrete implementation of [`LockableScore`] which supports multi-threading.
-pub struct MultiThreadedLockableScore<S: Score> {
-       score: Mutex<S>,
-}
-#[cfg(c_bindings)]
-/// A locked `MultiThreadedLockableScore`.
-pub struct MultiThreadedScoreLock<'a, S: Score>(MutexGuard<'a, S>);
-#[cfg(c_bindings)]
-impl<'a, T: Score + 'a> Score for MultiThreadedScoreLock<'a, T> {
-       type ScoreParams = <T as Score>::ScoreParams;
-       fn channel_penalty_msat(&self, scid: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 {
-               self.0.channel_penalty_msat(scid, source, target, usage, score_params)
-       }
-       fn payment_path_failed(&mut self, path: &Path, short_channel_id: u64) {
-               self.0.payment_path_failed(path, short_channel_id)
-       }
-       fn payment_path_successful(&mut self, path: &Path) {
-               self.0.payment_path_successful(path)
-       }
-       fn probe_failed(&mut self, path: &Path, short_channel_id: u64) {
-               self.0.probe_failed(path, short_channel_id)
-       }
-       fn probe_successful(&mut self, path: &Path) {
-               self.0.probe_successful(path)
-       }
-}
-#[cfg(c_bindings)]
-impl<'a, T: Score + 'a> Writeable for MultiThreadedScoreLock<'a, T> {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
-               self.0.write(writer)
-       }
+pub struct MultiThreadedLockableScore<T: Score> {
+       score: Mutex<T>,
 }
 
 #[cfg(c_bindings)]
-impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
+impl<'a, T: 'a + Score> LockableScore<'a> for MultiThreadedLockableScore<T> {
+       type Score = T;
        type Locked = MultiThreadedScoreLock<'a, T>;
 
-       fn lock(&'a self) -> MultiThreadedScoreLock<'a, T> {
+       fn lock(&'a self) -> Self::Locked {
                MultiThreadedScoreLock(Mutex::lock(&self.score).unwrap())
        }
 }
@@ -240,7 +219,7 @@ impl<T: Score> Writeable for MultiThreadedLockableScore<T> {
 }
 
 #[cfg(c_bindings)]
-impl<'a, T: Score + 'a> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
+impl<'a, T: 'a + Score> WriteableScore<'a> for MultiThreadedLockableScore<T> {}
 
 #[cfg(c_bindings)]
 impl<T: Score> MultiThreadedLockableScore<T> {
@@ -251,21 +230,34 @@ impl<T: Score> MultiThreadedLockableScore<T> {
 }
 
 #[cfg(c_bindings)]
-/// This is not exported to bindings users
-impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
+/// A locked `MultiThreadedLockableScore`.
+pub struct MultiThreadedScoreLock<'a, T: Score>(MutexGuard<'a, T>);
+
+#[cfg(c_bindings)]
+impl<'a, T: 'a + Score> Writeable for MultiThreadedScoreLock<'a, T> {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
-               T::write(&**self, writer)
+               self.0.write(writer)
        }
 }
 
 #[cfg(c_bindings)]
-/// This is not exported to bindings users
-impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
-               S::write(&**self, writer)
-       }
+impl<'a, T: 'a + Score> DerefMut for MultiThreadedScoreLock<'a, T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.0.deref_mut()
+    }
+}
+
+#[cfg(c_bindings)]
+impl<'a, T: 'a + Score> Deref for MultiThreadedScoreLock<'a, T> {
+       type Target = T;
+
+    fn deref(&self) -> &Self::Target {
+        self.0.deref()
+    }
 }
 
+
+
 /// Proposed use of a channel passed as a parameter to [`Score::channel_penalty_msat`].
 #[derive(Clone, Copy, Debug, PartialEq)]
 pub struct ChannelUsage {
@@ -325,7 +317,7 @@ impl ReadableArgs<u64> for FixedPenaltyScorer {
 }
 
 #[cfg(not(feature = "no-std"))]
-type ConfiguredTime = std::time::Instant;
+type ConfiguredTime = crate::util::time::MonotonicTime;
 #[cfg(feature = "no-std")]
 use crate::util::time::Eternity;
 #[cfg(feature = "no-std")]
@@ -491,7 +483,7 @@ pub struct ProbabilisticScoringFeeParameters {
        pub manual_node_penalties: HashMap<NodeId, u64>,
 
        /// This penalty is applied when `htlc_maximum_msat` is equal to or larger than half of the
-       /// channel's capacity, (ie. htlc_maximum_msat  0.5 * channel_capacity) which makes us
+       /// channel's capacity, (ie. htlc_maximum_msat >= 0.5 * channel_capacity) which makes us
        /// prefer nodes with a smaller `htlc_maximum_msat`. We treat such nodes preferentially
        /// as this makes balance discovery attacks harder to execute, thereby creating an incentive
        /// to restrict `htlc_maximum_msat` and improve privacy.
@@ -673,8 +665,7 @@ struct ChannelLiquidity<T: Time> {
 struct DirectedChannelLiquidity<L: Deref<Target = u64>, BRT: Deref<Target = HistoricalBucketRangeTracker>, T: Time, U: Deref<Target = T>> {
        min_liquidity_offset_msat: L,
        max_liquidity_offset_msat: L,
-       min_liquidity_offset_history: BRT,
-       max_liquidity_offset_history: BRT,
+       liquidity_history: HistoricalMinMaxBuckets<BRT>,
        inflight_htlc_msat: u64,
        capacity_msat: u64,
        last_updated: U,
@@ -715,12 +706,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
                                                let amt = directed_info.effective_capacity().as_msat();
                                                let dir_liq = liq.as_directed(source, target, 0, amt, self.decay_params);
 
-                                               let buckets = HistoricalMinMaxBuckets {
-                                                       min_liquidity_offset_history: &dir_liq.min_liquidity_offset_history,
-                                                       max_liquidity_offset_history: &dir_liq.max_liquidity_offset_history,
-                                               };
-                                               let (min_buckets, max_buckets, _) = buckets.get_decayed_buckets(now,
-                                                       *dir_liq.last_updated, self.decay_params.historical_no_updates_half_life);
+                                               let (min_buckets, max_buckets, _) = dir_liq.liquidity_history
+                                                       .get_decayed_buckets(now, *dir_liq.last_updated,
+                                                               self.decay_params.historical_no_updates_half_life);
 
                                                log_debug!(self.logger, core::concat!(
                                                        "Liquidity from {} to {} via {} is in the range ({}, {}).\n",
@@ -797,12 +785,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
                                        let amt = directed_info.effective_capacity().as_msat();
                                        let dir_liq = liq.as_directed(source, target, 0, amt, self.decay_params);
 
-                                       let buckets = HistoricalMinMaxBuckets {
-                                               min_liquidity_offset_history: &dir_liq.min_liquidity_offset_history,
-                                               max_liquidity_offset_history: &dir_liq.max_liquidity_offset_history,
-                                       };
-                                       let (min_buckets, mut max_buckets, _) = buckets.get_decayed_buckets(dir_liq.now,
-                                               *dir_liq.last_updated, self.decay_params.historical_no_updates_half_life);
+                                       let (min_buckets, mut max_buckets, _) = dir_liq.liquidity_history
+                                               .get_decayed_buckets(dir_liq.now, *dir_liq.last_updated,
+                                                       self.decay_params.historical_no_updates_half_life);
                                        // Note that the liquidity buckets are an offset from the edge, so we inverse
                                        // the max order to get the probabilities from zero.
                                        max_buckets.reverse();
@@ -831,14 +816,9 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> ProbabilisticScorerU
                                        let capacity_msat = directed_info.effective_capacity().as_msat();
                                        let dir_liq = liq.as_directed(source, target, 0, capacity_msat, self.decay_params);
 
-                                       let buckets = HistoricalMinMaxBuckets {
-                                               min_liquidity_offset_history: &dir_liq.min_liquidity_offset_history,
-                                               max_liquidity_offset_history: &dir_liq.max_liquidity_offset_history,
-                                       };
-
-                                       return buckets.calculate_success_probability_times_billion(dir_liq.now,
-                                               *dir_liq.last_updated, self.decay_params.historical_no_updates_half_life,
-                                               amount_msat, capacity_msat
+                                       return dir_liq.liquidity_history.calculate_success_probability_times_billion(
+                                               dir_liq.now, *dir_liq.last_updated,
+                                               self.decay_params.historical_no_updates_half_life, amount_msat, capacity_msat
                                        ).map(|p| p as f64 / (1024 * 1024 * 1024) as f64);
                                }
                        }
@@ -876,8 +856,10 @@ impl<T: Time> ChannelLiquidity<T> {
                DirectedChannelLiquidity {
                        min_liquidity_offset_msat,
                        max_liquidity_offset_msat,
-                       min_liquidity_offset_history,
-                       max_liquidity_offset_history,
+                       liquidity_history: HistoricalMinMaxBuckets {
+                               min_liquidity_offset_history,
+                               max_liquidity_offset_history,
+                       },
                        inflight_htlc_msat,
                        capacity_msat,
                        last_updated: &self.last_updated,
@@ -903,8 +885,10 @@ impl<T: Time> ChannelLiquidity<T> {
                DirectedChannelLiquidity {
                        min_liquidity_offset_msat,
                        max_liquidity_offset_msat,
-                       min_liquidity_offset_history,
-                       max_liquidity_offset_history,
+                       liquidity_history: HistoricalMinMaxBuckets {
+                               min_liquidity_offset_history,
+                               max_liquidity_offset_history,
+                       },
                        inflight_htlc_msat,
                        capacity_msat,
                        last_updated: &mut self.last_updated,
@@ -973,11 +957,7 @@ impl<L: Deref<Target = u64>, BRT: Deref<Target = HistoricalBucketRangeTracker>,
 
                if score_params.historical_liquidity_penalty_multiplier_msat != 0 ||
                   score_params.historical_liquidity_penalty_amount_multiplier_msat != 0 {
-                       let buckets = HistoricalMinMaxBuckets {
-                               min_liquidity_offset_history: &self.min_liquidity_offset_history,
-                               max_liquidity_offset_history: &self.max_liquidity_offset_history,
-                       };
-                       if let Some(cumulative_success_prob_times_billion) = buckets
+                       if let Some(cumulative_success_prob_times_billion) = self.liquidity_history
                                .calculate_success_probability_times_billion(self.now, *self.last_updated,
                                        self.decay_params.historical_no_updates_half_life, amount_msat, self.capacity_msat)
                        {
@@ -1041,10 +1021,25 @@ impl<L: Deref<Target = u64>, BRT: Deref<Target = HistoricalBucketRangeTracker>,
        }
 
        fn decayed_offset_msat(&self, offset_msat: u64) -> u64 {
-               self.now.duration_since(*self.last_updated).as_secs()
-                       .checked_div(self.decay_params.liquidity_offset_half_life.as_secs())
-                       .and_then(|decays| offset_msat.checked_shr(decays as u32))
-                       .unwrap_or(0)
+               let half_life = self.decay_params.liquidity_offset_half_life.as_secs();
+               if half_life != 0 {
+                       // Decay the offset by the appropriate number of half lives. If half of the next half
+                       // life has passed, approximate an additional three-quarter life to help smooth out the
+                       // decay.
+                       let elapsed_time = self.now.duration_since(*self.last_updated).as_secs();
+                       let half_decays = elapsed_time / (half_life / 2);
+                       let decays = half_decays / 2;
+                       let decayed_offset_msat = offset_msat.checked_shr(decays as u32).unwrap_or(0);
+                       if half_decays % 2 == 0 {
+                               decayed_offset_msat
+                       } else {
+                               // 11_585 / 16_384 ~= core::f64::consts::FRAC_1_SQRT_2
+                               // 16_384 == 2^14
+                               (decayed_offset_msat as u128 * 11_585 / 16_384) as u64
+                       }
+               } else {
+                       0
+               }
        }
 }
 
@@ -1087,15 +1082,15 @@ impl<L: DerefMut<Target = u64>, BRT: DerefMut<Target = HistoricalBucketRangeTrac
                let half_lives = self.now.duration_since(*self.last_updated).as_secs()
                        .checked_div(self.decay_params.historical_no_updates_half_life.as_secs())
                        .map(|v| v.try_into().unwrap_or(u32::max_value())).unwrap_or(u32::max_value());
-               self.min_liquidity_offset_history.time_decay_data(half_lives);
-               self.max_liquidity_offset_history.time_decay_data(half_lives);
+               self.liquidity_history.min_liquidity_offset_history.time_decay_data(half_lives);
+               self.liquidity_history.max_liquidity_offset_history.time_decay_data(half_lives);
 
                let min_liquidity_offset_msat = self.decayed_offset_msat(*self.min_liquidity_offset_msat);
-               self.min_liquidity_offset_history.track_datapoint(
+               self.liquidity_history.min_liquidity_offset_history.track_datapoint(
                        min_liquidity_offset_msat, self.capacity_msat
                );
                let max_liquidity_offset_msat = self.decayed_offset_msat(*self.max_liquidity_offset_msat);
-               self.max_liquidity_offset_history.track_datapoint(
+               self.liquidity_history.max_liquidity_offset_history.track_datapoint(
                        max_liquidity_offset_msat, self.capacity_msat
                );
        }
@@ -1138,8 +1133,10 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
 
                let mut anti_probing_penalty_msat = 0;
                match usage.effective_capacity {
-                       EffectiveCapacity::ExactLiquidity { liquidity_msat } => {
-                               if usage.amount_msat > liquidity_msat {
+                       EffectiveCapacity::ExactLiquidity { liquidity_msat: amount_msat } |
+                               EffectiveCapacity::HintMaxHTLC { amount_msat } =>
+                       {
+                               if usage.amount_msat > amount_msat {
                                        return u64::max_value();
                                } else {
                                        return base_penalty_msat;
@@ -1607,12 +1604,12 @@ mod bucketed_history {
 
        impl_writeable_tlv_based!(HistoricalBucketRangeTracker, { (0, buckets, required) });
 
-       pub(super) struct HistoricalMinMaxBuckets<'a> {
-               pub(super) min_liquidity_offset_history: &'a HistoricalBucketRangeTracker,
-               pub(super) max_liquidity_offset_history: &'a HistoricalBucketRangeTracker,
+       pub(super) struct HistoricalMinMaxBuckets<D: Deref<Target = HistoricalBucketRangeTracker>> {
+               pub(super) min_liquidity_offset_history: D,
+               pub(super) max_liquidity_offset_history: D,
        }
 
-       impl HistoricalMinMaxBuckets<'_> {
+       impl<D: Deref<Target = HistoricalBucketRangeTracker>> HistoricalMinMaxBuckets<D> {
                #[inline]
                pub(super) fn get_decayed_buckets<T: Time>(&self, now: T, last_updated: T, half_life: Duration)
                -> ([u16; 8], [u16; 8], u32) {
@@ -2453,6 +2450,7 @@ mod tests {
                scorer.payment_path_failed(&payment_path_for_amount(768), 42);
                scorer.payment_path_failed(&payment_path_for_amount(128), 43);
 
+               // Initial penalties
                let usage = ChannelUsage { amount_msat: 128, ..usage };
                assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), 0);
                let usage = ChannelUsage { amount_msat: 256, ..usage };
@@ -2462,7 +2460,8 @@ mod tests {
                let usage = ChannelUsage { amount_msat: 896, ..usage };
                assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), u64::max_value());
 
-               SinceEpoch::advance(Duration::from_secs(9));
+               // No decay
+               SinceEpoch::advance(Duration::from_secs(4));
                let usage = ChannelUsage { amount_msat: 128, ..usage };
                assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), 0);
                let usage = ChannelUsage { amount_msat: 256, ..usage };
@@ -2472,7 +2471,19 @@ mod tests {
                let usage = ChannelUsage { amount_msat: 896, ..usage };
                assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), u64::max_value());
 
+               // Half decay (i.e., three-quarter life)
                SinceEpoch::advance(Duration::from_secs(1));
+               let usage = ChannelUsage { amount_msat: 128, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), 22);
+               let usage = ChannelUsage { amount_msat: 256, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), 106);
+               let usage = ChannelUsage { amount_msat: 768, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), 916);
+               let usage = ChannelUsage { amount_msat: 896, ..usage };
+               assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), u64::max_value());
+
+               // One decay (i.e., half life)
+               SinceEpoch::advance(Duration::from_secs(5));
                let usage = ChannelUsage { amount_msat: 64, ..usage };
                assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage, &params), 0);
                let usage = ChannelUsage { amount_msat: 128, ..usage };
@@ -2944,7 +2955,7 @@ mod tests {
                let usage = ChannelUsage {
                        amount_msat: 1,
                        inflight_htlc_msat: 0,
-                       effective_capacity: EffectiveCapacity::MaximumHTLC { amount_msat: 0 },
+                       effective_capacity: EffectiveCapacity::AdvertisedMaxHTLC { amount_msat: 0 },
                };
                assert_eq!(scorer.channel_penalty_msat(42, &target, &source, usage, &params), 2048);