100363b689ce5a4e1f87e86e69fa6739082b9aa6
[rust-lightning] / lightning / src / routing / scoring.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Utilities for scoring payment channels.
11 //!
12 //! [`ProbabilisticScorer`] may be given to [`find_route`] to score payment channels during path
13 //! finding when a custom [`Score`] implementation is not needed.
14 //!
15 //! # Example
16 //!
17 //! ```
18 //! # extern crate secp256k1;
19 //! #
20 //! # use lightning::routing::network_graph::NetworkGraph;
21 //! # use lightning::routing::router::{RouteParameters, find_route};
22 //! # use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters, Scorer, ScoringParameters};
23 //! # use lightning::util::logger::{Logger, Record};
24 //! # use secp256k1::key::PublicKey;
25 //! #
26 //! # struct FakeLogger {};
27 //! # impl Logger for FakeLogger {
28 //! #     fn log(&self, record: &Record) { unimplemented!() }
29 //! # }
30 //! # fn find_scored_route(payer: PublicKey, route_params: RouteParameters, network_graph: NetworkGraph) {
31 //! # let logger = FakeLogger {};
32 //! #
33 //! // Use the default channel penalties.
34 //! let params = ProbabilisticScoringParameters::default();
35 //! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
36 //!
37 //! // Or use custom channel penalties.
38 //! let params = ProbabilisticScoringParameters {
39 //!     liquidity_penalty_multiplier_msat: 2 * 1000,
40 //!     ..ProbabilisticScoringParameters::default()
41 //! };
42 //! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
43 //!
44 //! let route = find_route(&payer, &route_params, &network_graph, None, &logger, &scorer);
45 //! # }
46 //! ```
47 //!
48 //! # Note
49 //!
50 //! Persisting when built with feature `no-std` and restoring without it, or vice versa, uses
51 //! different types and thus is undefined.
52 //!
53 //! [`find_route`]: crate::routing::router::find_route
54
55 use bitcoin::secp256k1::key::PublicKey;
56
57 use ln::msgs::DecodeError;
58 use routing::network_graph::{EffectiveCapacity, NetworkGraph, NodeId};
59 use routing::router::RouteHop;
60 use util::ser::{Readable, ReadableArgs, Writeable, Writer};
61
62 use prelude::*;
63 use core::cell::{RefCell, RefMut};
64 use core::ops::{Deref, DerefMut};
65 use core::time::Duration;
66 use io::{self, Read};
67 use sync::{Mutex, MutexGuard};
68
69 /// We define Score ever-so-slightly differently based on whether we are being built for C bindings
70 /// or not. For users, `LockableScore` must somehow be writeable to disk. For Rust users, this is
71 /// no problem - you move a `Score` that implements `Writeable` into a `Mutex`, lock it, and now
72 /// you have the original, concrete, `Score` type, which presumably implements `Writeable`.
73 ///
74 /// For C users, once you've moved the `Score` into a `LockableScore` all you have after locking it
75 /// is an opaque trait object with an opaque pointer with no type info. Users could take the unsafe
76 /// approach of blindly casting that opaque pointer to a concrete type and calling `Writeable` from
77 /// there, but other languages downstream of the C bindings (e.g. Java) can't even do that.
78 /// Instead, we really want `Score` and `LockableScore` to implement `Writeable` directly, which we
79 /// do here by defining `Score` differently for `cfg(c_bindings)`.
80 macro_rules! define_score { ($($supertrait: path)*) => {
81 /// An interface used to score payment channels for path finding.
82 ///
83 ///     Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
84 pub trait Score $(: $supertrait)* {
85         /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the
86         /// given channel in the direction from `source` to `target`.
87         ///
88         /// The channel's capacity (less any other MPP parts that are also being considered for use in
89         /// the same payment) is given by `capacity_msat`. It may be determined from various sources
90         /// such as a chain data, network gossip, or invoice hints. For invoice hints, a capacity near
91         /// [`u64::max_value`] is given to indicate sufficient capacity for the invoice's full amount.
92         /// Thus, implementations should be overflow-safe.
93         fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64;
94
95         /// Handles updating channel penalties after failing to route through a channel.
96         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64);
97
98         /// Handles updating channel penalties after successfully routing along a path.
99         fn payment_path_successful(&mut self, path: &[&RouteHop]);
100 }
101
102 impl<S: Score, T: DerefMut<Target=S> $(+ $supertrait)*> Score for T {
103         fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64 {
104                 self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, capacity_msat, source, target)
105         }
106
107         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
108                 self.deref_mut().payment_path_failed(path, short_channel_id)
109         }
110
111         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
112                 self.deref_mut().payment_path_successful(path)
113         }
114 }
115 } }
116
117 #[cfg(c_bindings)]
118 define_score!(Writeable);
119 #[cfg(not(c_bindings))]
120 define_score!();
121
122 /// A scorer that is accessed under a lock.
123 ///
124 /// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
125 /// having shared ownership of a scorer but without requiring internal locking in [`Score`]
126 /// implementations. Internal locking would be detrimental to route finding performance and could
127 /// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
128 ///
129 /// [`find_route`]: crate::routing::router::find_route
130 pub trait LockableScore<'a> {
131         /// The locked [`Score`] type.
132         type Locked: 'a + Score;
133
134         /// Returns the locked scorer.
135         fn lock(&'a self) -> Self::Locked;
136 }
137
138 /// (C-not exported)
139 impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
140         type Locked = MutexGuard<'a, T>;
141
142         fn lock(&'a self) -> MutexGuard<'a, T> {
143                 Mutex::lock(self).unwrap()
144         }
145 }
146
147 impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
148         type Locked = RefMut<'a, T>;
149
150         fn lock(&'a self) -> RefMut<'a, T> {
151                 self.borrow_mut()
152         }
153 }
154
155 #[cfg(c_bindings)]
156 /// A concrete implementation of [`LockableScore`] which supports multi-threading.
157 pub struct MultiThreadedLockableScore<S: Score> {
158         score: Mutex<S>,
159 }
160 #[cfg(c_bindings)]
161 /// (C-not exported)
162 impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
163         type Locked = MutexGuard<'a, T>;
164
165         fn lock(&'a self) -> MutexGuard<'a, T> {
166                 Mutex::lock(&self.score).unwrap()
167         }
168 }
169
170 #[cfg(c_bindings)]
171 impl<T: Score> MultiThreadedLockableScore<T> {
172         /// Creates a new [`MultiThreadedLockableScore`] given an underlying [`Score`].
173         pub fn new(score: T) -> Self {
174                 MultiThreadedLockableScore { score: Mutex::new(score) }
175         }
176 }
177
178 #[cfg(c_bindings)]
179 /// (C-not exported)
180 impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
181         fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
182                 T::write(&**self, writer)
183         }
184 }
185
186 #[cfg(c_bindings)]
187 /// (C-not exported)
188 impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> {
189         fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
190                 S::write(&**self, writer)
191         }
192 }
193
194 /// [`Score`] implementation that provides reasonable default behavior.
195 ///
196 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
197 /// slightly higher fees are available. Will further penalize channels that fail to relay payments.
198 ///
199 /// See [module-level documentation] for usage.
200 ///
201 /// [module-level documentation]: crate::routing::scoring
202 #[cfg(not(feature = "no-std"))]
203 pub type Scorer = ScorerUsingTime::<std::time::Instant>;
204 /// [`Score`] implementation that provides reasonable default behavior.
205 ///
206 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
207 /// slightly higher fees are available. Will further penalize channels that fail to relay payments.
208 ///
209 /// See [module-level documentation] for usage and [`ScoringParameters`] for customization.
210 ///
211 /// [module-level documentation]: crate::routing::scoring
212 #[cfg(feature = "no-std")]
213 pub type Scorer = ScorerUsingTime::<time::Eternity>;
214
215 // Note that ideally we'd hide ScorerUsingTime from public view by sealing it as well, but rustdoc
216 // doesn't handle this well - instead exposing a `Scorer` which has no trait implementation(s) or
217 // methods at all.
218
219 /// [`Score`] implementation.
220 ///
221 /// See [`Scorer`] for details.
222 ///
223 /// # Note
224 ///
225 /// Mixing the `no-std` feature between serialization and deserialization results in undefined
226 /// behavior.
227 ///
228 /// (C-not exported) generally all users should use the [`Scorer`] type alias.
229 pub struct ScorerUsingTime<T: Time> {
230         params: ScoringParameters,
231         // TODO: Remove entries of closed channels.
232         channel_failures: HashMap<u64, ChannelFailure<T>>,
233 }
234
235 /// Parameters for configuring [`Scorer`].
236 pub struct ScoringParameters {
237         /// A fixed penalty in msats to apply to each channel.
238         ///
239         /// Default value: 500 msat
240         pub base_penalty_msat: u64,
241
242         /// A penalty in msats to apply to a channel upon failing to relay a payment.
243         ///
244         /// This accumulates for each failure but may be reduced over time based on
245         /// [`failure_penalty_half_life`] or when successfully routing through a channel.
246         ///
247         /// Default value: 1,024,000 msat
248         ///
249         /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life
250         pub failure_penalty_msat: u64,
251
252         /// When the amount being sent over a channel is this many 1024ths of the total channel
253         /// capacity, we begin applying [`overuse_penalty_msat_per_1024th`].
254         ///
255         /// Default value: 128 1024ths (i.e. begin penalizing when an HTLC uses 1/8th of a channel)
256         ///
257         /// [`overuse_penalty_msat_per_1024th`]: Self::overuse_penalty_msat_per_1024th
258         pub overuse_penalty_start_1024th: u16,
259
260         /// A penalty applied, per whole 1024ths of the channel capacity which the amount being sent
261         /// over the channel exceeds [`overuse_penalty_start_1024th`] by.
262         ///
263         /// Default value: 20 msat (i.e. 2560 msat penalty to use 1/4th of a channel, 7680 msat penalty
264         ///                to use half a channel, and 12,560 msat penalty to use 3/4ths of a channel)
265         ///
266         /// [`overuse_penalty_start_1024th`]: Self::overuse_penalty_start_1024th
267         pub overuse_penalty_msat_per_1024th: u64,
268
269         /// The time required to elapse before any accumulated [`failure_penalty_msat`] penalties are
270         /// cut in half.
271         ///
272         /// Successfully routing through a channel will immediately cut the penalty in half as well.
273         ///
274         /// # Note
275         ///
276         /// When built with the `no-std` feature, time will never elapse. Therefore, this penalty will
277         /// never decay.
278         ///
279         /// [`failure_penalty_msat`]: Self::failure_penalty_msat
280         pub failure_penalty_half_life: Duration,
281 }
282
283 impl_writeable_tlv_based!(ScoringParameters, {
284         (0, base_penalty_msat, required),
285         (1, overuse_penalty_start_1024th, (default_value, 128)),
286         (2, failure_penalty_msat, required),
287         (3, overuse_penalty_msat_per_1024th, (default_value, 20)),
288         (4, failure_penalty_half_life, required),
289 });
290
291 /// Accounting for penalties against a channel for failing to relay any payments.
292 ///
293 /// Penalties decay over time, though accumulate as more failures occur.
294 struct ChannelFailure<T: Time> {
295         /// Accumulated penalty in msats for the channel as of `last_updated`.
296         undecayed_penalty_msat: u64,
297
298         /// Last time the channel either failed to route or successfully routed a payment. Used to decay
299         /// `undecayed_penalty_msat`.
300         last_updated: T,
301 }
302
303 impl<T: Time> ScorerUsingTime<T> {
304         /// Creates a new scorer using the given scoring parameters.
305         pub fn new(params: ScoringParameters) -> Self {
306                 Self {
307                         params,
308                         channel_failures: HashMap::new(),
309                 }
310         }
311
312         /// Creates a new scorer using `penalty_msat` as a fixed channel penalty.
313         #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
314         pub fn with_fixed_penalty(penalty_msat: u64) -> Self {
315                 Self::new(ScoringParameters {
316                         base_penalty_msat: penalty_msat,
317                         failure_penalty_msat: 0,
318                         failure_penalty_half_life: Duration::from_secs(0),
319                         overuse_penalty_start_1024th: 1024,
320                         overuse_penalty_msat_per_1024th: 0,
321                 })
322         }
323 }
324
325 impl<T: Time> ChannelFailure<T> {
326         fn new(failure_penalty_msat: u64) -> Self {
327                 Self {
328                         undecayed_penalty_msat: failure_penalty_msat,
329                         last_updated: T::now(),
330                 }
331         }
332
333         fn add_penalty(&mut self, failure_penalty_msat: u64, half_life: Duration) {
334                 self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) + failure_penalty_msat;
335                 self.last_updated = T::now();
336         }
337
338         fn reduce_penalty(&mut self, half_life: Duration) {
339                 self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) >> 1;
340                 self.last_updated = T::now();
341         }
342
343         fn decayed_penalty_msat(&self, half_life: Duration) -> u64 {
344                 self.last_updated.elapsed().as_secs()
345                         .checked_div(half_life.as_secs())
346                         .and_then(|decays| self.undecayed_penalty_msat.checked_shr(decays as u32))
347                         .unwrap_or(0)
348         }
349 }
350
351 impl<T: Time> Default for ScorerUsingTime<T> {
352         fn default() -> Self {
353                 Self::new(ScoringParameters::default())
354         }
355 }
356
357 impl Default for ScoringParameters {
358         fn default() -> Self {
359                 Self {
360                         base_penalty_msat: 500,
361                         failure_penalty_msat: 1024 * 1000,
362                         failure_penalty_half_life: Duration::from_secs(3600),
363                         overuse_penalty_start_1024th: 1024 / 8,
364                         overuse_penalty_msat_per_1024th: 20,
365                 }
366         }
367 }
368
369 impl<T: Time> Score for ScorerUsingTime<T> {
370         fn channel_penalty_msat(
371                 &self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, _source: &NodeId, _target: &NodeId
372         ) -> u64 {
373                 let failure_penalty_msat = self.channel_failures
374                         .get(&short_channel_id)
375                         .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life));
376
377                 let mut penalty_msat = self.params.base_penalty_msat + failure_penalty_msat;
378                 let send_1024ths = send_amt_msat.checked_mul(1024).unwrap_or(u64::max_value()) / capacity_msat;
379                 if send_1024ths > self.params.overuse_penalty_start_1024th as u64 {
380                         penalty_msat = penalty_msat.checked_add(
381                                         (send_1024ths - self.params.overuse_penalty_start_1024th as u64)
382                                         .checked_mul(self.params.overuse_penalty_msat_per_1024th).unwrap_or(u64::max_value()))
383                                 .unwrap_or(u64::max_value());
384                 }
385
386                 penalty_msat
387         }
388
389         fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) {
390                 let failure_penalty_msat = self.params.failure_penalty_msat;
391                 let half_life = self.params.failure_penalty_half_life;
392                 self.channel_failures
393                         .entry(short_channel_id)
394                         .and_modify(|failure| failure.add_penalty(failure_penalty_msat, half_life))
395                         .or_insert_with(|| ChannelFailure::new(failure_penalty_msat));
396         }
397
398         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
399                 let half_life = self.params.failure_penalty_half_life;
400                 for hop in path.iter() {
401                         self.channel_failures
402                                 .entry(hop.short_channel_id)
403                                 .and_modify(|failure| failure.reduce_penalty(half_life));
404                 }
405         }
406 }
407
408 impl<T: Time> Writeable for ScorerUsingTime<T> {
409         #[inline]
410         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
411                 self.params.write(w)?;
412                 self.channel_failures.write(w)?;
413                 write_tlv_fields!(w, {});
414                 Ok(())
415         }
416 }
417
418 impl<T: Time> Readable for ScorerUsingTime<T> {
419         #[inline]
420         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
421                 let res = Ok(Self {
422                         params: Readable::read(r)?,
423                         channel_failures: Readable::read(r)?,
424                 });
425                 read_tlv_fields!(r, {});
426                 res
427         }
428 }
429
430 impl<T: Time> Writeable for ChannelFailure<T> {
431         #[inline]
432         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
433                 let duration_since_epoch = T::duration_since_epoch() - self.last_updated.elapsed();
434                 write_tlv_fields!(w, {
435                         (0, self.undecayed_penalty_msat, required),
436                         (2, duration_since_epoch, required),
437                 });
438                 Ok(())
439         }
440 }
441
442 impl<T: Time> Readable for ChannelFailure<T> {
443         #[inline]
444         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
445                 let mut undecayed_penalty_msat = 0;
446                 let mut duration_since_epoch = Duration::from_secs(0);
447                 read_tlv_fields!(r, {
448                         (0, undecayed_penalty_msat, required),
449                         (2, duration_since_epoch, required),
450                 });
451                 Ok(Self {
452                         undecayed_penalty_msat,
453                         last_updated: T::now() - (T::duration_since_epoch() - duration_since_epoch),
454                 })
455         }
456 }
457
458 /// [`Score`] implementation using channel success probability distributions.
459 ///
460 /// Based on *Optimally Reliable & Cheap Payment Flows on the Lightning Network* by Rene Pickhardt
461 /// and Stefan Richter [[1]]. Given the uncertainty of channel liquidity balances, probability
462 /// distributions are defined based on knowledge learned from successful and unsuccessful attempts.
463 /// Then the negative log of the success probability is used to determine the cost of routing a
464 /// specific HTLC amount through a channel.
465 ///
466 /// [1]: https://arxiv.org/abs/2107.05322
467 pub struct ProbabilisticScorer<G: Deref<Target = NetworkGraph>> {
468         params: ProbabilisticScoringParameters,
469         node_id: NodeId,
470         network_graph: G,
471         // TODO: Remove entries of closed channels.
472         channel_liquidities: HashMap<u64, ChannelLiquidity>,
473 }
474
475 /// Parameters for configuring [`ProbabilisticScorer`].
476 pub struct ProbabilisticScoringParameters {
477         /// A penalty applied after multiplying by the negative log of the channel's success probability
478         /// for a payment.
479         ///
480         /// The success probability is determined by the effective channel capacity, the payment amount,
481         /// and knowledge learned from prior successful and unsuccessful payments. The lower bound of
482         /// the success probability is 0.01, effectively limiting the penalty to the range
483         /// `0..=2*liquidity_penalty_multiplier_msat`.
484         ///
485         /// Default value: 1,000 msat
486         pub liquidity_penalty_multiplier_msat: u64,
487 }
488
489 impl_writeable_tlv_based!(ProbabilisticScoringParameters, {
490         (0, liquidity_penalty_multiplier_msat, required),
491 });
492
493 /// Accounting for channel liquidity balance uncertainty.
494 ///
495 /// Direction is defined in terms of [`NodeId`] partial ordering, where the source node is the
496 /// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
497 /// offset fields gives the opposite direction.
498 struct ChannelLiquidity {
499         min_liquidity_offset_msat: u64,
500         max_liquidity_offset_msat: u64,
501 }
502
503 /// A view of [`ChannelLiquidity`] in one direction assuming a certain channel capacity.
504 struct DirectedChannelLiquidity<L: Deref<Target = u64>> {
505         min_liquidity_offset_msat: L,
506         max_liquidity_offset_msat: L,
507         capacity_msat: u64,
508 }
509
510 impl<G: Deref<Target = NetworkGraph>> ProbabilisticScorer<G> {
511         /// Creates a new scorer using the given scoring parameters for sending payments from a node
512         /// through a network graph.
513         pub fn new(
514                 params: ProbabilisticScoringParameters, node_pubkey: &PublicKey, network_graph: G
515         ) -> Self {
516                 Self {
517                         params,
518                         node_id: NodeId::from_pubkey(node_pubkey),
519                         network_graph,
520                         channel_liquidities: HashMap::new(),
521                 }
522         }
523
524         #[cfg(test)]
525         fn with_channel(mut self, short_channel_id: u64, liquidity: ChannelLiquidity) -> Self {
526                 assert!(self.channel_liquidities.insert(short_channel_id, liquidity).is_none());
527                 self
528         }
529 }
530
531 impl Default for ProbabilisticScoringParameters {
532         fn default() -> Self {
533                 Self {
534                         liquidity_penalty_multiplier_msat: 1000,
535                 }
536         }
537 }
538
539 impl ChannelLiquidity {
540         #[inline]
541         fn new() -> Self {
542                 Self {
543                         min_liquidity_offset_msat: 0,
544                         max_liquidity_offset_msat: 0,
545                 }
546         }
547
548         /// Returns a view of the channel liquidity directed from `source` to `target` assuming
549         /// `capacity_msat`.
550         fn as_directed(
551                 &self, source: &NodeId, target: &NodeId, capacity_msat: u64
552         ) -> DirectedChannelLiquidity<&u64> {
553                 let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target {
554                         (&self.min_liquidity_offset_msat, &self.max_liquidity_offset_msat)
555                 } else {
556                         (&self.max_liquidity_offset_msat, &self.min_liquidity_offset_msat)
557                 };
558
559                 DirectedChannelLiquidity {
560                         min_liquidity_offset_msat,
561                         max_liquidity_offset_msat,
562                         capacity_msat,
563                 }
564         }
565
566         /// Returns a mutable view of the channel liquidity directed from `source` to `target` assuming
567         /// `capacity_msat`.
568         fn as_directed_mut(
569                 &mut self, source: &NodeId, target: &NodeId, capacity_msat: u64
570         ) -> DirectedChannelLiquidity<&mut u64> {
571                 let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target {
572                         (&mut self.min_liquidity_offset_msat, &mut self.max_liquidity_offset_msat)
573                 } else {
574                         (&mut self.max_liquidity_offset_msat, &mut self.min_liquidity_offset_msat)
575                 };
576
577                 DirectedChannelLiquidity {
578                         min_liquidity_offset_msat,
579                         max_liquidity_offset_msat,
580                         capacity_msat,
581                 }
582         }
583 }
584
585 impl<L: Deref<Target = u64>> DirectedChannelLiquidity<L> {
586         /// Returns the success probability of routing the given HTLC `amount_msat` through the channel
587         /// in this direction.
588         fn success_probability(&self, amount_msat: u64) -> f64 {
589                 let max_liquidity_msat = self.max_liquidity_msat();
590                 let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat);
591                 if amount_msat > max_liquidity_msat {
592                         0.0
593                 } else if amount_msat <= min_liquidity_msat {
594                         1.0
595                 } else {
596                         let numerator = max_liquidity_msat + 1 - amount_msat;
597                         let denominator = max_liquidity_msat + 1 - min_liquidity_msat;
598                         numerator as f64 / denominator as f64
599                 }.max(0.01) // Lower bound the success probability to ensure some channel is selected.
600         }
601
602         /// Returns the lower bound of the channel liquidity balance in this direction.
603         fn min_liquidity_msat(&self) -> u64 {
604                 *self.min_liquidity_offset_msat
605         }
606
607         /// Returns the upper bound of the channel liquidity balance in this direction.
608         fn max_liquidity_msat(&self) -> u64 {
609                 self.capacity_msat.checked_sub(*self.max_liquidity_offset_msat).unwrap_or(0)
610         }
611 }
612
613 impl<L: DerefMut<Target = u64>> DirectedChannelLiquidity<L> {
614         /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat`.
615         fn failed_at_channel(&mut self, amount_msat: u64) {
616                 if amount_msat < self.max_liquidity_msat() {
617                         self.set_max_liquidity_msat(amount_msat);
618                 }
619         }
620
621         /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat` downstream.
622         fn failed_downstream(&mut self, amount_msat: u64) {
623                 if amount_msat > self.min_liquidity_msat() {
624                         self.set_min_liquidity_msat(amount_msat);
625                 }
626         }
627
628         /// Adjusts the channel liquidity balance bounds when successfully routing `amount_msat`.
629         fn successful(&mut self, amount_msat: u64) {
630                 let max_liquidity_msat = self.max_liquidity_msat().checked_sub(amount_msat).unwrap_or(0);
631                 self.set_max_liquidity_msat(max_liquidity_msat);
632         }
633
634         /// Adjusts the lower bound of the channel liquidity balance in this direction.
635         fn set_min_liquidity_msat(&mut self, amount_msat: u64) {
636                 *self.min_liquidity_offset_msat = amount_msat;
637
638                 if amount_msat > self.max_liquidity_msat() {
639                         *self.max_liquidity_offset_msat = 0;
640                 }
641         }
642
643         /// Adjusts the upper bound of the channel liquidity balance in this direction.
644         fn set_max_liquidity_msat(&mut self, amount_msat: u64) {
645                 *self.max_liquidity_offset_msat = self.capacity_msat.checked_sub(amount_msat).unwrap_or(0);
646
647                 if amount_msat < self.min_liquidity_msat() {
648                         *self.min_liquidity_offset_msat = 0;
649                 }
650         }
651 }
652
653 impl<G: Deref<Target = NetworkGraph>> Score for ProbabilisticScorer<G> {
654         #[allow(clippy::float_cmp)]
655         fn channel_penalty_msat(
656                 &self, short_channel_id: u64, amount_msat: u64, capacity_msat: u64, source: &NodeId,
657                 target: &NodeId
658         ) -> u64 {
659                 if *source == self.node_id || *target == self.node_id {
660                         return 0;
661                 }
662
663                 let liquidity_penalty_multiplier_msat = self.params.liquidity_penalty_multiplier_msat;
664                 let success_probability = self.channel_liquidities
665                         .get(&short_channel_id)
666                         .unwrap_or(&ChannelLiquidity::new())
667                         .as_directed(source, target, capacity_msat)
668                         .success_probability(amount_msat);
669                 if success_probability == 0.0 {
670                         u64::max_value()
671                 } else if success_probability == 1.0 {
672                         0
673                 } else {
674                         (-(success_probability.log10()) * liquidity_penalty_multiplier_msat as f64) as u64
675                 }
676         }
677
678         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
679                 let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
680                 let network_graph = self.network_graph.read_only();
681                 let hop_sources = core::iter::once(self.node_id)
682                         .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey)));
683                 for (source, hop) in hop_sources.zip(path.iter()) {
684                         let target = NodeId::from_pubkey(&hop.pubkey);
685                         if source == self.node_id || target == self.node_id {
686                                 continue;
687                         }
688
689                         let capacity_msat = network_graph.channels()
690                                 .get(&hop.short_channel_id)
691                                 .and_then(|channel| channel.as_directed_to(&target).map(|(d, _)| d.effective_capacity()))
692                                 .unwrap_or(EffectiveCapacity::Unknown)
693                                 .as_msat();
694
695                         if hop.short_channel_id == short_channel_id {
696                                 self.channel_liquidities
697                                         .entry(hop.short_channel_id)
698                                         .or_insert_with(|| ChannelLiquidity::new())
699                                         .as_directed_mut(&source, &target, capacity_msat)
700                                         .failed_at_channel(amount_msat);
701                                 break;
702                         }
703
704                         self.channel_liquidities
705                                 .entry(hop.short_channel_id)
706                                 .or_insert_with(|| ChannelLiquidity::new())
707                                 .as_directed_mut(&source, &target, capacity_msat)
708                                 .failed_downstream(amount_msat);
709                 }
710         }
711
712         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
713                 let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
714                 let network_graph = self.network_graph.read_only();
715                 let hop_sources = core::iter::once(self.node_id)
716                         .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey)));
717                 for (source, hop) in hop_sources.zip(path.iter()) {
718                         let target = NodeId::from_pubkey(&hop.pubkey);
719                         if source == self.node_id || target == self.node_id {
720                                 continue;
721                         }
722
723                         let capacity_msat = network_graph.channels()
724                                 .get(&hop.short_channel_id)
725                                 .and_then(|channel| channel.as_directed_to(&target).map(|(d, _)| d.effective_capacity()))
726                                 .unwrap_or(EffectiveCapacity::Unknown)
727                                 .as_msat();
728
729                         self.channel_liquidities
730                                 .entry(hop.short_channel_id)
731                                 .or_insert_with(|| ChannelLiquidity::new())
732                                 .as_directed_mut(&source, &target, capacity_msat)
733                                 .successful(amount_msat);
734                 }
735         }
736 }
737
738 impl<G: Deref<Target = NetworkGraph>> Writeable for ProbabilisticScorer<G> {
739         #[inline]
740         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
741                 self.channel_liquidities.write(w)?;
742                 write_tlv_fields!(w, {});
743                 Ok(())
744         }
745 }
746
747 impl<G: Deref<Target = NetworkGraph>> ReadableArgs<(ProbabilisticScoringParameters, &PublicKey, G)>
748 for ProbabilisticScorer<G> {
749         #[inline]
750         fn read<R: Read>(
751                 r: &mut R, args: (ProbabilisticScoringParameters, &PublicKey, G)
752         ) -> Result<Self, DecodeError> {
753                 let (params, node_pubkey, network_graph) = args;
754                 let res = Ok(Self {
755                         params,
756                         node_id: NodeId::from_pubkey(node_pubkey),
757                         network_graph,
758                         channel_liquidities: Readable::read(r)?,
759                 });
760                 read_tlv_fields!(r, {});
761                 res
762         }
763 }
764
765 impl Writeable for ChannelLiquidity {
766         #[inline]
767         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
768                 write_tlv_fields!(w, {
769                         (0, self.min_liquidity_offset_msat, required),
770                         (2, self.max_liquidity_offset_msat, required),
771                 });
772                 Ok(())
773         }
774 }
775
776 impl Readable for ChannelLiquidity {
777         #[inline]
778         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
779                 let mut min_liquidity_offset_msat = 0;
780                 let mut max_liquidity_offset_msat = 0;
781                 read_tlv_fields!(r, {
782                         (0, min_liquidity_offset_msat, required),
783                         (2, max_liquidity_offset_msat, required),
784                 });
785                 Ok(Self {
786                         min_liquidity_offset_msat,
787                         max_liquidity_offset_msat
788                 })
789         }
790 }
791
792 pub(crate) mod time {
793         use core::ops::Sub;
794         use core::time::Duration;
795         /// A measurement of time.
796         pub trait Time: Sub<Duration, Output = Self> where Self: Sized {
797                 /// Returns an instance corresponding to the current moment.
798                 fn now() -> Self;
799
800                 /// Returns the amount of time elapsed since `self` was created.
801                 fn elapsed(&self) -> Duration;
802
803                 /// Returns the amount of time passed since the beginning of [`Time`].
804                 ///
805                 /// Used during (de-)serialization.
806                 fn duration_since_epoch() -> Duration;
807         }
808
809         /// A state in which time has no meaning.
810         #[derive(Debug, PartialEq, Eq)]
811         pub struct Eternity;
812
813         #[cfg(not(feature = "no-std"))]
814         impl Time for std::time::Instant {
815                 fn now() -> Self {
816                         std::time::Instant::now()
817                 }
818
819                 fn duration_since_epoch() -> Duration {
820                         use std::time::SystemTime;
821                         SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap()
822                 }
823
824                 fn elapsed(&self) -> Duration {
825                         std::time::Instant::elapsed(self)
826                 }
827         }
828
829         impl Time for Eternity {
830                 fn now() -> Self {
831                         Self
832                 }
833
834                 fn duration_since_epoch() -> Duration {
835                         Duration::from_secs(0)
836                 }
837
838                 fn elapsed(&self) -> Duration {
839                         Duration::from_secs(0)
840                 }
841         }
842
843         impl Sub<Duration> for Eternity {
844                 type Output = Self;
845
846                 fn sub(self, _other: Duration) -> Self {
847                         self
848                 }
849         }
850 }
851
852 pub(crate) use self::time::Time;
853
854 #[cfg(test)]
855 mod tests {
856         use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorer, ScoringParameters, ScorerUsingTime, Time};
857         use super::time::Eternity;
858
859         use ln::features::{ChannelFeatures, NodeFeatures};
860         use ln::msgs::{ChannelAnnouncement, ChannelUpdate, OptionalField, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
861         use routing::scoring::Score;
862         use routing::network_graph::{NetworkGraph, NodeId};
863         use routing::router::RouteHop;
864         use util::ser::{Readable, Writeable};
865
866         use bitcoin::blockdata::constants::genesis_block;
867         use bitcoin::hashes::Hash;
868         use bitcoin::hashes::sha256d::Hash as Sha256dHash;
869         use bitcoin::network::constants::Network;
870         use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
871         use core::cell::Cell;
872         use core::ops::Sub;
873         use core::time::Duration;
874         use io;
875
876         // `Time` tests
877
878         /// Time that can be advanced manually in tests.
879         #[derive(Debug, PartialEq, Eq)]
880         struct SinceEpoch(Duration);
881
882         impl SinceEpoch {
883                 thread_local! {
884                         static ELAPSED: Cell<Duration> = core::cell::Cell::new(Duration::from_secs(0));
885                 }
886
887                 fn advance(duration: Duration) {
888                         Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration))
889                 }
890         }
891
892         impl Time for SinceEpoch {
893                 fn now() -> Self {
894                         Self(Self::duration_since_epoch())
895                 }
896
897                 fn duration_since_epoch() -> Duration {
898                         Self::ELAPSED.with(|elapsed| elapsed.get())
899                 }
900
901                 fn elapsed(&self) -> Duration {
902                         Self::duration_since_epoch() - self.0
903                 }
904         }
905
906         impl Sub<Duration> for SinceEpoch {
907                 type Output = Self;
908
909                 fn sub(self, other: Duration) -> Self {
910                         Self(self.0 - other)
911                 }
912         }
913
914         #[test]
915         fn time_passes_when_advanced() {
916                 let now = SinceEpoch::now();
917                 assert_eq!(now.elapsed(), Duration::from_secs(0));
918
919                 SinceEpoch::advance(Duration::from_secs(1));
920                 SinceEpoch::advance(Duration::from_secs(1));
921
922                 let elapsed = now.elapsed();
923                 let later = SinceEpoch::now();
924
925                 assert_eq!(elapsed, Duration::from_secs(2));
926                 assert_eq!(later - elapsed, now);
927         }
928
929         #[test]
930         fn time_never_passes_in_an_eternity() {
931                 let now = Eternity::now();
932                 let elapsed = now.elapsed();
933                 let later = Eternity::now();
934
935                 assert_eq!(now.elapsed(), Duration::from_secs(0));
936                 assert_eq!(later - elapsed, now);
937         }
938
939         // `Scorer` tests
940
941         /// A scorer for testing with time that can be manually advanced.
942         type Scorer = ScorerUsingTime::<SinceEpoch>;
943
944         fn source_privkey() -> SecretKey {
945                 SecretKey::from_slice(&[42; 32]).unwrap()
946         }
947
948         fn target_privkey() -> SecretKey {
949                 SecretKey::from_slice(&[43; 32]).unwrap()
950         }
951
952         fn source_pubkey() -> PublicKey {
953                 let secp_ctx = Secp256k1::new();
954                 PublicKey::from_secret_key(&secp_ctx, &source_privkey())
955         }
956
957         fn target_pubkey() -> PublicKey {
958                 let secp_ctx = Secp256k1::new();
959                 PublicKey::from_secret_key(&secp_ctx, &target_privkey())
960         }
961
962         fn source_node_id() -> NodeId {
963                 NodeId::from_pubkey(&source_pubkey())
964         }
965
966         fn target_node_id() -> NodeId {
967                 NodeId::from_pubkey(&target_pubkey())
968         }
969
970         #[test]
971         fn penalizes_without_channel_failures() {
972                 let scorer = Scorer::new(ScoringParameters {
973                         base_penalty_msat: 1_000,
974                         failure_penalty_msat: 512,
975                         failure_penalty_half_life: Duration::from_secs(1),
976                         overuse_penalty_start_1024th: 1024,
977                         overuse_penalty_msat_per_1024th: 0,
978                 });
979                 let source = source_node_id();
980                 let target = target_node_id();
981                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
982
983                 SinceEpoch::advance(Duration::from_secs(1));
984                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
985         }
986
987         #[test]
988         fn accumulates_channel_failure_penalties() {
989                 let mut scorer = Scorer::new(ScoringParameters {
990                         base_penalty_msat: 1_000,
991                         failure_penalty_msat: 64,
992                         failure_penalty_half_life: Duration::from_secs(10),
993                         overuse_penalty_start_1024th: 1024,
994                         overuse_penalty_msat_per_1024th: 0,
995                 });
996                 let source = source_node_id();
997                 let target = target_node_id();
998                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
999
1000                 scorer.payment_path_failed(&[], 42);
1001                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
1002
1003                 scorer.payment_path_failed(&[], 42);
1004                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1005
1006                 scorer.payment_path_failed(&[], 42);
1007                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_192);
1008         }
1009
1010         #[test]
1011         fn decays_channel_failure_penalties_over_time() {
1012                 let mut scorer = Scorer::new(ScoringParameters {
1013                         base_penalty_msat: 1_000,
1014                         failure_penalty_msat: 512,
1015                         failure_penalty_half_life: Duration::from_secs(10),
1016                         overuse_penalty_start_1024th: 1024,
1017                         overuse_penalty_msat_per_1024th: 0,
1018                 });
1019                 let source = source_node_id();
1020                 let target = target_node_id();
1021                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1022
1023                 scorer.payment_path_failed(&[], 42);
1024                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1025
1026                 SinceEpoch::advance(Duration::from_secs(9));
1027                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1028
1029                 SinceEpoch::advance(Duration::from_secs(1));
1030                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1031
1032                 SinceEpoch::advance(Duration::from_secs(10 * 8));
1033                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_001);
1034
1035                 SinceEpoch::advance(Duration::from_secs(10));
1036                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1037
1038                 SinceEpoch::advance(Duration::from_secs(10));
1039                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1040         }
1041
1042         #[test]
1043         fn decays_channel_failure_penalties_without_shift_overflow() {
1044                 let mut scorer = Scorer::new(ScoringParameters {
1045                         base_penalty_msat: 1_000,
1046                         failure_penalty_msat: 512,
1047                         failure_penalty_half_life: Duration::from_secs(10),
1048                         overuse_penalty_start_1024th: 1024,
1049                         overuse_penalty_msat_per_1024th: 0,
1050                 });
1051                 let source = source_node_id();
1052                 let target = target_node_id();
1053                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1054
1055                 scorer.payment_path_failed(&[], 42);
1056                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1057
1058                 // An unchecked right shift 64 bits or more in ChannelFailure::decayed_penalty_msat would
1059                 // cause an overflow.
1060                 SinceEpoch::advance(Duration::from_secs(10 * 64));
1061                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1062
1063                 SinceEpoch::advance(Duration::from_secs(10));
1064                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1065         }
1066
1067         #[test]
1068         fn accumulates_channel_failure_penalties_after_decay() {
1069                 let mut scorer = Scorer::new(ScoringParameters {
1070                         base_penalty_msat: 1_000,
1071                         failure_penalty_msat: 512,
1072                         failure_penalty_half_life: Duration::from_secs(10),
1073                         overuse_penalty_start_1024th: 1024,
1074                         overuse_penalty_msat_per_1024th: 0,
1075                 });
1076                 let source = source_node_id();
1077                 let target = target_node_id();
1078                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1079
1080                 scorer.payment_path_failed(&[], 42);
1081                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1082
1083                 SinceEpoch::advance(Duration::from_secs(10));
1084                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1085
1086                 scorer.payment_path_failed(&[], 42);
1087                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_768);
1088
1089                 SinceEpoch::advance(Duration::from_secs(10));
1090                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_384);
1091         }
1092
1093         #[test]
1094         fn reduces_channel_failure_penalties_after_success() {
1095                 let mut scorer = Scorer::new(ScoringParameters {
1096                         base_penalty_msat: 1_000,
1097                         failure_penalty_msat: 512,
1098                         failure_penalty_half_life: Duration::from_secs(10),
1099                         overuse_penalty_start_1024th: 1024,
1100                         overuse_penalty_msat_per_1024th: 0,
1101                 });
1102                 let source = source_node_id();
1103                 let target = target_node_id();
1104                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1105
1106                 scorer.payment_path_failed(&[], 42);
1107                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1108
1109                 SinceEpoch::advance(Duration::from_secs(10));
1110                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1111
1112                 let hop = RouteHop {
1113                         pubkey: PublicKey::from_slice(target.as_slice()).unwrap(),
1114                         node_features: NodeFeatures::known(),
1115                         short_channel_id: 42,
1116                         channel_features: ChannelFeatures::known(),
1117                         fee_msat: 1,
1118                         cltv_expiry_delta: 18,
1119                 };
1120                 scorer.payment_path_successful(&[&hop]);
1121                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1122
1123                 SinceEpoch::advance(Duration::from_secs(10));
1124                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
1125         }
1126
1127         #[test]
1128         fn restores_persisted_channel_failure_penalties() {
1129                 let mut scorer = Scorer::new(ScoringParameters {
1130                         base_penalty_msat: 1_000,
1131                         failure_penalty_msat: 512,
1132                         failure_penalty_half_life: Duration::from_secs(10),
1133                         overuse_penalty_start_1024th: 1024,
1134                         overuse_penalty_msat_per_1024th: 0,
1135                 });
1136                 let source = source_node_id();
1137                 let target = target_node_id();
1138
1139                 scorer.payment_path_failed(&[], 42);
1140                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1141
1142                 SinceEpoch::advance(Duration::from_secs(10));
1143                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1144
1145                 scorer.payment_path_failed(&[], 43);
1146                 assert_eq!(scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
1147
1148                 let mut serialized_scorer = Vec::new();
1149                 scorer.write(&mut serialized_scorer).unwrap();
1150
1151                 let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
1152                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1153                 assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
1154         }
1155
1156         #[test]
1157         fn decays_persisted_channel_failure_penalties() {
1158                 let mut scorer = Scorer::new(ScoringParameters {
1159                         base_penalty_msat: 1_000,
1160                         failure_penalty_msat: 512,
1161                         failure_penalty_half_life: Duration::from_secs(10),
1162                         overuse_penalty_start_1024th: 1024,
1163                         overuse_penalty_msat_per_1024th: 0,
1164                 });
1165                 let source = source_node_id();
1166                 let target = target_node_id();
1167
1168                 scorer.payment_path_failed(&[], 42);
1169                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1170
1171                 let mut serialized_scorer = Vec::new();
1172                 scorer.write(&mut serialized_scorer).unwrap();
1173
1174                 SinceEpoch::advance(Duration::from_secs(10));
1175
1176                 let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
1177                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1178
1179                 SinceEpoch::advance(Duration::from_secs(10));
1180                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1181         }
1182
1183         #[test]
1184         fn charges_per_1024th_penalty() {
1185                 let scorer = Scorer::new(ScoringParameters {
1186                         base_penalty_msat: 0,
1187                         failure_penalty_msat: 0,
1188                         failure_penalty_half_life: Duration::from_secs(0),
1189                         overuse_penalty_start_1024th: 256,
1190                         overuse_penalty_msat_per_1024th: 100,
1191                 });
1192                 let source = source_node_id();
1193                 let target = target_node_id();
1194
1195                 assert_eq!(scorer.channel_penalty_msat(42, 1_000, 1_024_000, &source, &target), 0);
1196                 assert_eq!(scorer.channel_penalty_msat(42, 256_999, 1_024_000, &source, &target), 0);
1197                 assert_eq!(scorer.channel_penalty_msat(42, 257_000, 1_024_000, &source, &target), 100);
1198                 assert_eq!(scorer.channel_penalty_msat(42, 258_000, 1_024_000, &source, &target), 200);
1199                 assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 256 * 100);
1200         }
1201
1202         // `ProbabilisticScorer` tests
1203
1204         fn sender_privkey() -> SecretKey {
1205                 SecretKey::from_slice(&[41; 32]).unwrap()
1206         }
1207
1208         fn recipient_privkey() -> SecretKey {
1209                 SecretKey::from_slice(&[45; 32]).unwrap()
1210         }
1211
1212         fn sender_pubkey() -> PublicKey {
1213                 let secp_ctx = Secp256k1::new();
1214                 PublicKey::from_secret_key(&secp_ctx, &sender_privkey())
1215         }
1216
1217         fn recipient_pubkey() -> PublicKey {
1218                 let secp_ctx = Secp256k1::new();
1219                 PublicKey::from_secret_key(&secp_ctx, &recipient_privkey())
1220         }
1221
1222         fn sender_node_id() -> NodeId {
1223                 NodeId::from_pubkey(&sender_pubkey())
1224         }
1225
1226         fn recipient_node_id() -> NodeId {
1227                 NodeId::from_pubkey(&recipient_pubkey())
1228         }
1229
1230         fn network_graph() -> NetworkGraph {
1231                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1232                 let mut network_graph = NetworkGraph::new(genesis_hash);
1233                 add_channel(&mut network_graph, 41, sender_privkey(), source_privkey());
1234                 add_channel(&mut network_graph, 42, source_privkey(), target_privkey());
1235                 add_channel(&mut network_graph, 43, target_privkey(), recipient_privkey());
1236
1237                 network_graph
1238         }
1239
1240         fn add_channel(
1241                 network_graph: &mut NetworkGraph, short_channel_id: u64, node_1_key: SecretKey,
1242                 node_2_key: SecretKey
1243         ) {
1244                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1245                 let node_1_secret = &SecretKey::from_slice(&[39; 32]).unwrap();
1246                 let node_2_secret = &SecretKey::from_slice(&[40; 32]).unwrap();
1247                 let secp_ctx = Secp256k1::new();
1248                 let unsigned_announcement = UnsignedChannelAnnouncement {
1249                         features: ChannelFeatures::known(),
1250                         chain_hash: genesis_hash,
1251                         short_channel_id,
1252                         node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_key),
1253                         node_id_2: PublicKey::from_secret_key(&secp_ctx, &node_2_key),
1254                         bitcoin_key_1: PublicKey::from_secret_key(&secp_ctx, &node_1_secret),
1255                         bitcoin_key_2: PublicKey::from_secret_key(&secp_ctx, &node_2_secret),
1256                         excess_data: Vec::new(),
1257                 };
1258                 let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_announcement.encode()[..])[..]);
1259                 let signed_announcement = ChannelAnnouncement {
1260                         node_signature_1: secp_ctx.sign(&msghash, &node_1_key),
1261                         node_signature_2: secp_ctx.sign(&msghash, &node_2_key),
1262                         bitcoin_signature_1: secp_ctx.sign(&msghash, &node_1_secret),
1263                         bitcoin_signature_2: secp_ctx.sign(&msghash, &node_2_secret),
1264                         contents: unsigned_announcement,
1265                 };
1266                 let chain_source: Option<&::util::test_utils::TestChainSource> = None;
1267                 network_graph.update_channel_from_announcement(
1268                         &signed_announcement, &chain_source, &secp_ctx).unwrap();
1269                 update_channel(network_graph, short_channel_id, node_1_key, 0);
1270                 update_channel(network_graph, short_channel_id, node_2_key, 1);
1271         }
1272
1273         fn update_channel(
1274                 network_graph: &mut NetworkGraph, short_channel_id: u64, node_key: SecretKey, flags: u8
1275         ) {
1276                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1277                 let secp_ctx = Secp256k1::new();
1278                 let unsigned_update = UnsignedChannelUpdate {
1279                         chain_hash: genesis_hash,
1280                         short_channel_id,
1281                         timestamp: 100,
1282                         flags,
1283                         cltv_expiry_delta: 18,
1284                         htlc_minimum_msat: 0,
1285                         htlc_maximum_msat: OptionalField::Present(1_000),
1286                         fee_base_msat: 1,
1287                         fee_proportional_millionths: 0,
1288                         excess_data: Vec::new(),
1289                 };
1290                 let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_update.encode()[..])[..]);
1291                 let signed_update = ChannelUpdate {
1292                         signature: secp_ctx.sign(&msghash, &node_key),
1293                         contents: unsigned_update,
1294                 };
1295                 network_graph.update_channel(&signed_update, &secp_ctx).unwrap();
1296         }
1297
1298         fn payment_path_for_amount(amount_msat: u64) -> Vec<RouteHop> {
1299                 vec![
1300                         RouteHop {
1301                                 pubkey: source_pubkey(),
1302                                 node_features: NodeFeatures::known(),
1303                                 short_channel_id: 41,
1304                                 channel_features: ChannelFeatures::known(),
1305                                 fee_msat: 1,
1306                                 cltv_expiry_delta: 18,
1307                         },
1308                         RouteHop {
1309                                 pubkey: target_pubkey(),
1310                                 node_features: NodeFeatures::known(),
1311                                 short_channel_id: 42,
1312                                 channel_features: ChannelFeatures::known(),
1313                                 fee_msat: 2,
1314                                 cltv_expiry_delta: 18,
1315                         },
1316                         RouteHop {
1317                                 pubkey: recipient_pubkey(),
1318                                 node_features: NodeFeatures::known(),
1319                                 short_channel_id: 43,
1320                                 channel_features: ChannelFeatures::known(),
1321                                 fee_msat: amount_msat,
1322                                 cltv_expiry_delta: 18,
1323                         },
1324                 ]
1325         }
1326
1327         #[test]
1328         fn liquidity_bounds_directed_from_lowest_node_id() {
1329                 let network_graph = network_graph();
1330                 let params = ProbabilisticScoringParameters::default();
1331                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1332                         .with_channel(42,
1333                                 ChannelLiquidity {
1334                                         min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
1335                                 })
1336                         .with_channel(43,
1337                                 ChannelLiquidity {
1338                                         min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
1339                                 });
1340                 let source = source_node_id();
1341                 let target = target_node_id();
1342                 let recipient = recipient_node_id();
1343
1344                 let liquidity = scorer.channel_liquidities.get_mut(&42).unwrap();
1345                 assert!(source > target);
1346                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 100);
1347                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300);
1348                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700);
1349                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 900);
1350
1351                 liquidity.as_directed_mut(&source, &target, 1_000).set_min_liquidity_msat(200);
1352                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 200);
1353                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300);
1354                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700);
1355                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 800);
1356
1357                 let liquidity = scorer.channel_liquidities.get_mut(&43).unwrap();
1358                 assert!(target < recipient);
1359                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 700);
1360                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 900);
1361                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 100);
1362                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 300);
1363
1364                 liquidity.as_directed_mut(&target, &recipient, 1_000).set_max_liquidity_msat(200);
1365                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 0);
1366                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 200);
1367                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 800);
1368                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 1000);
1369         }
1370
1371         #[test]
1372         fn resets_liquidity_upper_bound_when_crossed_by_lower_bound() {
1373                 let network_graph = network_graph();
1374                 let params = ProbabilisticScoringParameters::default();
1375                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1376                         .with_channel(42,
1377                                 ChannelLiquidity {
1378                                         min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
1379                                 });
1380                 let source = source_node_id();
1381                 let target = target_node_id();
1382                 assert!(source > target);
1383
1384                 // Check initial bounds.
1385                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1386                         .as_directed(&source, &target, 1_000);
1387                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1388                 assert_eq!(liquidity.max_liquidity_msat(), 800);
1389
1390                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1391                         .as_directed(&target, &source, 1_000);
1392                 assert_eq!(liquidity.min_liquidity_msat(), 200);
1393                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1394
1395                 // Reset from source to target.
1396                 scorer.channel_liquidities.get_mut(&42).unwrap()
1397                         .as_directed_mut(&source, &target, 1_000)
1398                         .set_min_liquidity_msat(900);
1399
1400                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1401                         .as_directed(&source, &target, 1_000);
1402                 assert_eq!(liquidity.min_liquidity_msat(), 900);
1403                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1404
1405                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1406                         .as_directed(&target, &source, 1_000);
1407                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1408                 assert_eq!(liquidity.max_liquidity_msat(), 100);
1409
1410                 // Reset from target to source.
1411                 scorer.channel_liquidities.get_mut(&42).unwrap()
1412                         .as_directed_mut(&target, &source, 1_000)
1413                         .set_min_liquidity_msat(400);
1414
1415                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1416                         .as_directed(&source, &target, 1_000);
1417                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1418                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1419
1420                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1421                         .as_directed(&target, &source, 1_000);
1422                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1423                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1424         }
1425
1426         #[test]
1427         fn resets_liquidity_lower_bound_when_crossed_by_upper_bound() {
1428                 let network_graph = network_graph();
1429                 let params = ProbabilisticScoringParameters::default();
1430                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1431                         .with_channel(42,
1432                                 ChannelLiquidity {
1433                                         min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
1434                                 });
1435                 let source = source_node_id();
1436                 let target = target_node_id();
1437                 assert!(source > target);
1438
1439                 // Check initial bounds.
1440                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1441                         .as_directed(&source, &target, 1_000);
1442                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1443                 assert_eq!(liquidity.max_liquidity_msat(), 800);
1444
1445                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1446                         .as_directed(&target, &source, 1_000);
1447                 assert_eq!(liquidity.min_liquidity_msat(), 200);
1448                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1449
1450                 // Reset from source to target.
1451                 scorer.channel_liquidities.get_mut(&42).unwrap()
1452                         .as_directed_mut(&source, &target, 1_000)
1453                         .set_max_liquidity_msat(300);
1454
1455                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1456                         .as_directed(&source, &target, 1_000);
1457                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1458                 assert_eq!(liquidity.max_liquidity_msat(), 300);
1459
1460                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1461                         .as_directed(&target, &source, 1_000);
1462                 assert_eq!(liquidity.min_liquidity_msat(), 700);
1463                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1464
1465                 // Reset from target to source.
1466                 scorer.channel_liquidities.get_mut(&42).unwrap()
1467                         .as_directed_mut(&target, &source, 1_000)
1468                         .set_max_liquidity_msat(600);
1469
1470                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1471                         .as_directed(&source, &target, 1_000);
1472                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1473                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1474
1475                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1476                         .as_directed(&target, &source, 1_000);
1477                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1478                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1479         }
1480
1481         #[test]
1482         fn increased_penalty_nearing_liquidity_upper_bound() {
1483                 let network_graph = network_graph();
1484                 let params = ProbabilisticScoringParameters::default();
1485                 let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1486                 let source = source_node_id();
1487                 let target = target_node_id();
1488
1489                 assert_eq!(scorer.channel_penalty_msat(42, 100, 100_000, &source, &target), 0);
1490                 assert_eq!(scorer.channel_penalty_msat(42, 1_000, 100_000, &source, &target), 4);
1491                 assert_eq!(scorer.channel_penalty_msat(42, 10_000, 100_000, &source, &target), 45);
1492                 assert_eq!(scorer.channel_penalty_msat(42, 100_000, 100_000, &source, &target), 2_000);
1493
1494                 assert_eq!(scorer.channel_penalty_msat(42, 125, 1_000, &source, &target), 57);
1495                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1496                 assert_eq!(scorer.channel_penalty_msat(42, 375, 1_000, &source, &target), 203);
1497                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1498                 assert_eq!(scorer.channel_penalty_msat(42, 625, 1_000, &source, &target), 425);
1499                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1500                 assert_eq!(scorer.channel_penalty_msat(42, 875, 1_000, &source, &target), 900);
1501         }
1502
1503         #[test]
1504         fn constant_penalty_outside_liquidity_bounds() {
1505                 let network_graph = network_graph();
1506                 let params = ProbabilisticScoringParameters::default();
1507                 let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1508                         .with_channel(42,
1509                                 ChannelLiquidity { min_liquidity_offset_msat: 40, max_liquidity_offset_msat: 40 });
1510                 let source = source_node_id();
1511                 let target = target_node_id();
1512
1513                 assert_eq!(scorer.channel_penalty_msat(42, 39, 100, &source, &target), 0);
1514                 assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 0);
1515                 assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 2_000);
1516                 assert_eq!(scorer.channel_penalty_msat(42, 61, 100, &source, &target), 2_000);
1517         }
1518
1519         #[test]
1520         fn does_not_penalize_own_channel() {
1521                 let network_graph = network_graph();
1522                 let params = ProbabilisticScoringParameters::default();
1523                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1524                 let sender = sender_node_id();
1525                 let source = source_node_id();
1526                 let failed_path = payment_path_for_amount(500);
1527                 let successful_path = payment_path_for_amount(200);
1528
1529                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1530
1531                 scorer.payment_path_failed(&failed_path.iter().collect::<Vec<_>>(), 41);
1532                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1533
1534                 scorer.payment_path_successful(&successful_path.iter().collect::<Vec<_>>());
1535                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1536         }
1537
1538         #[test]
1539         fn sets_liquidity_lower_bound_on_downstream_failure() {
1540                 let network_graph = network_graph();
1541                 let params = ProbabilisticScoringParameters::default();
1542                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1543                 let source = source_node_id();
1544                 let target = target_node_id();
1545                 let path = payment_path_for_amount(500);
1546
1547                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1548                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1549                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1550
1551                 scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 43);
1552
1553                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 0);
1554                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 0);
1555                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 300);
1556         }
1557
1558         #[test]
1559         fn sets_liquidity_upper_bound_on_failure() {
1560                 let network_graph = network_graph();
1561                 let params = ProbabilisticScoringParameters::default();
1562                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1563                 let source = source_node_id();
1564                 let target = target_node_id();
1565                 let path = payment_path_for_amount(500);
1566
1567                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1568                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1569                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1570
1571                 scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 42);
1572
1573                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
1574                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 2_000);
1575                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 2_000);
1576         }
1577
1578         #[test]
1579         fn reduces_liquidity_upper_bound_along_path_on_success() {
1580                 let network_graph = network_graph();
1581                 let params = ProbabilisticScoringParameters::default();
1582                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1583                 let sender = sender_node_id();
1584                 let source = source_node_id();
1585                 let target = target_node_id();
1586                 let recipient = recipient_node_id();
1587                 let path = payment_path_for_amount(500);
1588
1589                 assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 0);
1590                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1591                 assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 124);
1592
1593                 scorer.payment_path_successful(&path.iter().collect::<Vec<_>>());
1594
1595                 assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 0);
1596                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
1597                 assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 300);
1598         }
1599 }