]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Avoid excess multiplies by multiplying in `success_probability`
authorMatt Corallo <git@bluematt.me>
Sat, 16 Dec 2023 01:52:44 +0000 (01:52 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 17 Jul 2024 14:39:29 +0000 (14:39 +0000)
A substantial portion (~12%!) of our scoring time is spent dividing
the bucket pair probability by the `success_probability` divisor.

Here, we avoid this by multiplying the bucket pair probability
in floating point math and using a floating point divide (which can
be faster in some cases). This also avoids the 2^30 multiplies that
are used to avoid rounding errors when converting the float
numerator and denominator to ints.

lightning/src/routing/scoring.rs

index 3460ce655e1c696b827321489897bd18c2bf8e85..7c733c66733dbf51443ffe1caee6978ac1f4350b 100644 (file)
@@ -1076,11 +1076,55 @@ fn three_f64_pow_3(a: f64, b: f64, c: f64) -> (f64, f64, f64) {
        (a * a * a, b * b * b, c * c * c)
 }
 
+#[inline(always)]
+fn linear_success_probability(
+       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64,
+       min_zero_implies_no_successes: bool,
+) -> (u64, u64) {
+       let (numerator, mut denominator) =
+               (max_liquidity_msat - amount_msat,
+                (max_liquidity_msat - min_liquidity_msat).saturating_add(1));
+
+       if min_zero_implies_no_successes && min_liquidity_msat == 0 &&
+               denominator < u64::max_value() / 21
+       {
+               // If we have no knowledge of the channel, scale probability down by ~75%
+               // Note that we prefer to increase the denominator rather than decrease the numerator as
+               // the denominator is more likely to be larger and thus provide greater precision. This is
+               // mostly an overoptimization but makes a large difference in tests.
+               denominator = denominator * 21 / 16
+       }
+
+       (numerator, denominator)
+}
+
+#[inline(always)]
+fn nonlinear_success_probability(
+       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
+) -> (f64, f64) {
+       let capacity = capacity_msat as f64;
+       let min = (min_liquidity_msat as f64) / capacity;
+       let max = (max_liquidity_msat as f64) / capacity;
+       let amount = (amount_msat as f64) / capacity;
+
+       // Assume the channel has a probability density function of (x - 0.5)^2 for values from
+       // 0 to 1 (where 1 is the channel's full capacity). The success probability given some
+       // liquidity bounds is thus the integral under the curve from the amount to maximum
+       // estimated liquidity, divided by the same integral from the minimum to the maximum
+       // estimated liquidity bounds.
+       //
+       // Because the integral from x to y is simply (y - 0.5)^3 - (x - 0.5)^3, we can
+       // calculate the cumulative density function between the min/max bounds trivially. Note
+       // that we don't bother to normalize the CDF to total to 1, as it will come out in the
+       // division of num / den.
+       let (max_pow, amt_pow, min_pow) = three_f64_pow_3(max - 0.5, amount - 0.5, min - 0.5);
+       (max_pow - amt_pow, max_pow - min_pow)
+}
+
+
 /// Given liquidity bounds, calculates the success probability (in the form of a numerator and
 /// denominator) of an HTLC. This is a key assumption in our scoring models.
 ///
-/// Must not return a numerator or denominator greater than 2^31 for arguments less than 2^31.
-///
 /// min_zero_implies_no_successes signals that a `min_liquidity_msat` of 0 means we've not
 /// (recently) seen an HTLC successfully complete over this channel.
 #[inline(always)]
@@ -1092,39 +1136,23 @@ fn success_probability(
        debug_assert!(amount_msat < max_liquidity_msat);
        debug_assert!(max_liquidity_msat <= capacity_msat);
 
-       let (numerator, mut denominator) =
-               if params.linear_success_probability {
-                       (max_liquidity_msat - amount_msat,
-                               (max_liquidity_msat - min_liquidity_msat).saturating_add(1))
-               } else {
-                       let capacity = capacity_msat as f64;
-                       let min = (min_liquidity_msat as f64) / capacity;
-                       let max = (max_liquidity_msat as f64) / capacity;
-                       let amount = (amount_msat as f64) / capacity;
-
-                       // Assume the channel has a probability density function of (x - 0.5)^2 for values from
-                       // 0 to 1 (where 1 is the channel's full capacity). The success probability given some
-                       // liquidity bounds is thus the integral under the curve from the amount to maximum
-                       // estimated liquidity, divided by the same integral from the minimum to the maximum
-                       // estimated liquidity bounds.
-                       //
-                       // Because the integral from x to y is simply (y - 0.5)^3 - (x - 0.5)^3, we can
-                       // calculate the cumulative density function between the min/max bounds trivially. Note
-                       // that we don't bother to normalize the CDF to total to 1, as it will come out in the
-                       // division of num / den.
-                       let (max_pow, amt_pow, min_pow) = three_f64_pow_3(max - 0.5, amount - 0.5, min - 0.5);
-                       let num = max_pow - amt_pow;
-                       let den = max_pow - min_pow;
-
-                       // Because our numerator and denominator max out at 0.5^3 we need to multiply them by
-                       // quite a large factor to get something useful (ideally in the 2^30 range).
-                       const BILLIONISH: f64 = 1024.0 * 1024.0 * 1024.0;
-                       let numerator = (num * BILLIONISH) as u64 + 1;
-                       let denominator = (den * BILLIONISH) as u64 + 1;
-                       debug_assert!(numerator <= 1 << 30, "Got large numerator ({}) from float {}.", numerator, num);
-                       debug_assert!(denominator <= 1 << 30, "Got large denominator ({}) from float {}.", denominator, den);
-                       (numerator, denominator)
-               };
+       if params.linear_success_probability {
+               return linear_success_probability(
+                       amount_msat, min_liquidity_msat, max_liquidity_msat, min_zero_implies_no_successes
+               );
+       }
+
+       let (num, den) = nonlinear_success_probability(
+               amount_msat, min_liquidity_msat, max_liquidity_msat, capacity_msat
+       );
+
+       // Because our numerator and denominator max out at 0.5^3 we need to multiply them by
+       // quite a large factor to get something useful (ideally in the 2^30 range).
+       const BILLIONISH: f64 = 1024.0 * 1024.0 * 1024.0;
+       let numerator = (num * BILLIONISH) as u64 + 1;
+       let mut denominator = (den * BILLIONISH) as u64 + 1;
+       debug_assert!(numerator <= 1 << 30, "Got large numerator ({}) from float {}.", numerator, num);
+       debug_assert!(denominator <= 1 << 30, "Got large denominator ({}) from float {}.", denominator, den);
 
        if min_zero_implies_no_successes && min_liquidity_msat == 0 &&
                denominator < u64::max_value() / 21
@@ -1139,6 +1167,45 @@ fn success_probability(
        (numerator, denominator)
 }
 
+/// Given liquidity bounds, calculates the success probability (times some value) of an HTLC. This
+/// is a key assumption in our scoring models.
+///
+/// min_zero_implies_no_successes signals that a `min_liquidity_msat` of 0 means we've not
+/// (recently) seen an HTLC successfully complete over this channel.
+#[inline(always)]
+fn success_probability_times_value(
+       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
+       params: &ProbabilisticScoringFeeParameters, min_zero_implies_no_successes: bool,
+       value: u32,
+) -> u64 {
+       debug_assert!(min_liquidity_msat <= amount_msat);
+       debug_assert!(amount_msat < max_liquidity_msat);
+       debug_assert!(max_liquidity_msat <= capacity_msat);
+
+       if params.linear_success_probability {
+               let (numerator, denominator) = linear_success_probability(
+                       amount_msat, min_liquidity_msat, max_liquidity_msat, min_zero_implies_no_successes
+               );
+               return (value as u64) * numerator / denominator;
+       }
+
+       let (num, mut den) = nonlinear_success_probability(
+               amount_msat, min_liquidity_msat, max_liquidity_msat, capacity_msat
+       );
+
+       if min_zero_implies_no_successes && min_liquidity_msat == 0 {
+               // If we have no knowledge of the channel, scale probability down by ~75%
+               // Note that we prefer to increase the denominator rather than decrease the numerator as
+               // the denominator is more likely to be larger and thus provide greater precision. This is
+               // mostly an overoptimization but makes a large difference in tests.
+               den = den * 21.0 / 16.0
+       }
+
+       let res = (value as f64) * num / den;
+
+       res as u64
+}
+
 impl<L: Deref<Target = u64>, HT: Deref<Target = HistoricalLiquidityTracker>, T: Deref<Target = Duration>>
 DirectedChannelLiquidity< L, HT, T> {
        /// Returns a liquidity penalty for routing the given HTLC `amount_msat` through the channel in
@@ -1825,13 +1892,14 @@ mod bucketed_history {
                                }
                                let max_bucket_end_pos = BUCKET_START_POS[32 - highest_max_bucket_with_points] - 1;
                                if payment_pos < max_bucket_end_pos {
-                                       let (numerator, denominator) = success_probability(payment_pos as u64, 0,
-                                               max_bucket_end_pos as u64, POSITION_TICKS as u64 - 1, params, true);
                                        let bucket_prob_times_billion =
                                                (min_liquidity_offset_history_buckets[0] as u64) * total_max_points
                                                        * 1024 * 1024 * 1024 / total_valid_points_tracked;
-                                       cumulative_success_prob_times_billion += bucket_prob_times_billion *
-                                               numerator / denominator;
+                                       debug_assert!(bucket_prob_times_billion < u32::max_value() as u64);
+                                       cumulative_success_prob_times_billion += success_probability_times_value(
+                                               payment_pos as u64, 0, max_bucket_end_pos as u64,
+                                               POSITION_TICKS as u64 - 1, params, true, bucket_prob_times_billion as u32
+                                       );
                                }
                        }
 
@@ -1840,32 +1908,33 @@ mod bucketed_history {
                                if payment_pos < min_bucket_start_pos {
                                        for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) {
                                                let max_bucket_end_pos = BUCKET_START_POS[32 - max_idx] - 1;
-                                               // Note that this multiply can only barely not overflow - two 16 bit ints plus
-                                               // 30 bits is 62 bits.
-                                               let bucket_prob_times_billion = (*min_bucket as u64) * (*max_bucket as u64)
-                                                       * 1024 * 1024 * 1024 / total_valid_points_tracked;
                                                if payment_pos >= max_bucket_end_pos {
                                                        // Success probability 0, the payment amount may be above the max liquidity
                                                        break;
                                                }
+                                               // Note that this multiply can only barely not overflow - two 16 bit ints plus
+                                               // 30 bits is 62 bits.
+                                               let bucket_prob_times_billion = ((*min_bucket as u32) * (*max_bucket as u32)) as u64
+                                                       * 1024 * 1024 * 1024 / total_valid_points_tracked;
+                                               debug_assert!(bucket_prob_times_billion < u32::max_value() as u64);
                                                cumulative_success_prob_times_billion += bucket_prob_times_billion;
                                        }
                                } else {
                                        for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) {
                                                let max_bucket_end_pos = BUCKET_START_POS[32 - max_idx] - 1;
-                                               // Note that this multiply can only barely not overflow - two 16 bit ints plus
-                                               // 30 bits is 62 bits.
-                                               let bucket_prob_times_billion = (*min_bucket as u64) * (*max_bucket as u64)
-                                                       * 1024 * 1024 * 1024 / total_valid_points_tracked;
                                                if payment_pos >= max_bucket_end_pos {
                                                        // Success probability 0, the payment amount may be above the max liquidity
                                                        break;
                                                }
-                                               let (numerator, denominator) = success_probability(payment_pos as u64,
-                                                       min_bucket_start_pos as u64, max_bucket_end_pos as u64,
-                                                       POSITION_TICKS as u64 - 1, params, true);
-                                               cumulative_success_prob_times_billion += bucket_prob_times_billion *
-                                                       numerator / denominator;
+                                               // Note that this multiply can only barely not overflow - two 16 bit ints plus
+                                               // 30 bits is 62 bits.
+                                               let bucket_prob_times_billion = ((*min_bucket as u32) * (*max_bucket as u32)) as u64
+                                                       * 1024 * 1024 * 1024 / total_valid_points_tracked;
+                                               debug_assert!(bucket_prob_times_billion < u32::max_value() as u64);
+                                               cumulative_success_prob_times_billion += success_probability_times_value(
+                                                       payment_pos as u64, min_bucket_start_pos as u64,
+                                                       max_bucket_end_pos as u64, POSITION_TICKS as u64 - 1, params, true,
+                                                       bucket_prob_times_billion as u32);
                                        }
                                }
                        }