5faee4125b2db0c47de9d27bb498eea6419e63f7
[rust-lightning] / lightning / src / routing / scoring.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Utilities for scoring payment channels.
11 //!
12 //! [`ProbabilisticScorer`] may be given to [`find_route`] to score payment channels during path
13 //! finding when a custom [`Score`] implementation is not needed.
14 //!
15 //! # Example
16 //!
17 //! ```
18 //! # extern crate secp256k1;
19 //! #
20 //! # use lightning::routing::network_graph::NetworkGraph;
21 //! # use lightning::routing::router::{RouteParameters, find_route};
22 //! # use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters, Scorer, ScoringParameters};
23 //! # use lightning::util::logger::{Logger, Record};
24 //! # use secp256k1::key::PublicKey;
25 //! #
26 //! # struct FakeLogger {};
27 //! # impl Logger for FakeLogger {
28 //! #     fn log(&self, record: &Record) { unimplemented!() }
29 //! # }
30 //! # fn find_scored_route(payer: PublicKey, route_params: RouteParameters, network_graph: NetworkGraph) {
31 //! # let logger = FakeLogger {};
32 //! #
33 //! // Use the default channel penalties.
34 //! let params = ProbabilisticScoringParameters::default();
35 //! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
36 //!
37 //! // Or use custom channel penalties.
38 //! let params = ProbabilisticScoringParameters {
39 //!     liquidity_penalty_multiplier_msat: 2 * 1000,
40 //!     ..ProbabilisticScoringParameters::default()
41 //! };
42 //! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
43 //!
44 //! let route = find_route(&payer, &route_params, &network_graph, None, &logger, &scorer);
45 //! # }
46 //! ```
47 //!
48 //! # Note
49 //!
50 //! Persisting when built with feature `no-std` and restoring without it, or vice versa, uses
51 //! different types and thus is undefined.
52 //!
53 //! [`find_route`]: crate::routing::router::find_route
54
55 use bitcoin::secp256k1::key::PublicKey;
56
57 use ln::msgs::DecodeError;
58 use routing::network_graph::{EffectiveCapacity, NetworkGraph, NodeId};
59 use routing::router::RouteHop;
60 use util::ser::{Readable, ReadableArgs, Writeable, Writer};
61
62 use prelude::*;
63 use core::cell::{RefCell, RefMut};
64 use core::ops::{Deref, DerefMut};
65 use core::time::Duration;
66 use io::{self, Read};
67 use sync::{Mutex, MutexGuard};
68
69 /// We define Score ever-so-slightly differently based on whether we are being built for C bindings
70 /// or not. For users, `LockableScore` must somehow be writeable to disk. For Rust users, this is
71 /// no problem - you move a `Score` that implements `Writeable` into a `Mutex`, lock it, and now
72 /// you have the original, concrete, `Score` type, which presumably implements `Writeable`.
73 ///
74 /// For C users, once you've moved the `Score` into a `LockableScore` all you have after locking it
75 /// is an opaque trait object with an opaque pointer with no type info. Users could take the unsafe
76 /// approach of blindly casting that opaque pointer to a concrete type and calling `Writeable` from
77 /// there, but other languages downstream of the C bindings (e.g. Java) can't even do that.
78 /// Instead, we really want `Score` and `LockableScore` to implement `Writeable` directly, which we
79 /// do here by defining `Score` differently for `cfg(c_bindings)`.
80 macro_rules! define_score { ($($supertrait: path)*) => {
81 /// An interface used to score payment channels for path finding.
82 ///
83 ///     Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
84 pub trait Score $(: $supertrait)* {
85         /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the
86         /// given channel in the direction from `source` to `target`.
87         ///
88         /// The channel's capacity (less any other MPP parts that are also being considered for use in
89         /// the same payment) is given by `capacity_msat`. It may be determined from various sources
90         /// such as a chain data, network gossip, or invoice hints. For invoice hints, a capacity near
91         /// [`u64::max_value`] is given to indicate sufficient capacity for the invoice's full amount.
92         /// Thus, implementations should be overflow-safe.
93         fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64;
94
95         /// Handles updating channel penalties after failing to route through a channel.
96         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64);
97
98         /// Handles updating channel penalties after successfully routing along a path.
99         fn payment_path_successful(&mut self, path: &[&RouteHop]);
100 }
101
102 impl<S: Score, T: DerefMut<Target=S> $(+ $supertrait)*> Score for T {
103         fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64 {
104                 self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, capacity_msat, source, target)
105         }
106
107         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
108                 self.deref_mut().payment_path_failed(path, short_channel_id)
109         }
110
111         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
112                 self.deref_mut().payment_path_successful(path)
113         }
114 }
115 } }
116
117 #[cfg(c_bindings)]
118 define_score!(Writeable);
119 #[cfg(not(c_bindings))]
120 define_score!();
121
122 /// A scorer that is accessed under a lock.
123 ///
124 /// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
125 /// having shared ownership of a scorer but without requiring internal locking in [`Score`]
126 /// implementations. Internal locking would be detrimental to route finding performance and could
127 /// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
128 ///
129 /// [`find_route`]: crate::routing::router::find_route
130 pub trait LockableScore<'a> {
131         /// The locked [`Score`] type.
132         type Locked: 'a + Score;
133
134         /// Returns the locked scorer.
135         fn lock(&'a self) -> Self::Locked;
136 }
137
138 /// (C-not exported)
139 impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
140         type Locked = MutexGuard<'a, T>;
141
142         fn lock(&'a self) -> MutexGuard<'a, T> {
143                 Mutex::lock(self).unwrap()
144         }
145 }
146
147 impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
148         type Locked = RefMut<'a, T>;
149
150         fn lock(&'a self) -> RefMut<'a, T> {
151                 self.borrow_mut()
152         }
153 }
154
155 #[cfg(c_bindings)]
156 /// A concrete implementation of [`LockableScore`] which supports multi-threading.
157 pub struct MultiThreadedLockableScore<S: Score> {
158         score: Mutex<S>,
159 }
160 #[cfg(c_bindings)]
161 /// (C-not exported)
162 impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
163         type Locked = MutexGuard<'a, T>;
164
165         fn lock(&'a self) -> MutexGuard<'a, T> {
166                 Mutex::lock(&self.score).unwrap()
167         }
168 }
169
170 #[cfg(c_bindings)]
171 impl<T: Score> MultiThreadedLockableScore<T> {
172         /// Creates a new [`MultiThreadedLockableScore`] given an underlying [`Score`].
173         pub fn new(score: T) -> Self {
174                 MultiThreadedLockableScore { score: Mutex::new(score) }
175         }
176 }
177
178 #[cfg(c_bindings)]
179 /// (C-not exported)
180 impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
181         fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
182                 T::write(&**self, writer)
183         }
184 }
185
186 #[cfg(c_bindings)]
187 /// (C-not exported)
188 impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> {
189         fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
190                 S::write(&**self, writer)
191         }
192 }
193
194 /// [`Score`] implementation that provides reasonable default behavior.
195 ///
196 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
197 /// slightly higher fees are available. Will further penalize channels that fail to relay payments.
198 ///
199 /// See [module-level documentation] for usage.
200 ///
201 /// [module-level documentation]: crate::routing::scoring
202 #[cfg(not(feature = "no-std"))]
203 pub type Scorer = ScorerUsingTime::<std::time::Instant>;
204 /// [`Score`] implementation that provides reasonable default behavior.
205 ///
206 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
207 /// slightly higher fees are available. Will further penalize channels that fail to relay payments.
208 ///
209 /// See [module-level documentation] for usage and [`ScoringParameters`] for customization.
210 ///
211 /// [module-level documentation]: crate::routing::scoring
212 #[cfg(feature = "no-std")]
213 pub type Scorer = ScorerUsingTime::<time::Eternity>;
214
215 // Note that ideally we'd hide ScorerUsingTime from public view by sealing it as well, but rustdoc
216 // doesn't handle this well - instead exposing a `Scorer` which has no trait implementation(s) or
217 // methods at all.
218
219 /// [`Score`] implementation.
220 ///
221 /// See [`Scorer`] for details.
222 ///
223 /// # Note
224 ///
225 /// Mixing the `no-std` feature between serialization and deserialization results in undefined
226 /// behavior.
227 ///
228 /// (C-not exported) generally all users should use the [`Scorer`] type alias.
229 pub struct ScorerUsingTime<T: Time> {
230         params: ScoringParameters,
231         // TODO: Remove entries of closed channels.
232         channel_failures: HashMap<u64, ChannelFailure<T>>,
233 }
234
235 /// Parameters for configuring [`Scorer`].
236 pub struct ScoringParameters {
237         /// A fixed penalty in msats to apply to each channel.
238         ///
239         /// Default value: 500 msat
240         pub base_penalty_msat: u64,
241
242         /// A penalty in msats to apply to a channel upon failing to relay a payment.
243         ///
244         /// This accumulates for each failure but may be reduced over time based on
245         /// [`failure_penalty_half_life`] or when successfully routing through a channel.
246         ///
247         /// Default value: 1,024,000 msat
248         ///
249         /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life
250         pub failure_penalty_msat: u64,
251
252         /// When the amount being sent over a channel is this many 1024ths of the total channel
253         /// capacity, we begin applying [`overuse_penalty_msat_per_1024th`].
254         ///
255         /// Default value: 128 1024ths (i.e. begin penalizing when an HTLC uses 1/8th of a channel)
256         ///
257         /// [`overuse_penalty_msat_per_1024th`]: Self::overuse_penalty_msat_per_1024th
258         pub overuse_penalty_start_1024th: u16,
259
260         /// A penalty applied, per whole 1024ths of the channel capacity which the amount being sent
261         /// over the channel exceeds [`overuse_penalty_start_1024th`] by.
262         ///
263         /// Default value: 20 msat (i.e. 2560 msat penalty to use 1/4th of a channel, 7680 msat penalty
264         ///                to use half a channel, and 12,560 msat penalty to use 3/4ths of a channel)
265         ///
266         /// [`overuse_penalty_start_1024th`]: Self::overuse_penalty_start_1024th
267         pub overuse_penalty_msat_per_1024th: u64,
268
269         /// The time required to elapse before any accumulated [`failure_penalty_msat`] penalties are
270         /// cut in half.
271         ///
272         /// Successfully routing through a channel will immediately cut the penalty in half as well.
273         ///
274         /// # Note
275         ///
276         /// When built with the `no-std` feature, time will never elapse. Therefore, this penalty will
277         /// never decay.
278         ///
279         /// [`failure_penalty_msat`]: Self::failure_penalty_msat
280         pub failure_penalty_half_life: Duration,
281 }
282
283 impl_writeable_tlv_based!(ScoringParameters, {
284         (0, base_penalty_msat, required),
285         (1, overuse_penalty_start_1024th, (default_value, 128)),
286         (2, failure_penalty_msat, required),
287         (3, overuse_penalty_msat_per_1024th, (default_value, 20)),
288         (4, failure_penalty_half_life, required),
289 });
290
291 /// Accounting for penalties against a channel for failing to relay any payments.
292 ///
293 /// Penalties decay over time, though accumulate as more failures occur.
294 struct ChannelFailure<T: Time> {
295         /// Accumulated penalty in msats for the channel as of `last_updated`.
296         undecayed_penalty_msat: u64,
297
298         /// Last time the channel either failed to route or successfully routed a payment. Used to decay
299         /// `undecayed_penalty_msat`.
300         last_updated: T,
301 }
302
303 impl<T: Time> ScorerUsingTime<T> {
304         /// Creates a new scorer using the given scoring parameters.
305         pub fn new(params: ScoringParameters) -> Self {
306                 Self {
307                         params,
308                         channel_failures: HashMap::new(),
309                 }
310         }
311
312         /// Creates a new scorer using `penalty_msat` as a fixed channel penalty.
313         #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
314         pub fn with_fixed_penalty(penalty_msat: u64) -> Self {
315                 Self::new(ScoringParameters {
316                         base_penalty_msat: penalty_msat,
317                         failure_penalty_msat: 0,
318                         failure_penalty_half_life: Duration::from_secs(0),
319                         overuse_penalty_start_1024th: 1024,
320                         overuse_penalty_msat_per_1024th: 0,
321                 })
322         }
323 }
324
325 impl<T: Time> ChannelFailure<T> {
326         fn new(failure_penalty_msat: u64) -> Self {
327                 Self {
328                         undecayed_penalty_msat: failure_penalty_msat,
329                         last_updated: T::now(),
330                 }
331         }
332
333         fn add_penalty(&mut self, failure_penalty_msat: u64, half_life: Duration) {
334                 self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) + failure_penalty_msat;
335                 self.last_updated = T::now();
336         }
337
338         fn reduce_penalty(&mut self, half_life: Duration) {
339                 self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) >> 1;
340                 self.last_updated = T::now();
341         }
342
343         fn decayed_penalty_msat(&self, half_life: Duration) -> u64 {
344                 self.last_updated.elapsed().as_secs()
345                         .checked_div(half_life.as_secs())
346                         .and_then(|decays| self.undecayed_penalty_msat.checked_shr(decays as u32))
347                         .unwrap_or(0)
348         }
349 }
350
351 impl<T: Time> Default for ScorerUsingTime<T> {
352         fn default() -> Self {
353                 Self::new(ScoringParameters::default())
354         }
355 }
356
357 impl Default for ScoringParameters {
358         fn default() -> Self {
359                 Self {
360                         base_penalty_msat: 500,
361                         failure_penalty_msat: 1024 * 1000,
362                         failure_penalty_half_life: Duration::from_secs(3600),
363                         overuse_penalty_start_1024th: 1024 / 8,
364                         overuse_penalty_msat_per_1024th: 20,
365                 }
366         }
367 }
368
369 impl<T: Time> Score for ScorerUsingTime<T> {
370         fn channel_penalty_msat(
371                 &self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, _source: &NodeId, _target: &NodeId
372         ) -> u64 {
373                 let failure_penalty_msat = self.channel_failures
374                         .get(&short_channel_id)
375                         .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life));
376
377                 let mut penalty_msat = self.params.base_penalty_msat + failure_penalty_msat;
378                 let send_1024ths = send_amt_msat.checked_mul(1024).unwrap_or(u64::max_value()) / capacity_msat;
379                 if send_1024ths > self.params.overuse_penalty_start_1024th as u64 {
380                         penalty_msat = penalty_msat.checked_add(
381                                         (send_1024ths - self.params.overuse_penalty_start_1024th as u64)
382                                         .checked_mul(self.params.overuse_penalty_msat_per_1024th).unwrap_or(u64::max_value()))
383                                 .unwrap_or(u64::max_value());
384                 }
385
386                 penalty_msat
387         }
388
389         fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) {
390                 let failure_penalty_msat = self.params.failure_penalty_msat;
391                 let half_life = self.params.failure_penalty_half_life;
392                 self.channel_failures
393                         .entry(short_channel_id)
394                         .and_modify(|failure| failure.add_penalty(failure_penalty_msat, half_life))
395                         .or_insert_with(|| ChannelFailure::new(failure_penalty_msat));
396         }
397
398         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
399                 let half_life = self.params.failure_penalty_half_life;
400                 for hop in path.iter() {
401                         self.channel_failures
402                                 .entry(hop.short_channel_id)
403                                 .and_modify(|failure| failure.reduce_penalty(half_life));
404                 }
405         }
406 }
407
408 impl<T: Time> Writeable for ScorerUsingTime<T> {
409         #[inline]
410         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
411                 self.params.write(w)?;
412                 self.channel_failures.write(w)?;
413                 write_tlv_fields!(w, {});
414                 Ok(())
415         }
416 }
417
418 impl<T: Time> Readable for ScorerUsingTime<T> {
419         #[inline]
420         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
421                 let res = Ok(Self {
422                         params: Readable::read(r)?,
423                         channel_failures: Readable::read(r)?,
424                 });
425                 read_tlv_fields!(r, {});
426                 res
427         }
428 }
429
430 impl<T: Time> Writeable for ChannelFailure<T> {
431         #[inline]
432         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
433                 let duration_since_epoch = T::duration_since_epoch() - self.last_updated.elapsed();
434                 write_tlv_fields!(w, {
435                         (0, self.undecayed_penalty_msat, required),
436                         (2, duration_since_epoch, required),
437                 });
438                 Ok(())
439         }
440 }
441
442 impl<T: Time> Readable for ChannelFailure<T> {
443         #[inline]
444         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
445                 let mut undecayed_penalty_msat = 0;
446                 let mut duration_since_epoch = Duration::from_secs(0);
447                 read_tlv_fields!(r, {
448                         (0, undecayed_penalty_msat, required),
449                         (2, duration_since_epoch, required),
450                 });
451                 Ok(Self {
452                         undecayed_penalty_msat,
453                         last_updated: T::now() - (T::duration_since_epoch() - duration_since_epoch),
454                 })
455         }
456 }
457
458 /// [`Score`] implementation using channel success probability distributions.
459 ///
460 /// Based on *Optimally Reliable & Cheap Payment Flows on the Lightning Network* by Rene Pickhardt
461 /// and Stefan Richter [[1]]. Given the uncertainty of channel liquidity balances, probability
462 /// distributions are defined based on knowledge learned from successful and unsuccessful attempts.
463 /// Then the negative log of the success probability is used to determine the cost of routing a
464 /// specific HTLC amount through a channel.
465 ///
466 /// [1]: https://arxiv.org/abs/2107.05322
467 pub struct ProbabilisticScorer<G: Deref<Target = NetworkGraph>> {
468         params: ProbabilisticScoringParameters,
469         node_id: NodeId,
470         network_graph: G,
471         // TODO: Remove entries of closed channels.
472         channel_liquidities: HashMap<u64, ChannelLiquidity>,
473 }
474
475 /// Parameters for configuring [`ProbabilisticScorer`].
476 pub struct ProbabilisticScoringParameters {
477         /// A penalty applied after multiplying by the negative log of the channel's success probability
478         /// for a payment.
479         ///
480         /// The success probability is determined by the effective channel capacity, the payment amount,
481         /// and knowledge learned from prior successful and unsuccessful payments. The lower bound of
482         /// the success probability is 0.01, effectively limiting the penalty to the range
483         /// `0..=2*liquidity_penalty_multiplier_msat`.
484         ///
485         /// Default value: 1,000 msat
486         pub liquidity_penalty_multiplier_msat: u64,
487 }
488
489 impl_writeable_tlv_based!(ProbabilisticScoringParameters, {
490         (0, liquidity_penalty_multiplier_msat, required),
491 });
492
493 /// Accounting for channel liquidity balance uncertainty.
494 ///
495 /// Direction is defined in terms of [`NodeId`] partial ordering, where the source node is the
496 /// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
497 /// offset fields gives the opposite direction.
498 struct ChannelLiquidity {
499         min_liquidity_offset_msat: u64,
500         max_liquidity_offset_msat: u64,
501 }
502
503 /// A view of [`ChannelLiquidity`] in one direction assuming a certain channel capacity.
504 struct DirectedChannelLiquidity<L: Deref<Target = u64>> {
505         min_liquidity_offset_msat: L,
506         max_liquidity_offset_msat: L,
507         capacity_msat: u64,
508 }
509
510 impl<G: Deref<Target = NetworkGraph>> ProbabilisticScorer<G> {
511         /// Creates a new scorer using the given scoring parameters for sending payments from a node
512         /// through a network graph.
513         pub fn new(
514                 params: ProbabilisticScoringParameters, node_pubkey: &PublicKey, network_graph: G
515         ) -> Self {
516                 Self {
517                         params,
518                         node_id: NodeId::from_pubkey(node_pubkey),
519                         network_graph,
520                         channel_liquidities: HashMap::new(),
521                 }
522         }
523
524         #[cfg(test)]
525         fn with_channel(mut self, short_channel_id: u64, liquidity: ChannelLiquidity) -> Self {
526                 assert!(self.channel_liquidities.insert(short_channel_id, liquidity).is_none());
527                 self
528         }
529 }
530
531 impl Default for ProbabilisticScoringParameters {
532         fn default() -> Self {
533                 Self {
534                         liquidity_penalty_multiplier_msat: 1000,
535                 }
536         }
537 }
538
539 impl ChannelLiquidity {
540         #[inline]
541         fn new() -> Self {
542                 Self {
543                         min_liquidity_offset_msat: 0,
544                         max_liquidity_offset_msat: 0,
545                 }
546         }
547
548         /// Returns a view of the channel liquidity directed from `source` to `target` assuming
549         /// `capacity_msat`.
550         fn as_directed(
551                 &self, source: &NodeId, target: &NodeId, capacity_msat: u64
552         ) -> DirectedChannelLiquidity<&u64> {
553                 let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target {
554                         (&self.min_liquidity_offset_msat, &self.max_liquidity_offset_msat)
555                 } else {
556                         (&self.max_liquidity_offset_msat, &self.min_liquidity_offset_msat)
557                 };
558
559                 DirectedChannelLiquidity {
560                         min_liquidity_offset_msat,
561                         max_liquidity_offset_msat,
562                         capacity_msat,
563                 }
564         }
565
566         /// Returns a mutable view of the channel liquidity directed from `source` to `target` assuming
567         /// `capacity_msat`.
568         fn as_directed_mut(
569                 &mut self, source: &NodeId, target: &NodeId, capacity_msat: u64
570         ) -> DirectedChannelLiquidity<&mut u64> {
571                 let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target {
572                         (&mut self.min_liquidity_offset_msat, &mut self.max_liquidity_offset_msat)
573                 } else {
574                         (&mut self.max_liquidity_offset_msat, &mut self.min_liquidity_offset_msat)
575                 };
576
577                 DirectedChannelLiquidity {
578                         min_liquidity_offset_msat,
579                         max_liquidity_offset_msat,
580                         capacity_msat,
581                 }
582         }
583 }
584
585 impl<L: Deref<Target = u64>> DirectedChannelLiquidity<L> {
586         /// Returns the success probability of routing the given HTLC `amount_msat` through the channel
587         /// in this direction.
588         fn success_probability(&self, amount_msat: u64) -> f64 {
589                 let max_liquidity_msat = self.max_liquidity_msat();
590                 let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat);
591                 if amount_msat > max_liquidity_msat {
592                         0.0
593                 } else if amount_msat <= min_liquidity_msat {
594                         1.0
595                 } else {
596                         let numerator = max_liquidity_msat + 1 - amount_msat;
597                         let denominator = max_liquidity_msat + 1 - min_liquidity_msat;
598                         numerator as f64 / denominator as f64
599                 }.max(0.01) // Lower bound the success probability to ensure some channel is selected.
600         }
601
602         /// Returns the lower bound of the channel liquidity balance in this direction.
603         fn min_liquidity_msat(&self) -> u64 {
604                 *self.min_liquidity_offset_msat
605         }
606
607         /// Returns the upper bound of the channel liquidity balance in this direction.
608         fn max_liquidity_msat(&self) -> u64 {
609                 self.capacity_msat.checked_sub(*self.max_liquidity_offset_msat).unwrap_or(0)
610         }
611 }
612
613 impl<L: DerefMut<Target = u64>> DirectedChannelLiquidity<L> {
614         /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat`.
615         fn failed_at_channel(&mut self, amount_msat: u64) {
616                 if amount_msat < self.max_liquidity_msat() {
617                         self.set_max_liquidity_msat(amount_msat);
618                 }
619         }
620
621         /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat` downstream.
622         fn failed_downstream(&mut self, amount_msat: u64) {
623                 if amount_msat > self.min_liquidity_msat() {
624                         self.set_min_liquidity_msat(amount_msat);
625                 }
626         }
627
628         /// Adjusts the channel liquidity balance bounds when successfully routing `amount_msat`.
629         fn successful(&mut self, amount_msat: u64) {
630                 let max_liquidity_msat = self.max_liquidity_msat().checked_sub(amount_msat).unwrap_or(0);
631                 self.set_max_liquidity_msat(max_liquidity_msat);
632         }
633
634         /// Adjusts the lower bound of the channel liquidity balance in this direction.
635         fn set_min_liquidity_msat(&mut self, amount_msat: u64) {
636                 *self.min_liquidity_offset_msat = amount_msat;
637
638                 if amount_msat > self.max_liquidity_msat() {
639                         *self.max_liquidity_offset_msat = 0;
640                 }
641         }
642
643         /// Adjusts the upper bound of the channel liquidity balance in this direction.
644         fn set_max_liquidity_msat(&mut self, amount_msat: u64) {
645                 *self.max_liquidity_offset_msat = self.capacity_msat.checked_sub(amount_msat).unwrap_or(0);
646
647                 if amount_msat < self.min_liquidity_msat() {
648                         *self.min_liquidity_offset_msat = 0;
649                 }
650         }
651 }
652
653 impl<G: Deref<Target = NetworkGraph>> Score for ProbabilisticScorer<G> {
654         fn channel_penalty_msat(
655                 &self, short_channel_id: u64, amount_msat: u64, capacity_msat: u64, source: &NodeId,
656                 target: &NodeId
657         ) -> u64 {
658                 if *source == self.node_id || *target == self.node_id {
659                         return 0;
660                 }
661
662                 let liquidity_penalty_multiplier_msat = self.params.liquidity_penalty_multiplier_msat;
663                 let success_probability = self.channel_liquidities
664                         .get(&short_channel_id)
665                         .unwrap_or(&ChannelLiquidity::new())
666                         .as_directed(source, target, capacity_msat)
667                         .success_probability(amount_msat);
668                 if success_probability == 0.0 {
669                         u64::max_value()
670                 } else if success_probability == 1.0 {
671                         0
672                 } else {
673                         (-(success_probability.log10()) * liquidity_penalty_multiplier_msat as f64) as u64
674                 }
675         }
676
677         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
678                 let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
679                 let network_graph = self.network_graph.read_only();
680                 let hop_sources = core::iter::once(self.node_id)
681                         .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey)));
682                 for (source, hop) in hop_sources.zip(path.iter()) {
683                         let target = NodeId::from_pubkey(&hop.pubkey);
684                         if source == self.node_id || target == self.node_id {
685                                 continue;
686                         }
687
688                         let capacity_msat = network_graph.channels()
689                                 .get(&hop.short_channel_id)
690                                 .and_then(|channel| channel.as_directed_to(&target).map(|(d, _)| d.effective_capacity()))
691                                 .unwrap_or(EffectiveCapacity::Unknown)
692                                 .as_msat();
693
694                         if hop.short_channel_id == short_channel_id {
695                                 self.channel_liquidities
696                                         .entry(hop.short_channel_id)
697                                         .or_insert_with(|| ChannelLiquidity::new())
698                                         .as_directed_mut(&source, &target, capacity_msat)
699                                         .failed_at_channel(amount_msat);
700                                 break;
701                         }
702
703                         self.channel_liquidities
704                                 .entry(hop.short_channel_id)
705                                 .or_insert_with(|| ChannelLiquidity::new())
706                                 .as_directed_mut(&source, &target, capacity_msat)
707                                 .failed_downstream(amount_msat);
708                 }
709         }
710
711         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
712                 let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
713                 let network_graph = self.network_graph.read_only();
714                 let hop_sources = core::iter::once(self.node_id)
715                         .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey)));
716                 for (source, hop) in hop_sources.zip(path.iter()) {
717                         let target = NodeId::from_pubkey(&hop.pubkey);
718                         if source == self.node_id || target == self.node_id {
719                                 continue;
720                         }
721
722                         let capacity_msat = network_graph.channels()
723                                 .get(&hop.short_channel_id)
724                                 .and_then(|channel| channel.as_directed_to(&target).map(|(d, _)| d.effective_capacity()))
725                                 .unwrap_or(EffectiveCapacity::Unknown)
726                                 .as_msat();
727
728                         self.channel_liquidities
729                                 .entry(hop.short_channel_id)
730                                 .or_insert_with(|| ChannelLiquidity::new())
731                                 .as_directed_mut(&source, &target, capacity_msat)
732                                 .successful(amount_msat);
733                 }
734         }
735 }
736
737 impl<G: Deref<Target = NetworkGraph>> Writeable for ProbabilisticScorer<G> {
738         #[inline]
739         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
740                 self.params.write(w)?;
741                 self.channel_liquidities.write(w)?;
742                 write_tlv_fields!(w, {});
743                 Ok(())
744         }
745 }
746
747 impl<G: Deref<Target = NetworkGraph>> ReadableArgs<(&PublicKey, G)> for ProbabilisticScorer<G> {
748         #[inline]
749         fn read<R: Read>(r: &mut R, args: (&PublicKey, G)) -> Result<Self, DecodeError> {
750                 let (node_pubkey, network_graph) = args;
751                 let res = Ok(Self {
752                         params: Readable::read(r)?,
753                         node_id: NodeId::from_pubkey(node_pubkey),
754                         network_graph,
755                         channel_liquidities: Readable::read(r)?,
756                 });
757                 read_tlv_fields!(r, {});
758                 res
759         }
760 }
761
762 impl Writeable for ChannelLiquidity {
763         #[inline]
764         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
765                 write_tlv_fields!(w, {
766                         (0, self.min_liquidity_offset_msat, required),
767                         (2, self.max_liquidity_offset_msat, required),
768                 });
769                 Ok(())
770         }
771 }
772
773 impl Readable for ChannelLiquidity {
774         #[inline]
775         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
776                 let mut min_liquidity_offset_msat = 0;
777                 let mut max_liquidity_offset_msat = 0;
778                 read_tlv_fields!(r, {
779                         (0, min_liquidity_offset_msat, required),
780                         (2, max_liquidity_offset_msat, required),
781                 });
782                 Ok(Self {
783                         min_liquidity_offset_msat,
784                         max_liquidity_offset_msat
785                 })
786         }
787 }
788
789 pub(crate) mod time {
790         use core::ops::Sub;
791         use core::time::Duration;
792         /// A measurement of time.
793         pub trait Time: Sub<Duration, Output = Self> where Self: Sized {
794                 /// Returns an instance corresponding to the current moment.
795                 fn now() -> Self;
796
797                 /// Returns the amount of time elapsed since `self` was created.
798                 fn elapsed(&self) -> Duration;
799
800                 /// Returns the amount of time passed since the beginning of [`Time`].
801                 ///
802                 /// Used during (de-)serialization.
803                 fn duration_since_epoch() -> Duration;
804         }
805
806         /// A state in which time has no meaning.
807         #[derive(Debug, PartialEq, Eq)]
808         pub struct Eternity;
809
810         #[cfg(not(feature = "no-std"))]
811         impl Time for std::time::Instant {
812                 fn now() -> Self {
813                         std::time::Instant::now()
814                 }
815
816                 fn duration_since_epoch() -> Duration {
817                         use std::time::SystemTime;
818                         SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap()
819                 }
820
821                 fn elapsed(&self) -> Duration {
822                         std::time::Instant::elapsed(self)
823                 }
824         }
825
826         impl Time for Eternity {
827                 fn now() -> Self {
828                         Self
829                 }
830
831                 fn duration_since_epoch() -> Duration {
832                         Duration::from_secs(0)
833                 }
834
835                 fn elapsed(&self) -> Duration {
836                         Duration::from_secs(0)
837                 }
838         }
839
840         impl Sub<Duration> for Eternity {
841                 type Output = Self;
842
843                 fn sub(self, _other: Duration) -> Self {
844                         self
845                 }
846         }
847 }
848
849 pub(crate) use self::time::Time;
850
851 #[cfg(test)]
852 mod tests {
853         use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorer, ScoringParameters, ScorerUsingTime, Time};
854         use super::time::Eternity;
855
856         use ln::features::{ChannelFeatures, NodeFeatures};
857         use ln::msgs::{ChannelAnnouncement, ChannelUpdate, OptionalField, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
858         use routing::scoring::Score;
859         use routing::network_graph::{NetworkGraph, NodeId};
860         use routing::router::RouteHop;
861         use util::ser::{Readable, Writeable};
862
863         use bitcoin::blockdata::constants::genesis_block;
864         use bitcoin::hashes::Hash;
865         use bitcoin::hashes::sha256d::Hash as Sha256dHash;
866         use bitcoin::network::constants::Network;
867         use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
868         use core::cell::Cell;
869         use core::ops::Sub;
870         use core::time::Duration;
871         use io;
872
873         // `Time` tests
874
875         /// Time that can be advanced manually in tests.
876         #[derive(Debug, PartialEq, Eq)]
877         struct SinceEpoch(Duration);
878
879         impl SinceEpoch {
880                 thread_local! {
881                         static ELAPSED: Cell<Duration> = core::cell::Cell::new(Duration::from_secs(0));
882                 }
883
884                 fn advance(duration: Duration) {
885                         Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration))
886                 }
887         }
888
889         impl Time for SinceEpoch {
890                 fn now() -> Self {
891                         Self(Self::duration_since_epoch())
892                 }
893
894                 fn duration_since_epoch() -> Duration {
895                         Self::ELAPSED.with(|elapsed| elapsed.get())
896                 }
897
898                 fn elapsed(&self) -> Duration {
899                         Self::duration_since_epoch() - self.0
900                 }
901         }
902
903         impl Sub<Duration> for SinceEpoch {
904                 type Output = Self;
905
906                 fn sub(self, other: Duration) -> Self {
907                         Self(self.0 - other)
908                 }
909         }
910
911         #[test]
912         fn time_passes_when_advanced() {
913                 let now = SinceEpoch::now();
914                 assert_eq!(now.elapsed(), Duration::from_secs(0));
915
916                 SinceEpoch::advance(Duration::from_secs(1));
917                 SinceEpoch::advance(Duration::from_secs(1));
918
919                 let elapsed = now.elapsed();
920                 let later = SinceEpoch::now();
921
922                 assert_eq!(elapsed, Duration::from_secs(2));
923                 assert_eq!(later - elapsed, now);
924         }
925
926         #[test]
927         fn time_never_passes_in_an_eternity() {
928                 let now = Eternity::now();
929                 let elapsed = now.elapsed();
930                 let later = Eternity::now();
931
932                 assert_eq!(now.elapsed(), Duration::from_secs(0));
933                 assert_eq!(later - elapsed, now);
934         }
935
936         // `Scorer` tests
937
938         /// A scorer for testing with time that can be manually advanced.
939         type Scorer = ScorerUsingTime::<SinceEpoch>;
940
941         fn source_privkey() -> SecretKey {
942                 SecretKey::from_slice(&[42; 32]).unwrap()
943         }
944
945         fn target_privkey() -> SecretKey {
946                 SecretKey::from_slice(&[43; 32]).unwrap()
947         }
948
949         fn source_pubkey() -> PublicKey {
950                 let secp_ctx = Secp256k1::new();
951                 PublicKey::from_secret_key(&secp_ctx, &source_privkey())
952         }
953
954         fn target_pubkey() -> PublicKey {
955                 let secp_ctx = Secp256k1::new();
956                 PublicKey::from_secret_key(&secp_ctx, &target_privkey())
957         }
958
959         fn source_node_id() -> NodeId {
960                 NodeId::from_pubkey(&source_pubkey())
961         }
962
963         fn target_node_id() -> NodeId {
964                 NodeId::from_pubkey(&target_pubkey())
965         }
966
967         #[test]
968         fn penalizes_without_channel_failures() {
969                 let scorer = Scorer::new(ScoringParameters {
970                         base_penalty_msat: 1_000,
971                         failure_penalty_msat: 512,
972                         failure_penalty_half_life: Duration::from_secs(1),
973                         overuse_penalty_start_1024th: 1024,
974                         overuse_penalty_msat_per_1024th: 0,
975                 });
976                 let source = source_node_id();
977                 let target = target_node_id();
978                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
979
980                 SinceEpoch::advance(Duration::from_secs(1));
981                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
982         }
983
984         #[test]
985         fn accumulates_channel_failure_penalties() {
986                 let mut scorer = Scorer::new(ScoringParameters {
987                         base_penalty_msat: 1_000,
988                         failure_penalty_msat: 64,
989                         failure_penalty_half_life: Duration::from_secs(10),
990                         overuse_penalty_start_1024th: 1024,
991                         overuse_penalty_msat_per_1024th: 0,
992                 });
993                 let source = source_node_id();
994                 let target = target_node_id();
995                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
996
997                 scorer.payment_path_failed(&[], 42);
998                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
999
1000                 scorer.payment_path_failed(&[], 42);
1001                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1002
1003                 scorer.payment_path_failed(&[], 42);
1004                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_192);
1005         }
1006
1007         #[test]
1008         fn decays_channel_failure_penalties_over_time() {
1009                 let mut scorer = Scorer::new(ScoringParameters {
1010                         base_penalty_msat: 1_000,
1011                         failure_penalty_msat: 512,
1012                         failure_penalty_half_life: Duration::from_secs(10),
1013                         overuse_penalty_start_1024th: 1024,
1014                         overuse_penalty_msat_per_1024th: 0,
1015                 });
1016                 let source = source_node_id();
1017                 let target = target_node_id();
1018                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1019
1020                 scorer.payment_path_failed(&[], 42);
1021                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1022
1023                 SinceEpoch::advance(Duration::from_secs(9));
1024                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1025
1026                 SinceEpoch::advance(Duration::from_secs(1));
1027                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1028
1029                 SinceEpoch::advance(Duration::from_secs(10 * 8));
1030                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_001);
1031
1032                 SinceEpoch::advance(Duration::from_secs(10));
1033                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1034
1035                 SinceEpoch::advance(Duration::from_secs(10));
1036                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1037         }
1038
1039         #[test]
1040         fn decays_channel_failure_penalties_without_shift_overflow() {
1041                 let mut scorer = Scorer::new(ScoringParameters {
1042                         base_penalty_msat: 1_000,
1043                         failure_penalty_msat: 512,
1044                         failure_penalty_half_life: Duration::from_secs(10),
1045                         overuse_penalty_start_1024th: 1024,
1046                         overuse_penalty_msat_per_1024th: 0,
1047                 });
1048                 let source = source_node_id();
1049                 let target = target_node_id();
1050                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1051
1052                 scorer.payment_path_failed(&[], 42);
1053                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1054
1055                 // An unchecked right shift 64 bits or more in ChannelFailure::decayed_penalty_msat would
1056                 // cause an overflow.
1057                 SinceEpoch::advance(Duration::from_secs(10 * 64));
1058                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1059
1060                 SinceEpoch::advance(Duration::from_secs(10));
1061                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1062         }
1063
1064         #[test]
1065         fn accumulates_channel_failure_penalties_after_decay() {
1066                 let mut scorer = Scorer::new(ScoringParameters {
1067                         base_penalty_msat: 1_000,
1068                         failure_penalty_msat: 512,
1069                         failure_penalty_half_life: Duration::from_secs(10),
1070                         overuse_penalty_start_1024th: 1024,
1071                         overuse_penalty_msat_per_1024th: 0,
1072                 });
1073                 let source = source_node_id();
1074                 let target = target_node_id();
1075                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1076
1077                 scorer.payment_path_failed(&[], 42);
1078                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1079
1080                 SinceEpoch::advance(Duration::from_secs(10));
1081                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1082
1083                 scorer.payment_path_failed(&[], 42);
1084                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_768);
1085
1086                 SinceEpoch::advance(Duration::from_secs(10));
1087                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_384);
1088         }
1089
1090         #[test]
1091         fn reduces_channel_failure_penalties_after_success() {
1092                 let mut scorer = Scorer::new(ScoringParameters {
1093                         base_penalty_msat: 1_000,
1094                         failure_penalty_msat: 512,
1095                         failure_penalty_half_life: Duration::from_secs(10),
1096                         overuse_penalty_start_1024th: 1024,
1097                         overuse_penalty_msat_per_1024th: 0,
1098                 });
1099                 let source = source_node_id();
1100                 let target = target_node_id();
1101                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1102
1103                 scorer.payment_path_failed(&[], 42);
1104                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1105
1106                 SinceEpoch::advance(Duration::from_secs(10));
1107                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1108
1109                 let hop = RouteHop {
1110                         pubkey: PublicKey::from_slice(target.as_slice()).unwrap(),
1111                         node_features: NodeFeatures::known(),
1112                         short_channel_id: 42,
1113                         channel_features: ChannelFeatures::known(),
1114                         fee_msat: 1,
1115                         cltv_expiry_delta: 18,
1116                 };
1117                 scorer.payment_path_successful(&[&hop]);
1118                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1119
1120                 SinceEpoch::advance(Duration::from_secs(10));
1121                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
1122         }
1123
1124         #[test]
1125         fn restores_persisted_channel_failure_penalties() {
1126                 let mut scorer = Scorer::new(ScoringParameters {
1127                         base_penalty_msat: 1_000,
1128                         failure_penalty_msat: 512,
1129                         failure_penalty_half_life: Duration::from_secs(10),
1130                         overuse_penalty_start_1024th: 1024,
1131                         overuse_penalty_msat_per_1024th: 0,
1132                 });
1133                 let source = source_node_id();
1134                 let target = target_node_id();
1135
1136                 scorer.payment_path_failed(&[], 42);
1137                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1138
1139                 SinceEpoch::advance(Duration::from_secs(10));
1140                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1141
1142                 scorer.payment_path_failed(&[], 43);
1143                 assert_eq!(scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
1144
1145                 let mut serialized_scorer = Vec::new();
1146                 scorer.write(&mut serialized_scorer).unwrap();
1147
1148                 let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
1149                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1150                 assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
1151         }
1152
1153         #[test]
1154         fn decays_persisted_channel_failure_penalties() {
1155                 let mut scorer = Scorer::new(ScoringParameters {
1156                         base_penalty_msat: 1_000,
1157                         failure_penalty_msat: 512,
1158                         failure_penalty_half_life: Duration::from_secs(10),
1159                         overuse_penalty_start_1024th: 1024,
1160                         overuse_penalty_msat_per_1024th: 0,
1161                 });
1162                 let source = source_node_id();
1163                 let target = target_node_id();
1164
1165                 scorer.payment_path_failed(&[], 42);
1166                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1167
1168                 let mut serialized_scorer = Vec::new();
1169                 scorer.write(&mut serialized_scorer).unwrap();
1170
1171                 SinceEpoch::advance(Duration::from_secs(10));
1172
1173                 let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
1174                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1175
1176                 SinceEpoch::advance(Duration::from_secs(10));
1177                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1178         }
1179
1180         #[test]
1181         fn charges_per_1024th_penalty() {
1182                 let scorer = Scorer::new(ScoringParameters {
1183                         base_penalty_msat: 0,
1184                         failure_penalty_msat: 0,
1185                         failure_penalty_half_life: Duration::from_secs(0),
1186                         overuse_penalty_start_1024th: 256,
1187                         overuse_penalty_msat_per_1024th: 100,
1188                 });
1189                 let source = source_node_id();
1190                 let target = target_node_id();
1191
1192                 assert_eq!(scorer.channel_penalty_msat(42, 1_000, 1_024_000, &source, &target), 0);
1193                 assert_eq!(scorer.channel_penalty_msat(42, 256_999, 1_024_000, &source, &target), 0);
1194                 assert_eq!(scorer.channel_penalty_msat(42, 257_000, 1_024_000, &source, &target), 100);
1195                 assert_eq!(scorer.channel_penalty_msat(42, 258_000, 1_024_000, &source, &target), 200);
1196                 assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 256 * 100);
1197         }
1198
1199         // `ProbabilisticScorer` tests
1200
1201         fn sender_privkey() -> SecretKey {
1202                 SecretKey::from_slice(&[41; 32]).unwrap()
1203         }
1204
1205         fn recipient_privkey() -> SecretKey {
1206                 SecretKey::from_slice(&[45; 32]).unwrap()
1207         }
1208
1209         fn sender_pubkey() -> PublicKey {
1210                 let secp_ctx = Secp256k1::new();
1211                 PublicKey::from_secret_key(&secp_ctx, &sender_privkey())
1212         }
1213
1214         fn recipient_pubkey() -> PublicKey {
1215                 let secp_ctx = Secp256k1::new();
1216                 PublicKey::from_secret_key(&secp_ctx, &recipient_privkey())
1217         }
1218
1219         fn sender_node_id() -> NodeId {
1220                 NodeId::from_pubkey(&sender_pubkey())
1221         }
1222
1223         fn recipient_node_id() -> NodeId {
1224                 NodeId::from_pubkey(&recipient_pubkey())
1225         }
1226
1227         fn network_graph() -> NetworkGraph {
1228                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1229                 let mut network_graph = NetworkGraph::new(genesis_hash);
1230                 add_channel(&mut network_graph, 41, sender_privkey(), source_privkey());
1231                 add_channel(&mut network_graph, 42, source_privkey(), target_privkey());
1232                 add_channel(&mut network_graph, 43, target_privkey(), recipient_privkey());
1233
1234                 network_graph
1235         }
1236
1237         fn add_channel(
1238                 network_graph: &mut NetworkGraph, short_channel_id: u64, node_1_key: SecretKey,
1239                 node_2_key: SecretKey
1240         ) {
1241                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1242                 let node_1_secret = &SecretKey::from_slice(&[39; 32]).unwrap();
1243                 let node_2_secret = &SecretKey::from_slice(&[40; 32]).unwrap();
1244                 let secp_ctx = Secp256k1::new();
1245                 let unsigned_announcement = UnsignedChannelAnnouncement {
1246                         features: ChannelFeatures::known(),
1247                         chain_hash: genesis_hash,
1248                         short_channel_id,
1249                         node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_key),
1250                         node_id_2: PublicKey::from_secret_key(&secp_ctx, &node_2_key),
1251                         bitcoin_key_1: PublicKey::from_secret_key(&secp_ctx, &node_1_secret),
1252                         bitcoin_key_2: PublicKey::from_secret_key(&secp_ctx, &node_2_secret),
1253                         excess_data: Vec::new(),
1254                 };
1255                 let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_announcement.encode()[..])[..]);
1256                 let signed_announcement = ChannelAnnouncement {
1257                         node_signature_1: secp_ctx.sign(&msghash, &node_1_key),
1258                         node_signature_2: secp_ctx.sign(&msghash, &node_2_key),
1259                         bitcoin_signature_1: secp_ctx.sign(&msghash, &node_1_secret),
1260                         bitcoin_signature_2: secp_ctx.sign(&msghash, &node_2_secret),
1261                         contents: unsigned_announcement,
1262                 };
1263                 let chain_source: Option<&::util::test_utils::TestChainSource> = None;
1264                 network_graph.update_channel_from_announcement(
1265                         &signed_announcement, &chain_source, &secp_ctx).unwrap();
1266                 update_channel(network_graph, short_channel_id, node_1_key, 0);
1267                 update_channel(network_graph, short_channel_id, node_2_key, 1);
1268         }
1269
1270         fn update_channel(
1271                 network_graph: &mut NetworkGraph, short_channel_id: u64, node_key: SecretKey, flags: u8
1272         ) {
1273                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1274                 let secp_ctx = Secp256k1::new();
1275                 let unsigned_update = UnsignedChannelUpdate {
1276                         chain_hash: genesis_hash,
1277                         short_channel_id,
1278                         timestamp: 100,
1279                         flags,
1280                         cltv_expiry_delta: 18,
1281                         htlc_minimum_msat: 0,
1282                         htlc_maximum_msat: OptionalField::Present(1_000),
1283                         fee_base_msat: 1,
1284                         fee_proportional_millionths: 0,
1285                         excess_data: Vec::new(),
1286                 };
1287                 let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_update.encode()[..])[..]);
1288                 let signed_update = ChannelUpdate {
1289                         signature: secp_ctx.sign(&msghash, &node_key),
1290                         contents: unsigned_update,
1291                 };
1292                 network_graph.update_channel(&signed_update, &secp_ctx).unwrap();
1293         }
1294
1295         fn payment_path_for_amount(amount_msat: u64) -> Vec<RouteHop> {
1296                 vec![
1297                         RouteHop {
1298                                 pubkey: source_pubkey(),
1299                                 node_features: NodeFeatures::known(),
1300                                 short_channel_id: 41,
1301                                 channel_features: ChannelFeatures::known(),
1302                                 fee_msat: 1,
1303                                 cltv_expiry_delta: 18,
1304                         },
1305                         RouteHop {
1306                                 pubkey: target_pubkey(),
1307                                 node_features: NodeFeatures::known(),
1308                                 short_channel_id: 42,
1309                                 channel_features: ChannelFeatures::known(),
1310                                 fee_msat: 2,
1311                                 cltv_expiry_delta: 18,
1312                         },
1313                         RouteHop {
1314                                 pubkey: recipient_pubkey(),
1315                                 node_features: NodeFeatures::known(),
1316                                 short_channel_id: 43,
1317                                 channel_features: ChannelFeatures::known(),
1318                                 fee_msat: amount_msat,
1319                                 cltv_expiry_delta: 18,
1320                         },
1321                 ]
1322         }
1323
1324         #[test]
1325         fn liquidity_bounds_directed_from_lowest_node_id() {
1326                 let network_graph = network_graph();
1327                 let params = ProbabilisticScoringParameters::default();
1328                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1329                         .with_channel(42,
1330                                 ChannelLiquidity {
1331                                         min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
1332                                 })
1333                         .with_channel(43,
1334                                 ChannelLiquidity {
1335                                         min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
1336                                 });
1337                 let source = source_node_id();
1338                 let target = target_node_id();
1339                 let recipient = recipient_node_id();
1340
1341                 let liquidity = scorer.channel_liquidities.get_mut(&42).unwrap();
1342                 assert!(source > target);
1343                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 100);
1344                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300);
1345                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700);
1346                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 900);
1347
1348                 liquidity.as_directed_mut(&source, &target, 1_000).set_min_liquidity_msat(200);
1349                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 200);
1350                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300);
1351                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700);
1352                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 800);
1353
1354                 let liquidity = scorer.channel_liquidities.get_mut(&43).unwrap();
1355                 assert!(target < recipient);
1356                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 700);
1357                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 900);
1358                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 100);
1359                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 300);
1360
1361                 liquidity.as_directed_mut(&target, &recipient, 1_000).set_max_liquidity_msat(200);
1362                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 0);
1363                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 200);
1364                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 800);
1365                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 1000);
1366         }
1367
1368         #[test]
1369         fn resets_liquidity_upper_bound_when_crossed_by_lower_bound() {
1370                 let network_graph = network_graph();
1371                 let params = ProbabilisticScoringParameters::default();
1372                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1373                         .with_channel(42,
1374                                 ChannelLiquidity {
1375                                         min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
1376                                 });
1377                 let source = source_node_id();
1378                 let target = target_node_id();
1379                 assert!(source > target);
1380
1381                 // Check initial bounds.
1382                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1383                         .as_directed(&source, &target, 1_000);
1384                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1385                 assert_eq!(liquidity.max_liquidity_msat(), 800);
1386
1387                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1388                         .as_directed(&target, &source, 1_000);
1389                 assert_eq!(liquidity.min_liquidity_msat(), 200);
1390                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1391
1392                 // Reset from source to target.
1393                 scorer.channel_liquidities.get_mut(&42).unwrap()
1394                         .as_directed_mut(&source, &target, 1_000)
1395                         .set_min_liquidity_msat(900);
1396
1397                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1398                         .as_directed(&source, &target, 1_000);
1399                 assert_eq!(liquidity.min_liquidity_msat(), 900);
1400                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1401
1402                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1403                         .as_directed(&target, &source, 1_000);
1404                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1405                 assert_eq!(liquidity.max_liquidity_msat(), 100);
1406
1407                 // Reset from target to source.
1408                 scorer.channel_liquidities.get_mut(&42).unwrap()
1409                         .as_directed_mut(&target, &source, 1_000)
1410                         .set_min_liquidity_msat(400);
1411
1412                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1413                         .as_directed(&source, &target, 1_000);
1414                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1415                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1416
1417                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1418                         .as_directed(&target, &source, 1_000);
1419                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1420                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1421         }
1422
1423         #[test]
1424         fn resets_liquidity_lower_bound_when_crossed_by_upper_bound() {
1425                 let network_graph = network_graph();
1426                 let params = ProbabilisticScoringParameters::default();
1427                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1428                         .with_channel(42,
1429                                 ChannelLiquidity {
1430                                         min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
1431                                 });
1432                 let source = source_node_id();
1433                 let target = target_node_id();
1434                 assert!(source > target);
1435
1436                 // Check initial bounds.
1437                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1438                         .as_directed(&source, &target, 1_000);
1439                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1440                 assert_eq!(liquidity.max_liquidity_msat(), 800);
1441
1442                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1443                         .as_directed(&target, &source, 1_000);
1444                 assert_eq!(liquidity.min_liquidity_msat(), 200);
1445                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1446
1447                 // Reset from source to target.
1448                 scorer.channel_liquidities.get_mut(&42).unwrap()
1449                         .as_directed_mut(&source, &target, 1_000)
1450                         .set_max_liquidity_msat(300);
1451
1452                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1453                         .as_directed(&source, &target, 1_000);
1454                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1455                 assert_eq!(liquidity.max_liquidity_msat(), 300);
1456
1457                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1458                         .as_directed(&target, &source, 1_000);
1459                 assert_eq!(liquidity.min_liquidity_msat(), 700);
1460                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1461
1462                 // Reset from target to source.
1463                 scorer.channel_liquidities.get_mut(&42).unwrap()
1464                         .as_directed_mut(&target, &source, 1_000)
1465                         .set_max_liquidity_msat(600);
1466
1467                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1468                         .as_directed(&source, &target, 1_000);
1469                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1470                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1471
1472                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1473                         .as_directed(&target, &source, 1_000);
1474                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1475                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1476         }
1477
1478         #[test]
1479         fn increased_penalty_nearing_liquidity_upper_bound() {
1480                 let network_graph = network_graph();
1481                 let params = ProbabilisticScoringParameters::default();
1482                 let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1483                 let source = source_node_id();
1484                 let target = target_node_id();
1485
1486                 assert_eq!(scorer.channel_penalty_msat(42, 100, 100_000, &source, &target), 0);
1487                 assert_eq!(scorer.channel_penalty_msat(42, 1_000, 100_000, &source, &target), 4);
1488                 assert_eq!(scorer.channel_penalty_msat(42, 10_000, 100_000, &source, &target), 45);
1489                 assert_eq!(scorer.channel_penalty_msat(42, 100_000, 100_000, &source, &target), 2_000);
1490
1491                 assert_eq!(scorer.channel_penalty_msat(42, 125, 1_000, &source, &target), 57);
1492                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1493                 assert_eq!(scorer.channel_penalty_msat(42, 375, 1_000, &source, &target), 203);
1494                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1495                 assert_eq!(scorer.channel_penalty_msat(42, 625, 1_000, &source, &target), 425);
1496                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1497                 assert_eq!(scorer.channel_penalty_msat(42, 875, 1_000, &source, &target), 900);
1498         }
1499
1500         #[test]
1501         fn constant_penalty_outside_liquidity_bounds() {
1502                 let network_graph = network_graph();
1503                 let params = ProbabilisticScoringParameters::default();
1504                 let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1505                         .with_channel(42,
1506                                 ChannelLiquidity { min_liquidity_offset_msat: 40, max_liquidity_offset_msat: 40 });
1507                 let source = source_node_id();
1508                 let target = target_node_id();
1509
1510                 assert_eq!(scorer.channel_penalty_msat(42, 39, 100, &source, &target), 0);
1511                 assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 0);
1512                 assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 2_000);
1513                 assert_eq!(scorer.channel_penalty_msat(42, 61, 100, &source, &target), 2_000);
1514         }
1515
1516         #[test]
1517         fn does_not_penalize_own_channel() {
1518                 let network_graph = network_graph();
1519                 let params = ProbabilisticScoringParameters::default();
1520                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1521                 let sender = sender_node_id();
1522                 let source = source_node_id();
1523                 let failed_path = payment_path_for_amount(500);
1524                 let successful_path = payment_path_for_amount(200);
1525
1526                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1527
1528                 scorer.payment_path_failed(&failed_path.iter().collect::<Vec<_>>(), 41);
1529                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1530
1531                 scorer.payment_path_successful(&successful_path.iter().collect::<Vec<_>>());
1532                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1533         }
1534
1535         #[test]
1536         fn sets_liquidity_lower_bound_on_downstream_failure() {
1537                 let network_graph = network_graph();
1538                 let params = ProbabilisticScoringParameters::default();
1539                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1540                 let source = source_node_id();
1541                 let target = target_node_id();
1542                 let path = payment_path_for_amount(500);
1543
1544                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1545                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1546                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1547
1548                 scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 43);
1549
1550                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 0);
1551                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 0);
1552                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 300);
1553         }
1554
1555         #[test]
1556         fn sets_liquidity_upper_bound_on_failure() {
1557                 let network_graph = network_graph();
1558                 let params = ProbabilisticScoringParameters::default();
1559                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1560                 let source = source_node_id();
1561                 let target = target_node_id();
1562                 let path = payment_path_for_amount(500);
1563
1564                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1565                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1566                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1567
1568                 scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 42);
1569
1570                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
1571                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 2_000);
1572                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 2_000);
1573         }
1574
1575         #[test]
1576         fn reduces_liquidity_upper_bound_along_path_on_success() {
1577                 let network_graph = network_graph();
1578                 let params = ProbabilisticScoringParameters::default();
1579                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1580                 let sender = sender_node_id();
1581                 let source = source_node_id();
1582                 let target = target_node_id();
1583                 let recipient = recipient_node_id();
1584                 let path = payment_path_for_amount(500);
1585
1586                 assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 0);
1587                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1588                 assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 124);
1589
1590                 scorer.payment_path_successful(&path.iter().collect::<Vec<_>>());
1591
1592                 assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 0);
1593                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
1594                 assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 300);
1595         }
1596 }