X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Frouting%2Fscoring.rs;h=4d342562bea75afbe091af9244584dcbb2124ffb;hb=3c02e507d67f83469b6e533e551d1f08e3915343;hp=d1ed0f7cb6d55b5206ce791ebc6ec1d10ff1aebc;hpb=10cfe5c973e61251f6bf2180d1fc8d57d5e56ca5;p=rust-lightning diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index d1ed0f7c..4d342562 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -165,8 +165,7 @@ pub trait WriteableScore<'a>: LockableScore<'a> + Writeable {} #[cfg(not(c_bindings))] impl<'a, T> WriteableScore<'a> for T where T: LockableScore<'a> + Writeable {} - -/// (C-not exported) +/// This is not exported to bindings users impl<'a, T: 'a + Score> LockableScore<'a> for Mutex { type Locked = MutexGuard<'a, T>; @@ -244,7 +243,7 @@ impl MultiThreadedLockableScore { } #[cfg(c_bindings)] -/// (C-not exported) +/// This is not exported to bindings users impl<'a, T: Writeable> Writeable for RefMut<'a, T> { fn write(&self, writer: &mut W) -> Result<(), io::Error> { T::write(&**self, writer) @@ -252,7 +251,7 @@ impl<'a, T: Writeable> Writeable for RefMut<'a, T> { } #[cfg(c_bindings)] -/// (C-not exported) +/// This is not exported to bindings users impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> { fn write(&self, writer: &mut W) -> Result<(), io::Error> { S::write(&**self, writer) @@ -363,7 +362,7 @@ pub type ProbabilisticScorer = ProbabilisticScorerUsingTime::>, L: Deref, T: Time> where L::Target: Logger { params: ProbabilisticScoringParameters, @@ -510,7 +509,7 @@ pub struct ProbabilisticScoringParameters { /// node. Note that a manual penalty of `u64::max_value()` means the node would not ever be /// considered during path finding. /// - /// (C-not exported) + /// This is not exported to bindings users pub manual_node_penalties: HashMap, /// This penalty is applied when `htlc_maximum_msat` is equal to or larger than half of the @@ -550,7 +549,7 @@ struct HistoricalBucketRangeTracker { impl HistoricalBucketRangeTracker { fn new() -> Self { Self { buckets: [0; 8] } } - fn track_datapoint(&mut self, bucket_idx: u8) { + fn track_datapoint(&mut self, liquidity_offset_msat: u64, capacity_msat: u64) { // We have 8 leaky buckets for min and max liquidity. Each bucket tracks the amount of time // we spend in each bucket as a 16-bit fixed-point number with a 5 bit fractional part. // @@ -571,6 +570,12 @@ impl HistoricalBucketRangeTracker { // // The constants were picked experimentally, selecting a decay amount that restricts us // from overflowing buckets without having to cap them manually. + + // Ensure the bucket index is in the range [0, 7], even if the liquidity offset is zero or + // the channel's capacity, though the second should generally never happen. + debug_assert!(liquidity_offset_msat <= capacity_msat); + let bucket_idx: u8 = (liquidity_offset_msat * 8 / capacity_msat.saturating_add(1)) + .try_into().unwrap_or(32); // 32 is bogus for 8 buckets, and will be ignored debug_assert!(bucket_idx < 8); if bucket_idx < 8 { for e in self.buckets.iter_mut() { @@ -1028,12 +1033,12 @@ impl, BRT: Deref, if params.historical_liquidity_penalty_multiplier_msat != 0 || params.historical_liquidity_penalty_amount_multiplier_msat != 0 { let payment_amt_64th_bucket = if amount_msat < u64::max_value() / 64 { - amount_msat * 64 / self.capacity_msat + amount_msat * 64 / self.capacity_msat.saturating_add(1) } else { // Only use 128-bit arithmetic when multiplication will overflow to avoid 128-bit // division. This branch should only be hit in fuzz testing since the amount would // need to be over 2.88 million BTC in practice. - ((amount_msat as u128) * 64 / (self.capacity_msat as u128)) + ((amount_msat as u128) * 64 / (self.capacity_msat as u128).saturating_add(1)) .try_into().unwrap_or(65) }; #[cfg(not(fuzzing))] @@ -1123,6 +1128,7 @@ impl, BRT: DerefMut, BRT: DerefMut, BRT: DerefMut, BRT: DerefMut, BRT: DerefMut, BRT: DerefMut = None; network_graph.update_channel_from_announcement( &signed_announcement, &chain_source).unwrap(); - update_channel(network_graph, short_channel_id, node_1_key, 0); - update_channel(network_graph, short_channel_id, node_2_key, 1); + update_channel(network_graph, short_channel_id, node_1_key, 0, 1_000); + update_channel(network_graph, short_channel_id, node_2_key, 1, 0); } fn update_channel( network_graph: &mut NetworkGraph<&TestLogger>, short_channel_id: u64, node_key: SecretKey, - flags: u8 + flags: u8, htlc_maximum_msat: u64 ) { let genesis_hash = genesis_block(Network::Testnet).header.block_hash(); let secp_ctx = Secp256k1::new(); @@ -1831,7 +1833,7 @@ mod tests { flags, cltv_expiry_delta: 18, htlc_minimum_msat: 0, - htlc_maximum_msat: 1_000, + htlc_maximum_msat, fee_base_msat: 1, fee_proportional_millionths: 0, excess_data: Vec::new(), @@ -2751,6 +2753,7 @@ mod tests { let logger = TestLogger::new(); let network_graph = network_graph(&logger); let params = ProbabilisticScoringParameters { + liquidity_offset_half_life: Duration::from_secs(60 * 60), historical_liquidity_penalty_multiplier_msat: 1024, historical_liquidity_penalty_amount_multiplier_msat: 1024, historical_no_updates_half_life: Duration::from_secs(10), @@ -2800,7 +2803,26 @@ mod tests { effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024, htlc_maximum_msat: 1_024 }, }; scorer.payment_path_failed(&payment_path_for_amount(1).iter().collect::>(), 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 2048); + assert_eq!(scorer.channel_penalty_msat(42, &source, &target, usage), 409); + + let usage = ChannelUsage { + amount_msat: 1, + inflight_htlc_msat: 0, + effective_capacity: EffectiveCapacity::MaximumHTLC { amount_msat: 0 }, + }; + assert_eq!(scorer.channel_penalty_msat(42, &target, &source, usage), 2048); + + // Advance to decay all liquidity offsets to zero. + SinceEpoch::advance(Duration::from_secs(60 * 60 * 10)); + + // Use a path in the opposite direction, which have zero for htlc_maximum_msat. This will + // ensure that the effective capacity is zero to test division-by-zero edge cases. + let path = vec![ + path_hop(target_pubkey(), 43, 2), + path_hop(source_pubkey(), 42, 1), + path_hop(sender_pubkey(), 41, 0), + ]; + scorer.payment_path_failed(&path.iter().collect::>(), 42); } #[test]