Fix overflow in ProbabilisticScorer
[rust-lightning] / lightning / src / routing / scoring.rs
index 22d7012c35c08320289b13909434a43c34254dfa..d9d0bee3ca59c23ddf61898d0d5599962fac9f06 100644 (file)
@@ -673,8 +673,8 @@ impl<L: Deref<Target = u64>, T: Time, U: Deref<Target = T>> DirectedChannelLiqui
                } else if amount_msat <= min_liquidity_msat {
                        0
                } else {
-                       let numerator = max_liquidity_msat + 1 - amount_msat;
-                       let denominator = max_liquidity_msat + 1 - min_liquidity_msat;
+                       let numerator = (max_liquidity_msat - amount_msat).saturating_add(1);
+                       let denominator = (max_liquidity_msat - min_liquidity_msat).saturating_add(1);
                        approx::negative_log10_times_1024(numerator, denominator)
                                .saturating_mul(liquidity_penalty_multiplier_msat) / 1024
                }
@@ -2061,4 +2061,20 @@ mod tests {
                let scorer = ProbabilisticScorer::new(params, &network_graph);
                assert_eq!(scorer.channel_penalty_msat(42, 128, 1_024, &source, &target), 1085);
        }
+
+       #[test]
+       fn calculates_log10_without_overflowing_u64_max_value() {
+               let network_graph = network_graph();
+               let source = source_node_id();
+               let target = target_node_id();
+
+               let params = ProbabilisticScoringParameters {
+                       base_penalty_msat: 0, ..Default::default()
+               };
+               let scorer = ProbabilisticScorer::new(params, &network_graph);
+               assert_eq!(
+                       scorer.channel_penalty_msat(42, u64::max_value(), u64::max_value(), &source, &target),
+                       20_000,
+               );
+       }
 }