cf47ab853b93589561e7e362b4791747d613f703
[rust-lightning] / lightning / src / routing / scoring.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Utilities for scoring payment channels.
11 //!
12 //! [`ProbabilisticScorer`] may be given to [`find_route`] to score payment channels during path
13 //! finding when a custom [`Score`] implementation is not needed.
14 //!
15 //! # Example
16 //!
17 //! ```
18 //! # extern crate secp256k1;
19 //! #
20 //! # use lightning::routing::network_graph::NetworkGraph;
21 //! # use lightning::routing::router::{RouteParameters, find_route};
22 //! # use lightning::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters, Scorer, ScoringParameters};
23 //! # use lightning::util::logger::{Logger, Record};
24 //! # use secp256k1::key::PublicKey;
25 //! #
26 //! # struct FakeLogger {};
27 //! # impl Logger for FakeLogger {
28 //! #     fn log(&self, record: &Record) { unimplemented!() }
29 //! # }
30 //! # fn find_scored_route(payer: PublicKey, route_params: RouteParameters, network_graph: NetworkGraph) {
31 //! # let logger = FakeLogger {};
32 //! #
33 //! // Use the default channel penalties.
34 //! let params = ProbabilisticScoringParameters::default();
35 //! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
36 //!
37 //! // Or use custom channel penalties.
38 //! let params = ProbabilisticScoringParameters {
39 //!     liquidity_penalty_multiplier_msat: 2 * 1000,
40 //!     ..ProbabilisticScoringParameters::default()
41 //! };
42 //! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
43 //!
44 //! let route = find_route(&payer, &route_params, &network_graph, None, &logger, &scorer);
45 //! # }
46 //! ```
47 //!
48 //! # Note
49 //!
50 //! Persisting when built with feature `no-std` and restoring without it, or vice versa, uses
51 //! different types and thus is undefined.
52 //!
53 //! [`find_route`]: crate::routing::router::find_route
54
55 use bitcoin::secp256k1::key::PublicKey;
56
57 use ln::msgs::DecodeError;
58 use routing::network_graph::{EffectiveCapacity, NetworkGraph, NodeId};
59 use routing::router::RouteHop;
60 use util::ser::{Readable, ReadableArgs, Writeable, Writer};
61
62 use prelude::*;
63 use core::cell::{RefCell, RefMut};
64 use core::ops::{Deref, DerefMut};
65 use core::time::Duration;
66 use io::{self, Read};
67 use sync::{Mutex, MutexGuard};
68
69 /// We define Score ever-so-slightly differently based on whether we are being built for C bindings
70 /// or not. For users, `LockableScore` must somehow be writeable to disk. For Rust users, this is
71 /// no problem - you move a `Score` that implements `Writeable` into a `Mutex`, lock it, and now
72 /// you have the original, concrete, `Score` type, which presumably implements `Writeable`.
73 ///
74 /// For C users, once you've moved the `Score` into a `LockableScore` all you have after locking it
75 /// is an opaque trait object with an opaque pointer with no type info. Users could take the unsafe
76 /// approach of blindly casting that opaque pointer to a concrete type and calling `Writeable` from
77 /// there, but other languages downstream of the C bindings (e.g. Java) can't even do that.
78 /// Instead, we really want `Score` and `LockableScore` to implement `Writeable` directly, which we
79 /// do here by defining `Score` differently for `cfg(c_bindings)`.
80 macro_rules! define_score { ($($supertrait: path)*) => {
81 /// An interface used to score payment channels for path finding.
82 ///
83 ///     Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
84 pub trait Score $(: $supertrait)* {
85         /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the
86         /// given channel in the direction from `source` to `target`.
87         ///
88         /// The channel's capacity (less any other MPP parts that are also being considered for use in
89         /// the same payment) is given by `capacity_msat`. It may be determined from various sources
90         /// such as a chain data, network gossip, or invoice hints. For invoice hints, a capacity near
91         /// [`u64::max_value`] is given to indicate sufficient capacity for the invoice's full amount.
92         /// Thus, implementations should be overflow-safe.
93         fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64;
94
95         /// Handles updating channel penalties after failing to route through a channel.
96         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64);
97
98         /// Handles updating channel penalties after successfully routing along a path.
99         fn payment_path_successful(&mut self, path: &[&RouteHop]);
100 }
101
102 impl<S: Score, T: DerefMut<Target=S> $(+ $supertrait)*> Score for T {
103         fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, source: &NodeId, target: &NodeId) -> u64 {
104                 self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, capacity_msat, source, target)
105         }
106
107         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
108                 self.deref_mut().payment_path_failed(path, short_channel_id)
109         }
110
111         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
112                 self.deref_mut().payment_path_successful(path)
113         }
114 }
115 } }
116
117 #[cfg(c_bindings)]
118 define_score!(Writeable);
119 #[cfg(not(c_bindings))]
120 define_score!();
121
122 /// A scorer that is accessed under a lock.
123 ///
124 /// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
125 /// having shared ownership of a scorer but without requiring internal locking in [`Score`]
126 /// implementations. Internal locking would be detrimental to route finding performance and could
127 /// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
128 ///
129 /// [`find_route`]: crate::routing::router::find_route
130 pub trait LockableScore<'a> {
131         /// The locked [`Score`] type.
132         type Locked: 'a + Score;
133
134         /// Returns the locked scorer.
135         fn lock(&'a self) -> Self::Locked;
136 }
137
138 /// (C-not exported)
139 impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
140         type Locked = MutexGuard<'a, T>;
141
142         fn lock(&'a self) -> MutexGuard<'a, T> {
143                 Mutex::lock(self).unwrap()
144         }
145 }
146
147 impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
148         type Locked = RefMut<'a, T>;
149
150         fn lock(&'a self) -> RefMut<'a, T> {
151                 self.borrow_mut()
152         }
153 }
154
155 #[cfg(c_bindings)]
156 /// A concrete implementation of [`LockableScore`] which supports multi-threading.
157 pub struct MultiThreadedLockableScore<S: Score> {
158         score: Mutex<S>,
159 }
160 #[cfg(c_bindings)]
161 /// (C-not exported)
162 impl<'a, T: Score + 'a> LockableScore<'a> for MultiThreadedLockableScore<T> {
163         type Locked = MutexGuard<'a, T>;
164
165         fn lock(&'a self) -> MutexGuard<'a, T> {
166                 Mutex::lock(&self.score).unwrap()
167         }
168 }
169
170 #[cfg(c_bindings)]
171 impl<T: Score> MultiThreadedLockableScore<T> {
172         /// Creates a new [`MultiThreadedLockableScore`] given an underlying [`Score`].
173         pub fn new(score: T) -> Self {
174                 MultiThreadedLockableScore { score: Mutex::new(score) }
175         }
176 }
177
178 #[cfg(c_bindings)]
179 /// (C-not exported)
180 impl<'a, T: Writeable> Writeable for RefMut<'a, T> {
181         fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
182                 T::write(&**self, writer)
183         }
184 }
185
186 #[cfg(c_bindings)]
187 /// (C-not exported)
188 impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> {
189         fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
190                 S::write(&**self, writer)
191         }
192 }
193
194 /// [`Score`] implementation that provides reasonable default behavior.
195 ///
196 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
197 /// slightly higher fees are available. Will further penalize channels that fail to relay payments.
198 ///
199 /// See [module-level documentation] for usage.
200 ///
201 /// [module-level documentation]: crate::routing::scoring
202 #[cfg(not(feature = "no-std"))]
203 pub type Scorer = ScorerUsingTime::<std::time::Instant>;
204 /// [`Score`] implementation that provides reasonable default behavior.
205 ///
206 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
207 /// slightly higher fees are available. Will further penalize channels that fail to relay payments.
208 ///
209 /// See [module-level documentation] for usage and [`ScoringParameters`] for customization.
210 ///
211 /// [module-level documentation]: crate::routing::scoring
212 #[cfg(feature = "no-std")]
213 pub type Scorer = ScorerUsingTime::<time::Eternity>;
214
215 // Note that ideally we'd hide ScorerUsingTime from public view by sealing it as well, but rustdoc
216 // doesn't handle this well - instead exposing a `Scorer` which has no trait implementation(s) or
217 // methods at all.
218
219 /// [`Score`] implementation.
220 ///
221 /// See [`Scorer`] for details.
222 ///
223 /// # Note
224 ///
225 /// Mixing the `no-std` feature between serialization and deserialization results in undefined
226 /// behavior.
227 ///
228 /// (C-not exported) generally all users should use the [`Scorer`] type alias.
229 pub struct ScorerUsingTime<T: Time> {
230         params: ScoringParameters,
231         // TODO: Remove entries of closed channels.
232         channel_failures: HashMap<u64, ChannelFailure<T>>,
233 }
234
235 /// Parameters for configuring [`Scorer`].
236 pub struct ScoringParameters {
237         /// A fixed penalty in msats to apply to each channel.
238         ///
239         /// Default value: 500 msat
240         pub base_penalty_msat: u64,
241
242         /// A penalty in msats to apply to a channel upon failing to relay a payment.
243         ///
244         /// This accumulates for each failure but may be reduced over time based on
245         /// [`failure_penalty_half_life`] or when successfully routing through a channel.
246         ///
247         /// Default value: 1,024,000 msat
248         ///
249         /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life
250         pub failure_penalty_msat: u64,
251
252         /// When the amount being sent over a channel is this many 1024ths of the total channel
253         /// capacity, we begin applying [`overuse_penalty_msat_per_1024th`].
254         ///
255         /// Default value: 128 1024ths (i.e. begin penalizing when an HTLC uses 1/8th of a channel)
256         ///
257         /// [`overuse_penalty_msat_per_1024th`]: Self::overuse_penalty_msat_per_1024th
258         pub overuse_penalty_start_1024th: u16,
259
260         /// A penalty applied, per whole 1024ths of the channel capacity which the amount being sent
261         /// over the channel exceeds [`overuse_penalty_start_1024th`] by.
262         ///
263         /// Default value: 20 msat (i.e. 2560 msat penalty to use 1/4th of a channel, 7680 msat penalty
264         ///                to use half a channel, and 12,560 msat penalty to use 3/4ths of a channel)
265         ///
266         /// [`overuse_penalty_start_1024th`]: Self::overuse_penalty_start_1024th
267         pub overuse_penalty_msat_per_1024th: u64,
268
269         /// The time required to elapse before any accumulated [`failure_penalty_msat`] penalties are
270         /// cut in half.
271         ///
272         /// Successfully routing through a channel will immediately cut the penalty in half as well.
273         ///
274         /// # Note
275         ///
276         /// When built with the `no-std` feature, time will never elapse. Therefore, this penalty will
277         /// never decay.
278         ///
279         /// [`failure_penalty_msat`]: Self::failure_penalty_msat
280         pub failure_penalty_half_life: Duration,
281 }
282
283 impl_writeable_tlv_based!(ScoringParameters, {
284         (0, base_penalty_msat, required),
285         (1, overuse_penalty_start_1024th, (default_value, 128)),
286         (2, failure_penalty_msat, required),
287         (3, overuse_penalty_msat_per_1024th, (default_value, 20)),
288         (4, failure_penalty_half_life, required),
289 });
290
291 /// Accounting for penalties against a channel for failing to relay any payments.
292 ///
293 /// Penalties decay over time, though accumulate as more failures occur.
294 struct ChannelFailure<T: Time> {
295         /// Accumulated penalty in msats for the channel as of `last_updated`.
296         undecayed_penalty_msat: u64,
297
298         /// Last time the channel either failed to route or successfully routed a payment. Used to decay
299         /// `undecayed_penalty_msat`.
300         last_updated: T,
301 }
302
303 impl<T: Time> ScorerUsingTime<T> {
304         /// Creates a new scorer using the given scoring parameters.
305         pub fn new(params: ScoringParameters) -> Self {
306                 Self {
307                         params,
308                         channel_failures: HashMap::new(),
309                 }
310         }
311
312         /// Creates a new scorer using `penalty_msat` as a fixed channel penalty.
313         #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
314         pub fn with_fixed_penalty(penalty_msat: u64) -> Self {
315                 Self::new(ScoringParameters {
316                         base_penalty_msat: penalty_msat,
317                         failure_penalty_msat: 0,
318                         failure_penalty_half_life: Duration::from_secs(0),
319                         overuse_penalty_start_1024th: 1024,
320                         overuse_penalty_msat_per_1024th: 0,
321                 })
322         }
323 }
324
325 impl<T: Time> ChannelFailure<T> {
326         fn new(failure_penalty_msat: u64) -> Self {
327                 Self {
328                         undecayed_penalty_msat: failure_penalty_msat,
329                         last_updated: T::now(),
330                 }
331         }
332
333         fn add_penalty(&mut self, failure_penalty_msat: u64, half_life: Duration) {
334                 self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) + failure_penalty_msat;
335                 self.last_updated = T::now();
336         }
337
338         fn reduce_penalty(&mut self, half_life: Duration) {
339                 self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) >> 1;
340                 self.last_updated = T::now();
341         }
342
343         fn decayed_penalty_msat(&self, half_life: Duration) -> u64 {
344                 self.last_updated.elapsed().as_secs()
345                         .checked_div(half_life.as_secs())
346                         .and_then(|decays| self.undecayed_penalty_msat.checked_shr(decays as u32))
347                         .unwrap_or(0)
348         }
349 }
350
351 impl<T: Time> Default for ScorerUsingTime<T> {
352         fn default() -> Self {
353                 Self::new(ScoringParameters::default())
354         }
355 }
356
357 impl Default for ScoringParameters {
358         fn default() -> Self {
359                 Self {
360                         base_penalty_msat: 500,
361                         failure_penalty_msat: 1024 * 1000,
362                         failure_penalty_half_life: Duration::from_secs(3600),
363                         overuse_penalty_start_1024th: 1024 / 8,
364                         overuse_penalty_msat_per_1024th: 20,
365                 }
366         }
367 }
368
369 impl<T: Time> Score for ScorerUsingTime<T> {
370         fn channel_penalty_msat(
371                 &self, short_channel_id: u64, send_amt_msat: u64, capacity_msat: u64, _source: &NodeId, _target: &NodeId
372         ) -> u64 {
373                 let failure_penalty_msat = self.channel_failures
374                         .get(&short_channel_id)
375                         .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life));
376
377                 let mut penalty_msat = self.params.base_penalty_msat + failure_penalty_msat;
378                 let send_1024ths = send_amt_msat.checked_mul(1024).unwrap_or(u64::max_value()) / capacity_msat;
379                 if send_1024ths > self.params.overuse_penalty_start_1024th as u64 {
380                         penalty_msat = penalty_msat.checked_add(
381                                         (send_1024ths - self.params.overuse_penalty_start_1024th as u64)
382                                         .checked_mul(self.params.overuse_penalty_msat_per_1024th).unwrap_or(u64::max_value()))
383                                 .unwrap_or(u64::max_value());
384                 }
385
386                 penalty_msat
387         }
388
389         fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) {
390                 let failure_penalty_msat = self.params.failure_penalty_msat;
391                 let half_life = self.params.failure_penalty_half_life;
392                 self.channel_failures
393                         .entry(short_channel_id)
394                         .and_modify(|failure| failure.add_penalty(failure_penalty_msat, half_life))
395                         .or_insert_with(|| ChannelFailure::new(failure_penalty_msat));
396         }
397
398         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
399                 let half_life = self.params.failure_penalty_half_life;
400                 for hop in path.iter() {
401                         self.channel_failures
402                                 .entry(hop.short_channel_id)
403                                 .and_modify(|failure| failure.reduce_penalty(half_life));
404                 }
405         }
406 }
407
408 impl<T: Time> Writeable for ScorerUsingTime<T> {
409         #[inline]
410         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
411                 self.params.write(w)?;
412                 self.channel_failures.write(w)?;
413                 write_tlv_fields!(w, {});
414                 Ok(())
415         }
416 }
417
418 impl<T: Time> Readable for ScorerUsingTime<T> {
419         #[inline]
420         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
421                 let res = Ok(Self {
422                         params: Readable::read(r)?,
423                         channel_failures: Readable::read(r)?,
424                 });
425                 read_tlv_fields!(r, {});
426                 res
427         }
428 }
429
430 impl<T: Time> Writeable for ChannelFailure<T> {
431         #[inline]
432         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
433                 let duration_since_epoch = T::duration_since_epoch() - self.last_updated.elapsed();
434                 write_tlv_fields!(w, {
435                         (0, self.undecayed_penalty_msat, required),
436                         (2, duration_since_epoch, required),
437                 });
438                 Ok(())
439         }
440 }
441
442 impl<T: Time> Readable for ChannelFailure<T> {
443         #[inline]
444         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
445                 let mut undecayed_penalty_msat = 0;
446                 let mut duration_since_epoch = Duration::from_secs(0);
447                 read_tlv_fields!(r, {
448                         (0, undecayed_penalty_msat, required),
449                         (2, duration_since_epoch, required),
450                 });
451                 Ok(Self {
452                         undecayed_penalty_msat,
453                         last_updated: T::now() - (T::duration_since_epoch() - duration_since_epoch),
454                 })
455         }
456 }
457
458 /// [`Score`] implementation using channel success probability distributions.
459 ///
460 /// Based on *Optimally Reliable & Cheap Payment Flows on the Lightning Network* by Rene Pickhardt
461 /// and Stefan Richter [[1]]. Given the uncertainty of channel liquidity balances, probability
462 /// distributions are defined based on knowledge learned from successful and unsuccessful attempts.
463 /// Then the negative log of the success probability is used to determine the cost of routing a
464 /// specific HTLC amount through a channel.
465 ///
466 /// [1]: https://arxiv.org/abs/2107.05322
467 pub struct ProbabilisticScorer<G: Deref<Target = NetworkGraph>> {
468         params: ProbabilisticScoringParameters,
469         node_id: NodeId,
470         network_graph: G,
471         // TODO: Remove entries of closed channels.
472         channel_liquidities: HashMap<u64, ChannelLiquidity>,
473 }
474
475 /// Parameters for configuring [`ProbabilisticScorer`].
476 pub struct ProbabilisticScoringParameters {
477         /// A penalty applied after multiplying by the negative log of the channel's success probability
478         /// for a payment.
479         ///
480         /// The success probability is determined by the effective channel capacity, the payment amount,
481         /// and knowledge learned from prior successful and unsuccessful payments. The lower bound of
482         /// the success probability is 0.01, effectively limiting the penalty to the range
483         /// `0..=2*liquidity_penalty_multiplier_msat`.
484         ///
485         /// Default value: 1,000 msat
486         pub liquidity_penalty_multiplier_msat: u64,
487 }
488
489 impl_writeable_tlv_based!(ProbabilisticScoringParameters, {
490         (0, liquidity_penalty_multiplier_msat, required),
491 });
492
493 /// Accounting for channel liquidity balance uncertainty.
494 ///
495 /// Direction is defined in terms of [`NodeId`] partial ordering, where the source node is the
496 /// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
497 /// offset fields gives the opposite direction.
498 struct ChannelLiquidity {
499         /// Lower channel liquidity bound in terms of an offset from zero.
500         min_liquidity_offset_msat: u64,
501
502         /// Upper channel liquidity bound in terms of an offset from the effective capacity.
503         max_liquidity_offset_msat: u64,
504 }
505
506 /// A view of [`ChannelLiquidity`] in one direction assuming a certain channel capacity.
507 struct DirectedChannelLiquidity<L: Deref<Target = u64>> {
508         min_liquidity_offset_msat: L,
509         max_liquidity_offset_msat: L,
510         capacity_msat: u64,
511 }
512
513 impl<G: Deref<Target = NetworkGraph>> ProbabilisticScorer<G> {
514         /// Creates a new scorer using the given scoring parameters for sending payments from a node
515         /// through a network graph.
516         pub fn new(
517                 params: ProbabilisticScoringParameters, node_pubkey: &PublicKey, network_graph: G
518         ) -> Self {
519                 Self {
520                         params,
521                         node_id: NodeId::from_pubkey(node_pubkey),
522                         network_graph,
523                         channel_liquidities: HashMap::new(),
524                 }
525         }
526
527         #[cfg(test)]
528         fn with_channel(mut self, short_channel_id: u64, liquidity: ChannelLiquidity) -> Self {
529                 assert!(self.channel_liquidities.insert(short_channel_id, liquidity).is_none());
530                 self
531         }
532 }
533
534 impl Default for ProbabilisticScoringParameters {
535         fn default() -> Self {
536                 Self {
537                         liquidity_penalty_multiplier_msat: 1000,
538                 }
539         }
540 }
541
542 impl ChannelLiquidity {
543         #[inline]
544         fn new() -> Self {
545                 Self {
546                         min_liquidity_offset_msat: 0,
547                         max_liquidity_offset_msat: 0,
548                 }
549         }
550
551         /// Returns a view of the channel liquidity directed from `source` to `target` assuming
552         /// `capacity_msat`.
553         fn as_directed(
554                 &self, source: &NodeId, target: &NodeId, capacity_msat: u64
555         ) -> DirectedChannelLiquidity<&u64> {
556                 let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target {
557                         (&self.min_liquidity_offset_msat, &self.max_liquidity_offset_msat)
558                 } else {
559                         (&self.max_liquidity_offset_msat, &self.min_liquidity_offset_msat)
560                 };
561
562                 DirectedChannelLiquidity {
563                         min_liquidity_offset_msat,
564                         max_liquidity_offset_msat,
565                         capacity_msat,
566                 }
567         }
568
569         /// Returns a mutable view of the channel liquidity directed from `source` to `target` assuming
570         /// `capacity_msat`.
571         fn as_directed_mut(
572                 &mut self, source: &NodeId, target: &NodeId, capacity_msat: u64
573         ) -> DirectedChannelLiquidity<&mut u64> {
574                 let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target {
575                         (&mut self.min_liquidity_offset_msat, &mut self.max_liquidity_offset_msat)
576                 } else {
577                         (&mut self.max_liquidity_offset_msat, &mut self.min_liquidity_offset_msat)
578                 };
579
580                 DirectedChannelLiquidity {
581                         min_liquidity_offset_msat,
582                         max_liquidity_offset_msat,
583                         capacity_msat,
584                 }
585         }
586 }
587
588 impl<L: Deref<Target = u64>> DirectedChannelLiquidity<L> {
589         /// Returns the success probability of routing the given HTLC `amount_msat` through the channel
590         /// in this direction.
591         fn success_probability(&self, amount_msat: u64) -> f64 {
592                 let max_liquidity_msat = self.max_liquidity_msat();
593                 let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat);
594                 if amount_msat > max_liquidity_msat {
595                         0.0
596                 } else if amount_msat <= min_liquidity_msat {
597                         1.0
598                 } else {
599                         let numerator = max_liquidity_msat + 1 - amount_msat;
600                         let denominator = max_liquidity_msat + 1 - min_liquidity_msat;
601                         numerator as f64 / denominator as f64
602                 }.max(0.01) // Lower bound the success probability to ensure some channel is selected.
603         }
604
605         /// Returns the lower bound of the channel liquidity balance in this direction.
606         fn min_liquidity_msat(&self) -> u64 {
607                 *self.min_liquidity_offset_msat
608         }
609
610         /// Returns the upper bound of the channel liquidity balance in this direction.
611         fn max_liquidity_msat(&self) -> u64 {
612                 self.capacity_msat.checked_sub(*self.max_liquidity_offset_msat).unwrap_or(0)
613         }
614 }
615
616 impl<L: DerefMut<Target = u64>> DirectedChannelLiquidity<L> {
617         /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat`.
618         fn failed_at_channel(&mut self, amount_msat: u64) {
619                 if amount_msat < self.max_liquidity_msat() {
620                         self.set_max_liquidity_msat(amount_msat);
621                 }
622         }
623
624         /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat` downstream.
625         fn failed_downstream(&mut self, amount_msat: u64) {
626                 if amount_msat > self.min_liquidity_msat() {
627                         self.set_min_liquidity_msat(amount_msat);
628                 }
629         }
630
631         /// Adjusts the channel liquidity balance bounds when successfully routing `amount_msat`.
632         fn successful(&mut self, amount_msat: u64) {
633                 let max_liquidity_msat = self.max_liquidity_msat().checked_sub(amount_msat).unwrap_or(0);
634                 self.set_max_liquidity_msat(max_liquidity_msat);
635         }
636
637         /// Adjusts the lower bound of the channel liquidity balance in this direction.
638         fn set_min_liquidity_msat(&mut self, amount_msat: u64) {
639                 *self.min_liquidity_offset_msat = amount_msat;
640
641                 if amount_msat > self.max_liquidity_msat() {
642                         *self.max_liquidity_offset_msat = 0;
643                 }
644         }
645
646         /// Adjusts the upper bound of the channel liquidity balance in this direction.
647         fn set_max_liquidity_msat(&mut self, amount_msat: u64) {
648                 *self.max_liquidity_offset_msat = self.capacity_msat.checked_sub(amount_msat).unwrap_or(0);
649
650                 if amount_msat < self.min_liquidity_msat() {
651                         *self.min_liquidity_offset_msat = 0;
652                 }
653         }
654 }
655
656 impl<G: Deref<Target = NetworkGraph>> Score for ProbabilisticScorer<G> {
657         #[allow(clippy::float_cmp)]
658         fn channel_penalty_msat(
659                 &self, short_channel_id: u64, amount_msat: u64, capacity_msat: u64, source: &NodeId,
660                 target: &NodeId
661         ) -> u64 {
662                 if *source == self.node_id || *target == self.node_id {
663                         return 0;
664                 }
665
666                 let liquidity_penalty_multiplier_msat = self.params.liquidity_penalty_multiplier_msat;
667                 let success_probability = self.channel_liquidities
668                         .get(&short_channel_id)
669                         .unwrap_or(&ChannelLiquidity::new())
670                         .as_directed(source, target, capacity_msat)
671                         .success_probability(amount_msat);
672                 if success_probability == 0.0 {
673                         u64::max_value()
674                 } else if success_probability == 1.0 {
675                         0
676                 } else {
677                         (-(success_probability.log10()) * liquidity_penalty_multiplier_msat as f64) as u64
678                 }
679         }
680
681         fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
682                 let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
683                 let network_graph = self.network_graph.read_only();
684                 let hop_sources = core::iter::once(self.node_id)
685                         .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey)));
686                 for (source, hop) in hop_sources.zip(path.iter()) {
687                         let target = NodeId::from_pubkey(&hop.pubkey);
688                         if source == self.node_id || target == self.node_id {
689                                 continue;
690                         }
691
692                         let capacity_msat = network_graph.channels()
693                                 .get(&hop.short_channel_id)
694                                 .and_then(|channel| channel.as_directed_to(&target).map(|(d, _)| d.effective_capacity()))
695                                 .unwrap_or(EffectiveCapacity::Unknown)
696                                 .as_msat();
697
698                         if hop.short_channel_id == short_channel_id {
699                                 self.channel_liquidities
700                                         .entry(hop.short_channel_id)
701                                         .or_insert_with(ChannelLiquidity::new)
702                                         .as_directed_mut(&source, &target, capacity_msat)
703                                         .failed_at_channel(amount_msat);
704                                 break;
705                         }
706
707                         self.channel_liquidities
708                                 .entry(hop.short_channel_id)
709                                 .or_insert_with(ChannelLiquidity::new)
710                                 .as_directed_mut(&source, &target, capacity_msat)
711                                 .failed_downstream(amount_msat);
712                 }
713         }
714
715         fn payment_path_successful(&mut self, path: &[&RouteHop]) {
716                 let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
717                 let network_graph = self.network_graph.read_only();
718                 let hop_sources = core::iter::once(self.node_id)
719                         .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey)));
720                 for (source, hop) in hop_sources.zip(path.iter()) {
721                         let target = NodeId::from_pubkey(&hop.pubkey);
722                         if source == self.node_id || target == self.node_id {
723                                 continue;
724                         }
725
726                         let capacity_msat = network_graph.channels()
727                                 .get(&hop.short_channel_id)
728                                 .and_then(|channel| channel.as_directed_to(&target).map(|(d, _)| d.effective_capacity()))
729                                 .unwrap_or(EffectiveCapacity::Unknown)
730                                 .as_msat();
731
732                         self.channel_liquidities
733                                 .entry(hop.short_channel_id)
734                                 .or_insert_with(ChannelLiquidity::new)
735                                 .as_directed_mut(&source, &target, capacity_msat)
736                                 .successful(amount_msat);
737                 }
738         }
739 }
740
741 impl<G: Deref<Target = NetworkGraph>> Writeable for ProbabilisticScorer<G> {
742         #[inline]
743         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
744                 self.channel_liquidities.write(w)?;
745                 write_tlv_fields!(w, {});
746                 Ok(())
747         }
748 }
749
750 impl<G: Deref<Target = NetworkGraph>> ReadableArgs<(ProbabilisticScoringParameters, &PublicKey, G)>
751 for ProbabilisticScorer<G> {
752         #[inline]
753         fn read<R: Read>(
754                 r: &mut R, args: (ProbabilisticScoringParameters, &PublicKey, G)
755         ) -> Result<Self, DecodeError> {
756                 let (params, node_pubkey, network_graph) = args;
757                 let res = Ok(Self {
758                         params,
759                         node_id: NodeId::from_pubkey(node_pubkey),
760                         network_graph,
761                         channel_liquidities: Readable::read(r)?,
762                 });
763                 read_tlv_fields!(r, {});
764                 res
765         }
766 }
767
768 impl Writeable for ChannelLiquidity {
769         #[inline]
770         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
771                 write_tlv_fields!(w, {
772                         (0, self.min_liquidity_offset_msat, required),
773                         (2, self.max_liquidity_offset_msat, required),
774                 });
775                 Ok(())
776         }
777 }
778
779 impl Readable for ChannelLiquidity {
780         #[inline]
781         fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
782                 let mut min_liquidity_offset_msat = 0;
783                 let mut max_liquidity_offset_msat = 0;
784                 read_tlv_fields!(r, {
785                         (0, min_liquidity_offset_msat, required),
786                         (2, max_liquidity_offset_msat, required),
787                 });
788                 Ok(Self {
789                         min_liquidity_offset_msat,
790                         max_liquidity_offset_msat
791                 })
792         }
793 }
794
795 pub(crate) mod time {
796         use core::ops::Sub;
797         use core::time::Duration;
798         /// A measurement of time.
799         pub trait Time: Sub<Duration, Output = Self> where Self: Sized {
800                 /// Returns an instance corresponding to the current moment.
801                 fn now() -> Self;
802
803                 /// Returns the amount of time elapsed since `self` was created.
804                 fn elapsed(&self) -> Duration;
805
806                 /// Returns the amount of time passed since the beginning of [`Time`].
807                 ///
808                 /// Used during (de-)serialization.
809                 fn duration_since_epoch() -> Duration;
810         }
811
812         /// A state in which time has no meaning.
813         #[derive(Debug, PartialEq, Eq)]
814         pub struct Eternity;
815
816         #[cfg(not(feature = "no-std"))]
817         impl Time for std::time::Instant {
818                 fn now() -> Self {
819                         std::time::Instant::now()
820                 }
821
822                 fn duration_since_epoch() -> Duration {
823                         use std::time::SystemTime;
824                         SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap()
825                 }
826
827                 fn elapsed(&self) -> Duration {
828                         std::time::Instant::elapsed(self)
829                 }
830         }
831
832         impl Time for Eternity {
833                 fn now() -> Self {
834                         Self
835                 }
836
837                 fn duration_since_epoch() -> Duration {
838                         Duration::from_secs(0)
839                 }
840
841                 fn elapsed(&self) -> Duration {
842                         Duration::from_secs(0)
843                 }
844         }
845
846         impl Sub<Duration> for Eternity {
847                 type Output = Self;
848
849                 fn sub(self, _other: Duration) -> Self {
850                         self
851                 }
852         }
853 }
854
855 pub(crate) use self::time::Time;
856
857 #[cfg(test)]
858 mod tests {
859         use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorer, ScoringParameters, ScorerUsingTime, Time};
860         use super::time::Eternity;
861
862         use ln::features::{ChannelFeatures, NodeFeatures};
863         use ln::msgs::{ChannelAnnouncement, ChannelUpdate, OptionalField, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
864         use routing::scoring::Score;
865         use routing::network_graph::{NetworkGraph, NodeId};
866         use routing::router::RouteHop;
867         use util::ser::{Readable, Writeable};
868
869         use bitcoin::blockdata::constants::genesis_block;
870         use bitcoin::hashes::Hash;
871         use bitcoin::hashes::sha256d::Hash as Sha256dHash;
872         use bitcoin::network::constants::Network;
873         use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
874         use core::cell::Cell;
875         use core::ops::Sub;
876         use core::time::Duration;
877         use io;
878
879         // `Time` tests
880
881         /// Time that can be advanced manually in tests.
882         #[derive(Debug, PartialEq, Eq)]
883         struct SinceEpoch(Duration);
884
885         impl SinceEpoch {
886                 thread_local! {
887                         static ELAPSED: Cell<Duration> = core::cell::Cell::new(Duration::from_secs(0));
888                 }
889
890                 fn advance(duration: Duration) {
891                         Self::ELAPSED.with(|elapsed| elapsed.set(elapsed.get() + duration))
892                 }
893         }
894
895         impl Time for SinceEpoch {
896                 fn now() -> Self {
897                         Self(Self::duration_since_epoch())
898                 }
899
900                 fn duration_since_epoch() -> Duration {
901                         Self::ELAPSED.with(|elapsed| elapsed.get())
902                 }
903
904                 fn elapsed(&self) -> Duration {
905                         Self::duration_since_epoch() - self.0
906                 }
907         }
908
909         impl Sub<Duration> for SinceEpoch {
910                 type Output = Self;
911
912                 fn sub(self, other: Duration) -> Self {
913                         Self(self.0 - other)
914                 }
915         }
916
917         #[test]
918         fn time_passes_when_advanced() {
919                 let now = SinceEpoch::now();
920                 assert_eq!(now.elapsed(), Duration::from_secs(0));
921
922                 SinceEpoch::advance(Duration::from_secs(1));
923                 SinceEpoch::advance(Duration::from_secs(1));
924
925                 let elapsed = now.elapsed();
926                 let later = SinceEpoch::now();
927
928                 assert_eq!(elapsed, Duration::from_secs(2));
929                 assert_eq!(later - elapsed, now);
930         }
931
932         #[test]
933         fn time_never_passes_in_an_eternity() {
934                 let now = Eternity::now();
935                 let elapsed = now.elapsed();
936                 let later = Eternity::now();
937
938                 assert_eq!(now.elapsed(), Duration::from_secs(0));
939                 assert_eq!(later - elapsed, now);
940         }
941
942         // `Scorer` tests
943
944         /// A scorer for testing with time that can be manually advanced.
945         type Scorer = ScorerUsingTime::<SinceEpoch>;
946
947         fn source_privkey() -> SecretKey {
948                 SecretKey::from_slice(&[42; 32]).unwrap()
949         }
950
951         fn target_privkey() -> SecretKey {
952                 SecretKey::from_slice(&[43; 32]).unwrap()
953         }
954
955         fn source_pubkey() -> PublicKey {
956                 let secp_ctx = Secp256k1::new();
957                 PublicKey::from_secret_key(&secp_ctx, &source_privkey())
958         }
959
960         fn target_pubkey() -> PublicKey {
961                 let secp_ctx = Secp256k1::new();
962                 PublicKey::from_secret_key(&secp_ctx, &target_privkey())
963         }
964
965         fn source_node_id() -> NodeId {
966                 NodeId::from_pubkey(&source_pubkey())
967         }
968
969         fn target_node_id() -> NodeId {
970                 NodeId::from_pubkey(&target_pubkey())
971         }
972
973         #[test]
974         fn penalizes_without_channel_failures() {
975                 let scorer = Scorer::new(ScoringParameters {
976                         base_penalty_msat: 1_000,
977                         failure_penalty_msat: 512,
978                         failure_penalty_half_life: Duration::from_secs(1),
979                         overuse_penalty_start_1024th: 1024,
980                         overuse_penalty_msat_per_1024th: 0,
981                 });
982                 let source = source_node_id();
983                 let target = target_node_id();
984                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
985
986                 SinceEpoch::advance(Duration::from_secs(1));
987                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
988         }
989
990         #[test]
991         fn accumulates_channel_failure_penalties() {
992                 let mut scorer = Scorer::new(ScoringParameters {
993                         base_penalty_msat: 1_000,
994                         failure_penalty_msat: 64,
995                         failure_penalty_half_life: Duration::from_secs(10),
996                         overuse_penalty_start_1024th: 1024,
997                         overuse_penalty_msat_per_1024th: 0,
998                 });
999                 let source = source_node_id();
1000                 let target = target_node_id();
1001                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1002
1003                 scorer.payment_path_failed(&[], 42);
1004                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
1005
1006                 scorer.payment_path_failed(&[], 42);
1007                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1008
1009                 scorer.payment_path_failed(&[], 42);
1010                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_192);
1011         }
1012
1013         #[test]
1014         fn decays_channel_failure_penalties_over_time() {
1015                 let mut scorer = Scorer::new(ScoringParameters {
1016                         base_penalty_msat: 1_000,
1017                         failure_penalty_msat: 512,
1018                         failure_penalty_half_life: Duration::from_secs(10),
1019                         overuse_penalty_start_1024th: 1024,
1020                         overuse_penalty_msat_per_1024th: 0,
1021                 });
1022                 let source = source_node_id();
1023                 let target = target_node_id();
1024                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1025
1026                 scorer.payment_path_failed(&[], 42);
1027                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1028
1029                 SinceEpoch::advance(Duration::from_secs(9));
1030                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1031
1032                 SinceEpoch::advance(Duration::from_secs(1));
1033                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1034
1035                 SinceEpoch::advance(Duration::from_secs(10 * 8));
1036                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_001);
1037
1038                 SinceEpoch::advance(Duration::from_secs(10));
1039                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1040
1041                 SinceEpoch::advance(Duration::from_secs(10));
1042                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1043         }
1044
1045         #[test]
1046         fn decays_channel_failure_penalties_without_shift_overflow() {
1047                 let mut scorer = Scorer::new(ScoringParameters {
1048                         base_penalty_msat: 1_000,
1049                         failure_penalty_msat: 512,
1050                         failure_penalty_half_life: Duration::from_secs(10),
1051                         overuse_penalty_start_1024th: 1024,
1052                         overuse_penalty_msat_per_1024th: 0,
1053                 });
1054                 let source = source_node_id();
1055                 let target = target_node_id();
1056                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1057
1058                 scorer.payment_path_failed(&[], 42);
1059                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1060
1061                 // An unchecked right shift 64 bits or more in ChannelFailure::decayed_penalty_msat would
1062                 // cause an overflow.
1063                 SinceEpoch::advance(Duration::from_secs(10 * 64));
1064                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1065
1066                 SinceEpoch::advance(Duration::from_secs(10));
1067                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1068         }
1069
1070         #[test]
1071         fn accumulates_channel_failure_penalties_after_decay() {
1072                 let mut scorer = Scorer::new(ScoringParameters {
1073                         base_penalty_msat: 1_000,
1074                         failure_penalty_msat: 512,
1075                         failure_penalty_half_life: Duration::from_secs(10),
1076                         overuse_penalty_start_1024th: 1024,
1077                         overuse_penalty_msat_per_1024th: 0,
1078                 });
1079                 let source = source_node_id();
1080                 let target = target_node_id();
1081                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1082
1083                 scorer.payment_path_failed(&[], 42);
1084                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1085
1086                 SinceEpoch::advance(Duration::from_secs(10));
1087                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1088
1089                 scorer.payment_path_failed(&[], 42);
1090                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_768);
1091
1092                 SinceEpoch::advance(Duration::from_secs(10));
1093                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_384);
1094         }
1095
1096         #[test]
1097         fn reduces_channel_failure_penalties_after_success() {
1098                 let mut scorer = Scorer::new(ScoringParameters {
1099                         base_penalty_msat: 1_000,
1100                         failure_penalty_msat: 512,
1101                         failure_penalty_half_life: Duration::from_secs(10),
1102                         overuse_penalty_start_1024th: 1024,
1103                         overuse_penalty_msat_per_1024th: 0,
1104                 });
1105                 let source = source_node_id();
1106                 let target = target_node_id();
1107                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_000);
1108
1109                 scorer.payment_path_failed(&[], 42);
1110                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1111
1112                 SinceEpoch::advance(Duration::from_secs(10));
1113                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1114
1115                 let hop = RouteHop {
1116                         pubkey: PublicKey::from_slice(target.as_slice()).unwrap(),
1117                         node_features: NodeFeatures::known(),
1118                         short_channel_id: 42,
1119                         channel_features: ChannelFeatures::known(),
1120                         fee_msat: 1,
1121                         cltv_expiry_delta: 18,
1122                 };
1123                 scorer.payment_path_successful(&[&hop]);
1124                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1125
1126                 SinceEpoch::advance(Duration::from_secs(10));
1127                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_064);
1128         }
1129
1130         #[test]
1131         fn restores_persisted_channel_failure_penalties() {
1132                 let mut scorer = Scorer::new(ScoringParameters {
1133                         base_penalty_msat: 1_000,
1134                         failure_penalty_msat: 512,
1135                         failure_penalty_half_life: Duration::from_secs(10),
1136                         overuse_penalty_start_1024th: 1024,
1137                         overuse_penalty_msat_per_1024th: 0,
1138                 });
1139                 let source = source_node_id();
1140                 let target = target_node_id();
1141
1142                 scorer.payment_path_failed(&[], 42);
1143                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1144
1145                 SinceEpoch::advance(Duration::from_secs(10));
1146                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1147
1148                 scorer.payment_path_failed(&[], 43);
1149                 assert_eq!(scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
1150
1151                 let mut serialized_scorer = Vec::new();
1152                 scorer.write(&mut serialized_scorer).unwrap();
1153
1154                 let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
1155                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1156                 assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, 1, &source, &target), 1_512);
1157         }
1158
1159         #[test]
1160         fn decays_persisted_channel_failure_penalties() {
1161                 let mut scorer = Scorer::new(ScoringParameters {
1162                         base_penalty_msat: 1_000,
1163                         failure_penalty_msat: 512,
1164                         failure_penalty_half_life: Duration::from_secs(10),
1165                         overuse_penalty_start_1024th: 1024,
1166                         overuse_penalty_msat_per_1024th: 0,
1167                 });
1168                 let source = source_node_id();
1169                 let target = target_node_id();
1170
1171                 scorer.payment_path_failed(&[], 42);
1172                 assert_eq!(scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_512);
1173
1174                 let mut serialized_scorer = Vec::new();
1175                 scorer.write(&mut serialized_scorer).unwrap();
1176
1177                 SinceEpoch::advance(Duration::from_secs(10));
1178
1179                 let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
1180                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_256);
1181
1182                 SinceEpoch::advance(Duration::from_secs(10));
1183                 assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, 1, &source, &target), 1_128);
1184         }
1185
1186         #[test]
1187         fn charges_per_1024th_penalty() {
1188                 let scorer = Scorer::new(ScoringParameters {
1189                         base_penalty_msat: 0,
1190                         failure_penalty_msat: 0,
1191                         failure_penalty_half_life: Duration::from_secs(0),
1192                         overuse_penalty_start_1024th: 256,
1193                         overuse_penalty_msat_per_1024th: 100,
1194                 });
1195                 let source = source_node_id();
1196                 let target = target_node_id();
1197
1198                 assert_eq!(scorer.channel_penalty_msat(42, 1_000, 1_024_000, &source, &target), 0);
1199                 assert_eq!(scorer.channel_penalty_msat(42, 256_999, 1_024_000, &source, &target), 0);
1200                 assert_eq!(scorer.channel_penalty_msat(42, 257_000, 1_024_000, &source, &target), 100);
1201                 assert_eq!(scorer.channel_penalty_msat(42, 258_000, 1_024_000, &source, &target), 200);
1202                 assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 256 * 100);
1203         }
1204
1205         // `ProbabilisticScorer` tests
1206
1207         fn sender_privkey() -> SecretKey {
1208                 SecretKey::from_slice(&[41; 32]).unwrap()
1209         }
1210
1211         fn recipient_privkey() -> SecretKey {
1212                 SecretKey::from_slice(&[45; 32]).unwrap()
1213         }
1214
1215         fn sender_pubkey() -> PublicKey {
1216                 let secp_ctx = Secp256k1::new();
1217                 PublicKey::from_secret_key(&secp_ctx, &sender_privkey())
1218         }
1219
1220         fn recipient_pubkey() -> PublicKey {
1221                 let secp_ctx = Secp256k1::new();
1222                 PublicKey::from_secret_key(&secp_ctx, &recipient_privkey())
1223         }
1224
1225         fn sender_node_id() -> NodeId {
1226                 NodeId::from_pubkey(&sender_pubkey())
1227         }
1228
1229         fn recipient_node_id() -> NodeId {
1230                 NodeId::from_pubkey(&recipient_pubkey())
1231         }
1232
1233         fn network_graph() -> NetworkGraph {
1234                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1235                 let mut network_graph = NetworkGraph::new(genesis_hash);
1236                 add_channel(&mut network_graph, 41, sender_privkey(), source_privkey());
1237                 add_channel(&mut network_graph, 42, source_privkey(), target_privkey());
1238                 add_channel(&mut network_graph, 43, target_privkey(), recipient_privkey());
1239
1240                 network_graph
1241         }
1242
1243         fn add_channel(
1244                 network_graph: &mut NetworkGraph, short_channel_id: u64, node_1_key: SecretKey,
1245                 node_2_key: SecretKey
1246         ) {
1247                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1248                 let node_1_secret = &SecretKey::from_slice(&[39; 32]).unwrap();
1249                 let node_2_secret = &SecretKey::from_slice(&[40; 32]).unwrap();
1250                 let secp_ctx = Secp256k1::new();
1251                 let unsigned_announcement = UnsignedChannelAnnouncement {
1252                         features: ChannelFeatures::known(),
1253                         chain_hash: genesis_hash,
1254                         short_channel_id,
1255                         node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_key),
1256                         node_id_2: PublicKey::from_secret_key(&secp_ctx, &node_2_key),
1257                         bitcoin_key_1: PublicKey::from_secret_key(&secp_ctx, &node_1_secret),
1258                         bitcoin_key_2: PublicKey::from_secret_key(&secp_ctx, &node_2_secret),
1259                         excess_data: Vec::new(),
1260                 };
1261                 let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_announcement.encode()[..])[..]);
1262                 let signed_announcement = ChannelAnnouncement {
1263                         node_signature_1: secp_ctx.sign(&msghash, &node_1_key),
1264                         node_signature_2: secp_ctx.sign(&msghash, &node_2_key),
1265                         bitcoin_signature_1: secp_ctx.sign(&msghash, &node_1_secret),
1266                         bitcoin_signature_2: secp_ctx.sign(&msghash, &node_2_secret),
1267                         contents: unsigned_announcement,
1268                 };
1269                 let chain_source: Option<&::util::test_utils::TestChainSource> = None;
1270                 network_graph.update_channel_from_announcement(
1271                         &signed_announcement, &chain_source, &secp_ctx).unwrap();
1272                 update_channel(network_graph, short_channel_id, node_1_key, 0);
1273                 update_channel(network_graph, short_channel_id, node_2_key, 1);
1274         }
1275
1276         fn update_channel(
1277                 network_graph: &mut NetworkGraph, short_channel_id: u64, node_key: SecretKey, flags: u8
1278         ) {
1279                 let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
1280                 let secp_ctx = Secp256k1::new();
1281                 let unsigned_update = UnsignedChannelUpdate {
1282                         chain_hash: genesis_hash,
1283                         short_channel_id,
1284                         timestamp: 100,
1285                         flags,
1286                         cltv_expiry_delta: 18,
1287                         htlc_minimum_msat: 0,
1288                         htlc_maximum_msat: OptionalField::Present(1_000),
1289                         fee_base_msat: 1,
1290                         fee_proportional_millionths: 0,
1291                         excess_data: Vec::new(),
1292                 };
1293                 let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_update.encode()[..])[..]);
1294                 let signed_update = ChannelUpdate {
1295                         signature: secp_ctx.sign(&msghash, &node_key),
1296                         contents: unsigned_update,
1297                 };
1298                 network_graph.update_channel(&signed_update, &secp_ctx).unwrap();
1299         }
1300
1301         fn payment_path_for_amount(amount_msat: u64) -> Vec<RouteHop> {
1302                 vec![
1303                         RouteHop {
1304                                 pubkey: source_pubkey(),
1305                                 node_features: NodeFeatures::known(),
1306                                 short_channel_id: 41,
1307                                 channel_features: ChannelFeatures::known(),
1308                                 fee_msat: 1,
1309                                 cltv_expiry_delta: 18,
1310                         },
1311                         RouteHop {
1312                                 pubkey: target_pubkey(),
1313                                 node_features: NodeFeatures::known(),
1314                                 short_channel_id: 42,
1315                                 channel_features: ChannelFeatures::known(),
1316                                 fee_msat: 2,
1317                                 cltv_expiry_delta: 18,
1318                         },
1319                         RouteHop {
1320                                 pubkey: recipient_pubkey(),
1321                                 node_features: NodeFeatures::known(),
1322                                 short_channel_id: 43,
1323                                 channel_features: ChannelFeatures::known(),
1324                                 fee_msat: amount_msat,
1325                                 cltv_expiry_delta: 18,
1326                         },
1327                 ]
1328         }
1329
1330         #[test]
1331         fn liquidity_bounds_directed_from_lowest_node_id() {
1332                 let network_graph = network_graph();
1333                 let params = ProbabilisticScoringParameters::default();
1334                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1335                         .with_channel(42,
1336                                 ChannelLiquidity {
1337                                         min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
1338                                 })
1339                         .with_channel(43,
1340                                 ChannelLiquidity {
1341                                         min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
1342                                 });
1343                 let source = source_node_id();
1344                 let target = target_node_id();
1345                 let recipient = recipient_node_id();
1346
1347                 let liquidity = scorer.channel_liquidities.get_mut(&42).unwrap();
1348                 assert!(source > target);
1349                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 100);
1350                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300);
1351                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700);
1352                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 900);
1353
1354                 liquidity.as_directed_mut(&source, &target, 1_000).set_min_liquidity_msat(200);
1355                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 200);
1356                 assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300);
1357                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700);
1358                 assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 800);
1359
1360                 let liquidity = scorer.channel_liquidities.get_mut(&43).unwrap();
1361                 assert!(target < recipient);
1362                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 700);
1363                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 900);
1364                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 100);
1365                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 300);
1366
1367                 liquidity.as_directed_mut(&target, &recipient, 1_000).set_max_liquidity_msat(200);
1368                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 0);
1369                 assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 200);
1370                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 800);
1371                 assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 1000);
1372         }
1373
1374         #[test]
1375         fn resets_liquidity_upper_bound_when_crossed_by_lower_bound() {
1376                 let network_graph = network_graph();
1377                 let params = ProbabilisticScoringParameters::default();
1378                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1379                         .with_channel(42,
1380                                 ChannelLiquidity {
1381                                         min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
1382                                 });
1383                 let source = source_node_id();
1384                 let target = target_node_id();
1385                 assert!(source > target);
1386
1387                 // Check initial bounds.
1388                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1389                         .as_directed(&source, &target, 1_000);
1390                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1391                 assert_eq!(liquidity.max_liquidity_msat(), 800);
1392
1393                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1394                         .as_directed(&target, &source, 1_000);
1395                 assert_eq!(liquidity.min_liquidity_msat(), 200);
1396                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1397
1398                 // Reset from source to target.
1399                 scorer.channel_liquidities.get_mut(&42).unwrap()
1400                         .as_directed_mut(&source, &target, 1_000)
1401                         .set_min_liquidity_msat(900);
1402
1403                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1404                         .as_directed(&source, &target, 1_000);
1405                 assert_eq!(liquidity.min_liquidity_msat(), 900);
1406                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1407
1408                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1409                         .as_directed(&target, &source, 1_000);
1410                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1411                 assert_eq!(liquidity.max_liquidity_msat(), 100);
1412
1413                 // Reset from target to source.
1414                 scorer.channel_liquidities.get_mut(&42).unwrap()
1415                         .as_directed_mut(&target, &source, 1_000)
1416                         .set_min_liquidity_msat(400);
1417
1418                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1419                         .as_directed(&source, &target, 1_000);
1420                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1421                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1422
1423                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1424                         .as_directed(&target, &source, 1_000);
1425                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1426                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1427         }
1428
1429         #[test]
1430         fn resets_liquidity_lower_bound_when_crossed_by_upper_bound() {
1431                 let network_graph = network_graph();
1432                 let params = ProbabilisticScoringParameters::default();
1433                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1434                         .with_channel(42,
1435                                 ChannelLiquidity {
1436                                         min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
1437                                 });
1438                 let source = source_node_id();
1439                 let target = target_node_id();
1440                 assert!(source > target);
1441
1442                 // Check initial bounds.
1443                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1444                         .as_directed(&source, &target, 1_000);
1445                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1446                 assert_eq!(liquidity.max_liquidity_msat(), 800);
1447
1448                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1449                         .as_directed(&target, &source, 1_000);
1450                 assert_eq!(liquidity.min_liquidity_msat(), 200);
1451                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1452
1453                 // Reset from source to target.
1454                 scorer.channel_liquidities.get_mut(&42).unwrap()
1455                         .as_directed_mut(&source, &target, 1_000)
1456                         .set_max_liquidity_msat(300);
1457
1458                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1459                         .as_directed(&source, &target, 1_000);
1460                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1461                 assert_eq!(liquidity.max_liquidity_msat(), 300);
1462
1463                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1464                         .as_directed(&target, &source, 1_000);
1465                 assert_eq!(liquidity.min_liquidity_msat(), 700);
1466                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1467
1468                 // Reset from target to source.
1469                 scorer.channel_liquidities.get_mut(&42).unwrap()
1470                         .as_directed_mut(&target, &source, 1_000)
1471                         .set_max_liquidity_msat(600);
1472
1473                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1474                         .as_directed(&source, &target, 1_000);
1475                 assert_eq!(liquidity.min_liquidity_msat(), 400);
1476                 assert_eq!(liquidity.max_liquidity_msat(), 1_000);
1477
1478                 let liquidity = scorer.channel_liquidities.get(&42).unwrap()
1479                         .as_directed(&target, &source, 1_000);
1480                 assert_eq!(liquidity.min_liquidity_msat(), 0);
1481                 assert_eq!(liquidity.max_liquidity_msat(), 600);
1482         }
1483
1484         #[test]
1485         fn increased_penalty_nearing_liquidity_upper_bound() {
1486                 let network_graph = network_graph();
1487                 let params = ProbabilisticScoringParameters::default();
1488                 let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1489                 let source = source_node_id();
1490                 let target = target_node_id();
1491
1492                 assert_eq!(scorer.channel_penalty_msat(42, 100, 100_000, &source, &target), 0);
1493                 assert_eq!(scorer.channel_penalty_msat(42, 1_000, 100_000, &source, &target), 4);
1494                 assert_eq!(scorer.channel_penalty_msat(42, 10_000, 100_000, &source, &target), 45);
1495                 assert_eq!(scorer.channel_penalty_msat(42, 100_000, 100_000, &source, &target), 2_000);
1496
1497                 assert_eq!(scorer.channel_penalty_msat(42, 125, 1_000, &source, &target), 57);
1498                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1499                 assert_eq!(scorer.channel_penalty_msat(42, 375, 1_000, &source, &target), 203);
1500                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1501                 assert_eq!(scorer.channel_penalty_msat(42, 625, 1_000, &source, &target), 425);
1502                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1503                 assert_eq!(scorer.channel_penalty_msat(42, 875, 1_000, &source, &target), 900);
1504         }
1505
1506         #[test]
1507         fn constant_penalty_outside_liquidity_bounds() {
1508                 let network_graph = network_graph();
1509                 let params = ProbabilisticScoringParameters::default();
1510                 let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
1511                         .with_channel(42,
1512                                 ChannelLiquidity { min_liquidity_offset_msat: 40, max_liquidity_offset_msat: 40 });
1513                 let source = source_node_id();
1514                 let target = target_node_id();
1515
1516                 assert_eq!(scorer.channel_penalty_msat(42, 39, 100, &source, &target), 0);
1517                 assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 0);
1518                 assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 2_000);
1519                 assert_eq!(scorer.channel_penalty_msat(42, 61, 100, &source, &target), 2_000);
1520         }
1521
1522         #[test]
1523         fn does_not_penalize_own_channel() {
1524                 let network_graph = network_graph();
1525                 let params = ProbabilisticScoringParameters::default();
1526                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1527                 let sender = sender_node_id();
1528                 let source = source_node_id();
1529                 let failed_path = payment_path_for_amount(500);
1530                 let successful_path = payment_path_for_amount(200);
1531
1532                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1533
1534                 scorer.payment_path_failed(&failed_path.iter().collect::<Vec<_>>(), 41);
1535                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1536
1537                 scorer.payment_path_successful(&successful_path.iter().collect::<Vec<_>>());
1538                 assert_eq!(scorer.channel_penalty_msat(41, 500, 1_000, &sender, &source), 0);
1539         }
1540
1541         #[test]
1542         fn sets_liquidity_lower_bound_on_downstream_failure() {
1543                 let network_graph = network_graph();
1544                 let params = ProbabilisticScoringParameters::default();
1545                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1546                 let source = source_node_id();
1547                 let target = target_node_id();
1548                 let path = payment_path_for_amount(500);
1549
1550                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1551                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1552                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1553
1554                 scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 43);
1555
1556                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 0);
1557                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 0);
1558                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 300);
1559         }
1560
1561         #[test]
1562         fn sets_liquidity_upper_bound_on_failure() {
1563                 let network_graph = network_graph();
1564                 let params = ProbabilisticScoringParameters::default();
1565                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1566                 let source = source_node_id();
1567                 let target = target_node_id();
1568                 let path = payment_path_for_amount(500);
1569
1570                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1571                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 300);
1572                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 600);
1573
1574                 scorer.payment_path_failed(&path.iter().collect::<Vec<_>>(), 42);
1575
1576                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
1577                 assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 2_000);
1578                 assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 2_000);
1579         }
1580
1581         #[test]
1582         fn reduces_liquidity_upper_bound_along_path_on_success() {
1583                 let network_graph = network_graph();
1584                 let params = ProbabilisticScoringParameters::default();
1585                 let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
1586                 let sender = sender_node_id();
1587                 let source = source_node_id();
1588                 let target = target_node_id();
1589                 let recipient = recipient_node_id();
1590                 let path = payment_path_for_amount(500);
1591
1592                 assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 0);
1593                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 124);
1594                 assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 124);
1595
1596                 scorer.payment_path_successful(&path.iter().collect::<Vec<_>>());
1597
1598                 assert_eq!(scorer.channel_penalty_msat(41, 250, 1_000, &sender, &source), 0);
1599                 assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 300);
1600                 assert_eq!(scorer.channel_penalty_msat(43, 250, 1_000, &target, &recipient), 300);
1601         }
1602 }