]> git.bitcoin.ninja Git - rust-lightning/commitdiff
vectorize
authorMatt Corallo <git@bluematt.me>
Sat, 16 Dec 2023 04:39:36 +0000 (04:39 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 17 Jul 2024 14:38:40 +0000 (14:38 +0000)
lightning/src/routing/scoring.rs

index 9f0d903ba249b9c2de1f083dc729f4d6aa5d1767..9a9e37b5338b2235e72b19da6807ae0fbea8ee59 100644 (file)
@@ -74,6 +74,7 @@ use {
        core::cell::{RefCell, RefMut, Ref},
        crate::sync::{Mutex, MutexGuard},
 };
+use crate::util::simd_f32::*;
 
 /// We define Score ever-so-slightly differently based on whether we are being built for C bindings
 /// or not. For users, `LockableScore` must somehow be writeable to disk. For Rust users, this is
@@ -1102,37 +1103,81 @@ fn linear_success_probability(
 fn nonlinear_success_probability(
        amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
        min_zero_implies_no_successes: bool,
-) -> (f64, f64) {
-       debug_assert!(min_liquidity_msat <= amount_msat);
-       debug_assert!(amount_msat < max_liquidity_msat);
-       debug_assert!(max_liquidity_msat <= capacity_msat);
+) -> (f32, f32) {
+       let min_max_amt_max_msat = FourF32::new(
+               min_liquidity_msat as f32,
+               max_liquidity_msat as f32,
+               amount_msat as f32,
+               max_liquidity_msat as f32,
+       );
+
+       let capacity = capacity_msat as f32;
+       let cap_cap_cap_cap = FourF32::new(capacity, capacity, capacity, capacity);
+
+       let zero_zero_den_num = nonlinear_success_probability_finish(
+               min_max_amt_max_msat, cap_cap_cap_cap,
+               min_zero_implies_no_successes && min_liquidity_msat == 0,
+       );
+
+       let res = zero_zero_den_num.dump();
+       (res.3, res.2)
+}
 
-       let capacity = capacity_msat as f64;
-       let min = (min_liquidity_msat as f64) / capacity;
-       let max = (max_liquidity_msat as f64) / capacity;
-       let amount = (amount_msat as f64) / capacity;
+#[inline(always)]
+fn bucket_nonlinear_success_probability(
+       amount_msat: u16, min_liquidity_msat: u16, max_liquidity_msat: u16, capacity_msat: u16,
+       min_zero_implies_no_successes: bool,
+) -> FourF32 {
+       let min_max_amt_max_msat = FourF32::from_ints(
+               min_liquidity_msat,
+               max_liquidity_msat,
+               amount_msat,
+               max_liquidity_msat,
+       );
+
+       let capacity = capacity_msat as f32;
+       let cap_cap_cap_cap = FourF32::new(capacity, capacity, capacity, capacity);
+
+       nonlinear_success_probability_finish(
+               min_max_amt_max_msat, cap_cap_cap_cap,
+               min_zero_implies_no_successes && min_liquidity_msat == 0,
+       )
+}
 
-       // Assume the channel has a probability density function of (x - 0.5)^2 for values from
-       // 0 to 1 (where 1 is the channel's full capacity). The success probability given some
-       // liquidity bounds is thus the integral under the curve from the amount to maximum
-       // estimated liquidity, divided by the same integral from the minimum to the maximum
-       // estimated liquidity bounds.
-       //
-       // Because the integral from x to y is simply (y - 0.5)^3 - (x - 0.5)^3, we can
-       // calculate the cumulative density function between the min/max bounds trivially. Note
-       // that we don't bother to normalize the CDF to total to 1, as it will come out in the
-       // division of num / den.
-       let (max_pow, amt_pow, min_pow) = three_f64_pow_3(max - 0.5, amount - 0.5, min - 0.5);
-       let mut num = max_pow - amt_pow;
-       let mut den = max_pow - min_pow;
-
-       if min_zero_implies_no_successes && min_liquidity_msat == 0 {
-               // If we have no knowledge of the channel, scale probability down by ~75%
-               den *= 21.0;
-               num *= 16.0;
+#[inline(always)]
+fn nonlinear_success_probability_finish(
+       min_max_amt_max_msat: FourF32, cap_cap_cap_cap: FourF32, times_16_on_21: bool,
+) -> FourF32 {
+       #[cfg(debug_assertions)] {
+               let args = min_max_amt_max_msat.dump();
+               debug_assert_eq!(args.1, args.3);
+
+               let cap = cap_cap_cap_cap.dump();
+               debug_assert_eq!(cap.0, cap.1);
+               debug_assert_eq!(cap.0, cap.2);
+               debug_assert_eq!(cap.0, cap.3);
+
+               debug_assert!(args.0 <= args.2);
+               debug_assert!(args.2 < args.3);
+               debug_assert!(args.3 <= cap.0);
        }
 
-       (num, den)
+       let min_max_amt_max = min_max_amt_max_msat / cap_cap_cap_cap;
+
+       let mhalf_mhalf_mhalf_mhalf = FourF32::new(-0.5f32, -0.5f32, -0.5f32, -0.5f32);
+       let min_max_amt_max_offset = min_max_amt_max + mhalf_mhalf_mhalf_mhalf;
+
+       let min_max_amt_max_sq = min_max_amt_max_offset * min_max_amt_max_offset;
+       let min_max_amt_max_pow = min_max_amt_max_sq * min_max_amt_max_offset;
+
+       let mut zero_zero_den_num = min_max_amt_max_pow.hsub();
+
+       if times_16_on_21 {
+               let zero_zero_twentyone_sixteen = FourF32::new(0.0f32, 0.0f32, 21.0f32, 16.0f32);
+               zero_zero_den_num = zero_zero_den_num * zero_zero_twentyone_sixteen;
+       }
+
+       zero_zero_den_num
 }
 
 
@@ -1162,7 +1207,7 @@ fn success_probability(
                );
                // Because our numerator and denominator max out at 0.5^3 we need to multiply them by
                // quite a large factor to get something useful (ideally in the 2^30 range).
-               const MILLIONISH: f64 = 1024.0 * 1024.0 * 32.0;
+               const MILLIONISH: f32 = 1024.0 * 1024.0 * 32.0;
                let numerator = (num * MILLIONISH) as u64 + 1;
                let mut denominator = (den * MILLIONISH) as u64 + 1;
                debug_assert!(numerator <= 1 << 30, "Got large numerator ({}) from float {}.", numerator, num);
@@ -1838,7 +1883,7 @@ mod bucketed_history {
                        if total_valid_points_tracked < FULLY_DECAYED.into() {
                                return None;
                        }
-                       let total_points_tracked_float = total_valid_points_tracked as f64;
+                       let total_points_tracked_float = total_valid_points_tracked as f32;
 
                        let mut cumulative_success_prob_times_billion = 0;
                        let mut cumulative_success_prob_float = 0.0;
@@ -1863,11 +1908,11 @@ mod bucketed_history {
                                                }
                                                let max_bucket_end_pos = BUCKET_START_POS[32 - highest_max_bucket_with_points] - 1;
                                                if payment_pos < max_bucket_end_pos {
-                                                       let (numerator, denominator) = $success_probability(payment_pos as u64, 0,
-                                                               max_bucket_end_pos as u64, POSITION_TICKS as u64 - 1, true);
+                                                       let success_probability = $success_probability(payment_pos, 0,
+                                                               max_bucket_end_pos , POSITION_TICKS - 1, true);
                                                        let bucket_points =
                                                                (min_liquidity_offset_history_buckets[0] as u64) * total_max_points;
-                                                       $accumulate_prob(numerator, denominator, bucket_points);
+                                                       $accumulate_prob(success_probability, bucket_points);
                                                }
                                        }
 
@@ -1889,13 +1934,13 @@ mod bucketed_history {
                                                                        // Success probability 0, the payment amount may be above the max liquidity
                                                                        break;
                                                                }
-                                                               let (numerator, denominator) = $success_probability(payment_pos as u64,
-                                                                       min_bucket_start_pos as u64, max_bucket_end_pos as u64,
-                                                                       POSITION_TICKS as u64 - 1, true);
+                                                               let success_probability = $success_probability(payment_pos,
+                                                                       min_bucket_start_pos, max_bucket_end_pos,
+                                                                       POSITION_TICKS - 1, true);
                                                                // Note that this multiply can only barely not overflow - two 16 bit ints plus
                                                                // 30 bits is 62 bits.
                                                                let bucket_points = (*min_bucket as u32) * (*max_bucket as u32);
-                                                               $accumulate_prob(numerator, denominator, bucket_points as u64);
+                                                               $accumulate_prob(success_probability, bucket_points as u64);
                                                        }
                                                }
                                        }
@@ -1903,18 +1948,24 @@ mod bucketed_history {
                        }
 
                        if params.linear_success_probability {
-                               let mut int_success_prob = |numerator: u64, denominator: u64, bucket_points: u64| {
+                               let success_prob_u64s = |a, b, c, d, e: bool| {
+                                       linear_success_probability(a as u64, b as u64, c as u64, d as u64, e)
+                               };
+                               let mut int_success_prob = |(numerator, denominator), bucket_points: u64| {
                                        cumulative_success_prob_times_billion += bucket_points
                                                * 1024 * 1024 * 1024 / total_valid_points_tracked
                                                * numerator / denominator;
+0.0
                                };
-                               calculate_probability!(linear_success_probability, int_success_prob);
+                               calculate_probability!(success_prob_u64s, int_success_prob);
                        } else {
-                               let mut float_success_prob = |numerator: f64, denominator: f64, bucket_points: u64| {
-                                       cumulative_success_prob_float += (bucket_points as f64)
-                                               / total_points_tracked_float * numerator / denominator;
+                               let mut float_success_prob = |zero_zero_den_num: FourF32, bucket_points: u64| {
+                                       let zero_zero_total_points = FourF32::new(0.0f32, 0.0f32, total_points_tracked_float, bucket_points as f32);
+                                       let zero_zero_rden_rnum = zero_zero_den_num * zero_zero_total_points;
+                                       let res = zero_zero_rden_rnum.dump();
+                                       cumulative_success_prob_float += res.3 / res.2;
                                };
-                               calculate_probability!(nonlinear_success_probability, float_success_prob);
+                               calculate_probability!(bucket_nonlinear_success_probability, float_success_prob);
                        }
 
                        // Once we've added all 32*32/2 32-bit success points together, we may have up to 42