]> git.bitcoin.ninja Git - rust-lightning/commitdiff
avoid float conversion in inner loop
authorMatt Corallo <git@bluematt.me>
Sat, 16 Dec 2023 17:42:35 +0000 (17:42 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 17 Jul 2024 14:38:41 +0000 (14:38 +0000)
lightning/src/routing/scoring.rs

index c745bb42e98eab3221aa0f3ae0e4bb22ca34f634..0274aed21ebb11320ea130ace211c1f8d0629c60 100644 (file)
@@ -1125,22 +1125,21 @@ fn nonlinear_success_probability(
 
 #[inline(always)]
 fn bucket_nonlinear_success_probability(
-       amount_msat: u16, min_liquidity_msat: u16, max_liquidity_msat: u16, capacity_msat: u16,
+       amount_msat: f32, min_liquidity_msat: f32, max_liquidity_msat: f32, capacity_msat: f32,
        min_zero_implies_no_successes: bool,
 ) -> FourF32 {
-       let min_max_amt_max_msat = FourF32::from_ints(
+       let min_max_amt_max_msat = FourF32::new(
                min_liquidity_msat,
                max_liquidity_msat,
                amount_msat,
                max_liquidity_msat,
        );
 
-       let capacity = capacity_msat as f32;
-       let cap_cap_cap_cap = FourF32::new(capacity, capacity, capacity, capacity);
+       let cap_cap_cap_cap = FourF32::new(capacity_msat, capacity_msat, capacity_msat, capacity_msat);
 
        nonlinear_success_probability_finish(
                min_max_amt_max_msat, cap_cap_cap_cap,
-               min_zero_implies_no_successes && min_liquidity_msat == 0,
+               min_zero_implies_no_successes && min_liquidity_msat == 0.0,
        )
 }
 
@@ -1581,7 +1580,7 @@ mod bucketed_history {
 
        // By default u16s may not be cache-aligned, but we'd rather not have to read a third cache
        // line just to access it
-       #[repr(align(128))]
+       #[repr(align(64))]
        struct BucketStartPos([u16; 32]);
        impl BucketStartPos {
                const fn new() -> Self {
@@ -1598,6 +1597,26 @@ mod bucketed_history {
        }
        const BUCKET_START_POS: BucketStartPos = BucketStartPos::new();
 
+       // By default f32s may not be cache-aligned, but we'd rather not have to read a third cache
+       // line just to access it
+       #[repr(align(128))]
+       struct FloatBucketStartPos([f32; 32]);
+       impl FloatBucketStartPos {
+               const fn new() -> Self {
+                       Self([
+                               0.0, 1.0, 3.0, 7.0, 15.0, 31.0, 63.0, 127.0, 255.0, 511.0, 1023.0, 2047.0, 3071.0,
+                               4095.0, 6143.0, 8191.0, 10239.0, 12287.0, 13311.0, 14335.0, 15359.0, 15871.0, 16127.0,
+                               16255.0, 16319.0, 16351.0, 16367.0, 16375.0, 16379.0, 16381.0, 16382.0, 16383.0,
+                       ])
+               }
+       }
+       impl core::ops::Index<usize> for FloatBucketStartPos {
+               type Output = f32;
+               #[inline(always)]
+               fn index(&self, index: usize) -> &f32 { &self.0[index] }
+       }
+       const FLOAT_BUCKET_START_POS: FloatBucketStartPos = FloatBucketStartPos::new();
+
        const LEGACY_TO_BUCKET_RANGE: [(u8, u8); 8] = [
                (0, 12), (12, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 20), (20, 32)
        ];
@@ -1626,6 +1645,7 @@ mod bucketed_history {
                for (bucket, width) in BUCKET_WIDTH_IN_16384S.iter().enumerate() {
                        if bucket != 0 {
                                assert_eq!(BUCKET_START_POS[bucket - 1] + 1, min_size_iter);
+                               assert_eq!(FLOAT_BUCKET_START_POS[bucket - 1] + 1.0, min_size_iter as f32);
                        }
                        for i in 0..*width {
                                assert_eq!(pos_to_bucket(min_size_iter + i) as usize, bucket);
@@ -1640,6 +1660,7 @@ mod bucketed_history {
                        }
                }
                assert_eq!(BUCKET_START_POS[31], POSITION_TICKS - 1);
+               assert_eq!(FLOAT_BUCKET_START_POS[31], POSITION_TICKS as f32 - 1.0);
                assert_eq!(min_size_iter, POSITION_TICKS);
        }
 
@@ -1886,12 +1907,15 @@ mod bucketed_history {
                                return None;
                        }
                        let total_points_tracked_float = total_valid_points_tracked as f32;
+                       let payment_pos_float = payment_pos as f32;
 
                        let mut cumulative_success_prob_times_billion = 0;
                        let mut cumulative_success_prob_float = 0.0;
                        let mut cumulative_success_points = 0;
                        macro_rules! calculate_probability {
-                               ($success_probability: ident, $accumulate_prob: ident) => { {
+                               ($success_probability: ident, $accumulate_prob: ident,
+                                $payment_pos: ident, $BUCKET_START_POS: ident, $MATH_TY: ty
+                               ) => { {
                                        // Special-case the 0th min bucket - it generally means we failed a payment, so only
                                        // consider the highest (i.e. largest-offset-from-max-capacity) max bucket for all
                                        // points against the 0th min bucket. This avoids the case where we fail to route
@@ -1908,10 +1932,10 @@ mod bucketed_history {
                                                        }
                                                        total_max_points += *max_bucket as u64;
                                                }
-                                               let max_bucket_end_pos = BUCKET_START_POS[31 - highest_max_bucket_with_points];
-                                               if payment_pos < max_bucket_end_pos {
-                                                       let success_probability = $success_probability(payment_pos, 0,
-                                                               max_bucket_end_pos , POSITION_TICKS - 1, true);
+                                               let max_bucket_end_pos = $BUCKET_START_POS[31 - highest_max_bucket_with_points];
+                                               if $payment_pos < max_bucket_end_pos {
+                                                       let success_probability = $success_probability($payment_pos, 0 as $MATH_TY,
+                                                               max_bucket_end_pos, (POSITION_TICKS - 1) as $MATH_TY, true);
                                                        let bucket_points =
                                                                (min_liquidity_offset_history_buckets[0] as u64) * total_max_points;
                                                        $accumulate_prob(success_probability, bucket_points);
@@ -1919,8 +1943,13 @@ mod bucketed_history {
                                        }
 
                                        for (min_idx, min_bucket) in min_liquidity_offset_history_buckets.iter().enumerate().skip(1) {
-                                               let min_bucket_start_pos = min_idx.checked_sub(1).map(|idx| BUCKET_START_POS[idx] + 1).unwrap_or(0);
-                                               if payment_pos < min_bucket_start_pos {
+                                               const ONE: $MATH_TY = 1 as $MATH_TY;
+                                               let min_bucket_start_pos = min_idx.checked_sub(1)
+                                                       .map(|idx| $BUCKET_START_POS[idx] + ONE).unwrap_or(0 as $MATH_TY);
+                                               let int_min_bucket_start_pos = min_idx.checked_sub(1)
+                                                       .map(|idx| BUCKET_START_POS[idx] + 1).unwrap_or(0);
+
+                                               if payment_pos < int_min_bucket_start_pos {
                                                        for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) {
                                                                let max_bucket_end_pos = BUCKET_START_POS[31 - max_idx];
                                                                if payment_pos >= max_bucket_end_pos {
@@ -1931,14 +1960,14 @@ mod bucketed_history {
                                                        }
                                                } else {
                                                        for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) {
-                                                               let max_bucket_end_pos = BUCKET_START_POS[31 - max_idx];
-                                                               if payment_pos >= max_bucket_end_pos {
+                                                               let max_bucket_end_pos = $BUCKET_START_POS[31 - max_idx];
+                                                               if $payment_pos >= max_bucket_end_pos {
                                                                        // Success probability 0, the payment amount may be above the max liquidity
                                                                        break;
                                                                }
-                                                               let success_probability = $success_probability(payment_pos,
+                                                               let success_probability = $success_probability($payment_pos,
                                                                        min_bucket_start_pos, max_bucket_end_pos,
-                                                                       POSITION_TICKS - 1, true);
+                                                                       (POSITION_TICKS - 1) as $MATH_TY, true);
                                                                // Note that this multiply can only barely not overflow - two 16 bit ints plus
                                                                // 30 bits is 62 bits.
                                                                let bucket_points = (*min_bucket as u32) * (*max_bucket as u32);
@@ -1959,7 +1988,9 @@ mod bucketed_history {
                                                * numerator / denominator;
 0.0
                                };
-                               calculate_probability!(success_prob_u64s, int_success_prob);
+                               calculate_probability!(
+                                       success_prob_u64s, int_success_prob, payment_pos, BUCKET_START_POS, u16
+                               );
                        } else {
                                let mut float_success_prob = |zero_zero_den_num: FourF32, bucket_points: u64| {
                                        let zero_zero_total_points = FourF32::new(0.0f32, 0.0f32, total_points_tracked_float, bucket_points as f32);
@@ -1967,7 +1998,10 @@ mod bucketed_history {
                                        let res = zero_zero_rden_rnum.dump();
                                        cumulative_success_prob_float += res.3 / res.2;
                                };
-                               calculate_probability!(bucket_nonlinear_success_probability, float_success_prob);
+                               calculate_probability!(
+                                       bucket_nonlinear_success_probability, float_success_prob,
+                                       payment_pos_float, FLOAT_BUCKET_START_POS, f32
+                               );
                        }
 
                        // Once we've added all 32*32/2 32-bit success points together, we may have up to 42