From: Jeffrey Czyz Date: Mon, 3 Jan 2022 14:35:19 +0000 (-0600) Subject: Probabilistic channel scoring X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=4f1b31314f3b9c5bd86ca0b83e7403155c60c5db;p=rust-lightning Probabilistic channel scoring Add a Score implementation based on "Optimally Reliable & Cheap Payment Flows on the Lightning Network" by Rene Pickhardt and Stefan Richter[1]. Given the uncertainty of channel liquidity balances, probability distributions are defined based on knowledge learned from successful and unsuccessful attempts. Then the negative log of the success probability is used to determine the cost of routing a specific HTLC amount through a channel. [1]: https://arxiv.org/abs/2107.05322 --- diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index 810fd2020..173a33cdd 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -51,14 +51,16 @@ //! //! [`find_route`]: crate::routing::router::find_route +use bitcoin::secp256k1::key::PublicKey; + use ln::msgs::DecodeError; -use routing::network_graph::NodeId; +use routing::network_graph::{EffectiveCapacity, NetworkGraph, NodeId}; use routing::router::RouteHop; -use util::ser::{Readable, Writeable, Writer}; +use util::ser::{Readable, ReadableArgs, Writeable, Writer}; use prelude::*; use core::cell::{RefCell, RefMut}; -use core::ops::DerefMut; +use core::ops::{Deref, DerefMut}; use core::time::Duration; use io::{self, Read}; use sync::{Mutex, MutexGuard}; @@ -451,6 +453,334 @@ impl Readable for ChannelFailure { } } +/// [`Score`] implementation using channel success probability distributions. +/// +/// Based on *Optimally Reliable & Cheap Payment Flows on the Lightning Network* by Rene Pickhardt +/// and Stefan Richter [[1]]. Given the uncertainty of channel liquidity balances, probability +/// distributions are defined based on knowledge learned from successful and unsuccessful attempts. +/// Then the negative log of the success probability is used to determine the cost of routing a +/// specific HTLC amount through a channel. +/// +/// [1]: https://arxiv.org/abs/2107.05322 +pub struct ProbabilisticScorer> { + _params: ProbabilisticScoringParameters, + node_id: NodeId, + network_graph: G, + // TODO: Remove entries of closed channels. + channel_liquidities: HashMap, +} + +/// Parameters for configuring [`ProbabilisticScorer`]. +pub struct ProbabilisticScoringParameters; + +impl_writeable_tlv_based!(ProbabilisticScoringParameters, { +}); + +/// Accounting for channel liquidity balance uncertainty. +/// +/// Direction is defined in terms of [`NodeId`] partial ordering, where the source node is the +/// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity +/// offset fields gives the opposite direction. +struct ChannelLiquidity { + min_liquidity_offset_msat: u64, + max_liquidity_offset_msat: u64, +} + +/// A view of [`ChannelLiquidity`] in one direction assuming a certain channel capacity. +struct DirectedChannelLiquidity> { + min_liquidity_offset_msat: L, + max_liquidity_offset_msat: L, + capacity_msat: u64, +} + +/// The likelihood of an event occurring. +enum Probability { + Zero, + One, + Ratio { numerator: u64, denominator: u64 }, +} + +impl> ProbabilisticScorer { + /// Creates a new scorer using the given scoring parameters for sending payments from a node + /// through a network graph. + pub fn new( + _params: ProbabilisticScoringParameters, node_pubkey: PublicKey, network_graph: G + ) -> Self { + Self { + _params, + node_id: NodeId::from_pubkey(&node_pubkey), + network_graph, + channel_liquidities: HashMap::new(), + } + } + + #[cfg(test)] + fn with_channel(mut self, short_channel_id: u64, liquidity: ChannelLiquidity) -> Self { + assert!(self.channel_liquidities.insert(short_channel_id, liquidity).is_none()); + self + } +} + +impl Default for ProbabilisticScoringParameters { + fn default() -> Self { + Self + } +} + +impl ChannelLiquidity { + #[inline] + fn new() -> Self { + Self { + min_liquidity_offset_msat: 0, + max_liquidity_offset_msat: 0, + } + } + + /// Returns a view of the channel liquidity directed from `source` to `target` assuming + /// `capacity_msat`. + fn as_directed( + &self, source: &NodeId, target: &NodeId, capacity_msat: u64 + ) -> DirectedChannelLiquidity<&u64> { + let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target { + (&self.min_liquidity_offset_msat, &self.max_liquidity_offset_msat) + } else { + (&self.max_liquidity_offset_msat, &self.min_liquidity_offset_msat) + }; + + DirectedChannelLiquidity { + min_liquidity_offset_msat, + max_liquidity_offset_msat, + capacity_msat, + } + } + + /// Returns a mutable view of the channel liquidity directed from `source` to `target` assuming + /// `capacity_msat`. + fn as_directed_mut( + &mut self, source: &NodeId, target: &NodeId, capacity_msat: u64 + ) -> DirectedChannelLiquidity<&mut u64> { + let (min_liquidity_offset_msat, max_liquidity_offset_msat) = if source < target { + (&mut self.min_liquidity_offset_msat, &mut self.max_liquidity_offset_msat) + } else { + (&mut self.max_liquidity_offset_msat, &mut self.min_liquidity_offset_msat) + }; + + DirectedChannelLiquidity { + min_liquidity_offset_msat, + max_liquidity_offset_msat, + capacity_msat, + } + } +} + +impl> DirectedChannelLiquidity { + /// Returns the success probability of routing the given HTLC `amount_msat` through the channel + /// in this direction. + fn success_probability(&self, amount_msat: u64) -> Probability { + let max_liquidity_msat = self.max_liquidity_msat(); + let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat); + if amount_msat > max_liquidity_msat { + Probability::Zero + } else if amount_msat < min_liquidity_msat { + Probability::One + } else { + let numerator = max_liquidity_msat + 1 - amount_msat; + let denominator = max_liquidity_msat + 1 - min_liquidity_msat; + if numerator == denominator { + Probability::One + } else { + Probability::Ratio { numerator, denominator } + } + } + } + + /// Returns the lower bound of the channel liquidity balance in this direction. + fn min_liquidity_msat(&self) -> u64 { + *self.min_liquidity_offset_msat + } + + /// Returns the upper bound of the channel liquidity balance in this direction. + fn max_liquidity_msat(&self) -> u64 { + self.capacity_msat.checked_sub(*self.max_liquidity_offset_msat).unwrap_or(0) + } +} + +impl> DirectedChannelLiquidity { + /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat`. + fn failed_at_channel(&mut self, amount_msat: u64) { + if amount_msat < self.max_liquidity_msat() { + self.set_max_liquidity_msat(amount_msat); + } + } + + /// Adjusts the channel liquidity balance bounds when failing to route `amount_msat` downstream. + fn failed_downstream(&mut self, amount_msat: u64) { + if amount_msat > self.min_liquidity_msat() { + self.set_min_liquidity_msat(amount_msat); + } + } + + /// Adjusts the channel liquidity balance bounds when successfully routing `amount_msat`. + fn successful(&mut self, amount_msat: u64) { + let max_liquidity_msat = self.max_liquidity_msat().checked_sub(amount_msat).unwrap_or(0); + self.set_max_liquidity_msat(max_liquidity_msat); + } + + /// Adjusts the lower bound of the channel liquidity balance in this direction. + fn set_min_liquidity_msat(&mut self, amount_msat: u64) { + *self.min_liquidity_offset_msat = amount_msat; + + if amount_msat > self.max_liquidity_msat() { + *self.max_liquidity_offset_msat = 0; + } + } + + /// Adjusts the upper bound of the channel liquidity balance in this direction. + fn set_max_liquidity_msat(&mut self, amount_msat: u64) { + *self.max_liquidity_offset_msat = self.capacity_msat.checked_sub(amount_msat).unwrap_or(0); + + if amount_msat < self.min_liquidity_msat() { + *self.min_liquidity_offset_msat = 0; + } + } +} + +impl> Score for ProbabilisticScorer { + fn channel_penalty_msat( + &self, short_channel_id: u64, amount_msat: u64, capacity_msat: u64, source: &NodeId, + target: &NodeId + ) -> u64 { + if *source == self.node_id || *target == self.node_id { + return 0; + } + + let success_probability = self.channel_liquidities + .get(&short_channel_id) + .unwrap_or(&ChannelLiquidity::new()) + .as_directed(source, target, capacity_msat) + .success_probability(amount_msat); + match success_probability { + Probability::Zero => u64::max_value(), + Probability::One => 0, + Probability::Ratio { numerator, denominator } => { + let success_probability = numerator as f64 / denominator as f64; + (-(success_probability.log10()) * amount_msat as f64) as u64 + }, + } + } + + fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) { + let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0); + let network_graph = self.network_graph.read_only(); + let hop_sources = core::iter::once(self.node_id) + .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey))); + for (source, hop) in hop_sources.zip(path.iter()) { + let target = NodeId::from_pubkey(&hop.pubkey); + if source == self.node_id || target == self.node_id { + continue; + } + + let capacity_msat = network_graph.channels() + .get(&hop.short_channel_id) + .and_then(|channel| channel.as_directed_to(&target).map(|d| d.effective_capacity())) + .unwrap_or(EffectiveCapacity::Unknown) + .as_msat(); + + if hop.short_channel_id == short_channel_id { + self.channel_liquidities + .entry(hop.short_channel_id) + .or_insert_with(|| ChannelLiquidity::new()) + .as_directed_mut(&source, &target, capacity_msat) + .failed_at_channel(amount_msat); + break; + } + + self.channel_liquidities + .entry(hop.short_channel_id) + .or_insert_with(|| ChannelLiquidity::new()) + .as_directed_mut(&source, &target, capacity_msat) + .failed_downstream(amount_msat); + } + } + + fn payment_path_successful(&mut self, path: &[&RouteHop]) { + let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0); + let network_graph = self.network_graph.read_only(); + let hop_sources = core::iter::once(self.node_id) + .chain(path.iter().map(|hop| NodeId::from_pubkey(&hop.pubkey))); + for (source, hop) in hop_sources.zip(path.iter()) { + let target = NodeId::from_pubkey(&hop.pubkey); + if source == self.node_id || target == self.node_id { + continue; + } + + let capacity_msat = network_graph.channels() + .get(&hop.short_channel_id) + .and_then(|channel| channel.as_directed_to(&target).map(|d| d.effective_capacity())) + .unwrap_or(EffectiveCapacity::Unknown) + .as_msat(); + + self.channel_liquidities + .entry(hop.short_channel_id) + .or_insert_with(|| ChannelLiquidity::new()) + .as_directed_mut(&source, &target, capacity_msat) + .successful(amount_msat); + } + } +} + +impl> Writeable for ProbabilisticScorer { + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + self._params.write(w)?; + self.node_id.write(w)?; + self.channel_liquidities.write(w)?; + write_tlv_fields!(w, {}); + Ok(()) + } +} + +impl> ReadableArgs for ProbabilisticScorer { + #[inline] + fn read(r: &mut R, args: G) -> Result { + let res = Ok(Self { + _params: Readable::read(r)?, + node_id: Readable::read(r)?, + network_graph: args, + channel_liquidities: Readable::read(r)?, + }); + read_tlv_fields!(r, {}); + res + } +} + +impl Writeable for ChannelLiquidity { + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + write_tlv_fields!(w, { + (0, self.min_liquidity_offset_msat, required), + (2, self.max_liquidity_offset_msat, required), + }); + Ok(()) + } +} + +impl Readable for ChannelLiquidity { + #[inline] + fn read(r: &mut R) -> Result { + let mut min_liquidity_offset_msat = 0; + let mut max_liquidity_offset_msat = 0; + read_tlv_fields!(r, { + (0, min_liquidity_offset_msat, required), + (2, max_liquidity_offset_msat, required), + }); + Ok(Self { + min_liquidity_offset_msat, + max_liquidity_offset_msat + }) + } +} + pub(crate) mod time { use core::ops::Sub; use core::time::Duration; @@ -515,21 +845,28 @@ pub(crate) use self::time::Time; #[cfg(test)] mod tests { - use super::{ScoringParameters, ScorerUsingTime, Time}; + use super::{ChannelLiquidity, ProbabilisticScoringParameters, ProbabilisticScorer, ScoringParameters, ScorerUsingTime, Time}; use super::time::Eternity; use ln::features::{ChannelFeatures, NodeFeatures}; + use ln::msgs::{ChannelAnnouncement, ChannelUpdate, OptionalField, UnsignedChannelAnnouncement, UnsignedChannelUpdate}; use routing::scoring::Score; - use routing::network_graph::NodeId; + use routing::network_graph::{NetworkGraph, NodeId}; use routing::router::RouteHop; use util::ser::{Readable, Writeable}; - use bitcoin::secp256k1::PublicKey; + use bitcoin::blockdata::constants::genesis_block; + use bitcoin::hashes::Hash; + use bitcoin::hashes::sha256d::Hash as Sha256dHash; + use bitcoin::network::constants::Network; + use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; use core::cell::Cell; use core::ops::Sub; use core::time::Duration; use io; + // `Time` tests + /// Time that can be advanced manually in tests. #[derive(Debug, PartialEq, Eq)] struct SinceEpoch(Duration); @@ -591,15 +928,35 @@ mod tests { assert_eq!(later - elapsed, now); } + // `Scorer` tests + /// A scorer for testing with time that can be manually advanced. type Scorer = ScorerUsingTime::; + fn source_privkey() -> SecretKey { + SecretKey::from_slice(&[42; 32]).unwrap() + } + + fn target_privkey() -> SecretKey { + SecretKey::from_slice(&[43; 32]).unwrap() + } + + fn source_pubkey() -> PublicKey { + let secp_ctx = Secp256k1::new(); + PublicKey::from_secret_key(&secp_ctx, &source_privkey()) + } + + fn target_pubkey() -> PublicKey { + let secp_ctx = Secp256k1::new(); + PublicKey::from_secret_key(&secp_ctx, &target_privkey()) + } + fn source_node_id() -> NodeId { - NodeId::from_pubkey(&PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap()) + NodeId::from_pubkey(&source_pubkey()) } fn target_node_id() -> NodeId { - NodeId::from_pubkey(&PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap()) + NodeId::from_pubkey(&target_pubkey()) } #[test] @@ -833,4 +1190,239 @@ mod tests { assert_eq!(scorer.channel_penalty_msat(42, 258_000, 1_024_000, &source, &target), 200); assert_eq!(scorer.channel_penalty_msat(42, 512_000, 1_024_000, &source, &target), 256 * 100); } + + // `ProbabilisticScorer` tests + + fn sender_privkey() -> SecretKey { + SecretKey::from_slice(&[41; 32]).unwrap() + } + + fn recipient_privkey() -> SecretKey { + SecretKey::from_slice(&[45; 32]).unwrap() + } + + fn sender_pubkey() -> PublicKey { + let secp_ctx = Secp256k1::new(); + PublicKey::from_secret_key(&secp_ctx, &sender_privkey()) + } + + fn recipient_pubkey() -> PublicKey { + let secp_ctx = Secp256k1::new(); + PublicKey::from_secret_key(&secp_ctx, &recipient_privkey()) + } + + fn sender_node_id() -> NodeId { + NodeId::from_pubkey(&sender_pubkey()) + } + + fn recipient_node_id() -> NodeId { + NodeId::from_pubkey(&recipient_pubkey()) + } + + fn network_graph() -> NetworkGraph { + let genesis_hash = genesis_block(Network::Testnet).header.block_hash(); + let mut network_graph = NetworkGraph::new(genesis_hash); + add_channel(&mut network_graph, 41, sender_privkey(), source_privkey()); + add_channel(&mut network_graph, 42, source_privkey(), target_privkey()); + add_channel(&mut network_graph, 43, target_privkey(), recipient_privkey()); + + network_graph + } + + fn add_channel( + network_graph: &mut NetworkGraph, short_channel_id: u64, node_1_key: SecretKey, + node_2_key: SecretKey + ) { + let genesis_hash = genesis_block(Network::Testnet).header.block_hash(); + let node_1_secret = &SecretKey::from_slice(&[39; 32]).unwrap(); + let node_2_secret = &SecretKey::from_slice(&[40; 32]).unwrap(); + let secp_ctx = Secp256k1::new(); + let unsigned_announcement = UnsignedChannelAnnouncement { + features: ChannelFeatures::known(), + chain_hash: genesis_hash, + short_channel_id, + node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_key), + node_id_2: PublicKey::from_secret_key(&secp_ctx, &node_2_key), + bitcoin_key_1: PublicKey::from_secret_key(&secp_ctx, &node_1_secret), + bitcoin_key_2: PublicKey::from_secret_key(&secp_ctx, &node_2_secret), + excess_data: Vec::new(), + }; + let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_announcement.encode()[..])[..]); + let signed_announcement = ChannelAnnouncement { + node_signature_1: secp_ctx.sign(&msghash, &node_1_key), + node_signature_2: secp_ctx.sign(&msghash, &node_2_key), + bitcoin_signature_1: secp_ctx.sign(&msghash, &node_1_secret), + bitcoin_signature_2: secp_ctx.sign(&msghash, &node_2_secret), + contents: unsigned_announcement, + }; + let chain_source: Option<&::util::test_utils::TestChainSource> = None; + network_graph.update_channel_from_announcement( + &signed_announcement, &chain_source, &secp_ctx).unwrap(); + update_channel(network_graph, short_channel_id, node_1_key, 0); + update_channel(network_graph, short_channel_id, node_2_key, 1); + } + + fn update_channel( + network_graph: &mut NetworkGraph, short_channel_id: u64, node_key: SecretKey, flags: u8 + ) { + let genesis_hash = genesis_block(Network::Testnet).header.block_hash(); + let secp_ctx = Secp256k1::new(); + let unsigned_update = UnsignedChannelUpdate { + chain_hash: genesis_hash, + short_channel_id, + timestamp: 100, + flags, + cltv_expiry_delta: 18, + htlc_minimum_msat: 0, + htlc_maximum_msat: OptionalField::Present(1_000), + fee_base_msat: 1, + fee_proportional_millionths: 0, + excess_data: Vec::new(), + }; + let msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_update.encode()[..])[..]); + let signed_update = ChannelUpdate { + signature: secp_ctx.sign(&msghash, &node_key), + contents: unsigned_update, + }; + network_graph.update_channel(&signed_update, &secp_ctx).unwrap(); + } + + fn payment_path(amount_msat: u64) -> Vec { + vec![ + RouteHop { + pubkey: source_pubkey(), + node_features: NodeFeatures::known(), + short_channel_id: 41, + channel_features: ChannelFeatures::known(), + fee_msat: 1, + cltv_expiry_delta: 18, + }, + RouteHop { + pubkey: target_pubkey(), + node_features: NodeFeatures::known(), + short_channel_id: 42, + channel_features: ChannelFeatures::known(), + fee_msat: 2, + cltv_expiry_delta: 18, + }, + RouteHop { + pubkey: recipient_pubkey(), + node_features: NodeFeatures::known(), + short_channel_id: 43, + channel_features: ChannelFeatures::known(), + fee_msat: amount_msat, + cltv_expiry_delta: 18, + }, + ] + } + + #[test] + fn liquidity_bounds_directed_from_lowest_node_id() { + let network_graph = network_graph(); + let params = ProbabilisticScoringParameters::default(); + let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph) + .with_channel(42, + ChannelLiquidity { + min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100 + }) + .with_channel(43, + ChannelLiquidity { + min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100 + }); + let source = source_node_id(); + let target = target_node_id(); + let recipient = recipient_node_id(); + + let liquidity = scorer.channel_liquidities.get_mut(&42).unwrap(); + assert!(source > target); + assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 100); + assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300); + assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700); + assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 900); + + liquidity.as_directed_mut(&source, &target, 1_000).set_min_liquidity_msat(200); + assert_eq!(liquidity.as_directed(&source, &target, 1_000).min_liquidity_msat(), 200); + assert_eq!(liquidity.as_directed(&source, &target, 1_000).max_liquidity_msat(), 300); + assert_eq!(liquidity.as_directed(&target, &source, 1_000).min_liquidity_msat(), 700); + assert_eq!(liquidity.as_directed(&target, &source, 1_000).max_liquidity_msat(), 800); + + let liquidity = scorer.channel_liquidities.get_mut(&43).unwrap(); + assert!(target < recipient); + assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 700); + assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 900); + assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 100); + assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 300); + + liquidity.as_directed_mut(&target, &recipient, 1_000).set_max_liquidity_msat(200); + assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).min_liquidity_msat(), 0); + assert_eq!(liquidity.as_directed(&target, &recipient, 1_000).max_liquidity_msat(), 200); + assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).min_liquidity_msat(), 800); + assert_eq!(liquidity.as_directed(&recipient, &target, 1_000).max_liquidity_msat(), 1000); + } + + #[test] + fn increased_penalty_nearing_liquidity_upper_bound() { + let network_graph = network_graph(); + let params = ProbabilisticScoringParameters::default(); + let scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph); + let source = source_node_id(); + let target = target_node_id(); + + assert_eq!(scorer.channel_penalty_msat(42, 100, 100_000, &source, &target), 0); + assert_eq!(scorer.channel_penalty_msat(42, 1_000, 100_000, &source, &target), 4); + assert_eq!(scorer.channel_penalty_msat(42, 10_000, 100_000, &source, &target), 457); + assert_eq!(scorer.channel_penalty_msat(42, 100_000, 100_000, &source, &target), 500_000); + + assert_eq!(scorer.channel_penalty_msat(42, 125, 1_000, &source, &target), 7); + assert_eq!(scorer.channel_penalty_msat(42, 250, 1_000, &source, &target), 31); + assert_eq!(scorer.channel_penalty_msat(42, 375, 1_000, &source, &target), 76); + assert_eq!(scorer.channel_penalty_msat(42, 500, 1_000, &source, &target), 150); + assert_eq!(scorer.channel_penalty_msat(42, 625, 1_000, &source, &target), 265); + assert_eq!(scorer.channel_penalty_msat(42, 750, 1_000, &source, &target), 450); + assert_eq!(scorer.channel_penalty_msat(42, 875, 1_000, &source, &target), 787); + } + + #[test] + fn constant_penalty_outside_liquidity_bounds() { + let network_graph = network_graph(); + let params = ProbabilisticScoringParameters::default(); + let scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph) + .with_channel(42, + ChannelLiquidity { min_liquidity_offset_msat: 40, max_liquidity_offset_msat: 40 }); + let source = source_node_id(); + let target = target_node_id(); + + assert_eq!(scorer.channel_penalty_msat(42, 39, 100, &source, &target), 0); + assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), 0); + assert_ne!(scorer.channel_penalty_msat(42, 50, 100, &source, &target), u64::max_value()); + assert_eq!(scorer.channel_penalty_msat(42, 61, 100, &source, &target), u64::max_value()); + } + + #[test] + fn reduces_liquidity_upper_bound_on_success() { + let network_graph = network_graph(); + let params = ProbabilisticScoringParameters::default(); + let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph) + .with_channel(42, + ChannelLiquidity { min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 0 }) + .with_channel(43, + ChannelLiquidity { min_liquidity_offset_msat: 0, max_liquidity_offset_msat: 400 }); + let sender = sender_node_id(); + let source = source_node_id(); + let target = target_node_id(); + let recipient = recipient_node_id(); + let path = payment_path(200); + + assert_eq!(scorer.channel_penalty_msat(41, 200, 1_000, &sender, &source), 0); + assert_eq!(scorer.channel_penalty_msat(42, 200, 1_000, &source, &target), 94); + assert_eq!(scorer.channel_penalty_msat(43, 200, 1_000, &target, &recipient), 35); + + scorer.payment_path_successful(&path.iter().collect::>()); + + assert_eq!(scorer.channel_penalty_msat(41, 200, 1_000, &sender, &source), 0); + assert_eq!(scorer.channel_penalty_msat(42, 200, 1_000, &source, &target), u64::max_value()); + assert_eq!(scorer.channel_penalty_msat(43, 200, 1_000, &target, &recipient), 59); + } + + // TODO: Add more test coverage }