SIMD 4x interleaved loop w/ floats 2023-12-scoring-simd
authorMatt Corallo <git@bluematt.me>
Thu, 14 Dec 2023 06:15:24 +0000 (06:15 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 15 Dec 2023 18:11:01 +0000 (18:11 +0000)
lightning/src/routing/scoring.rs

index c5c875f4acbf0b512b26730328473200199ba735..d95b4a01d4719c2a7cce3ebf3a3b761b3f0ef9f2 100644 (file)
@@ -1932,38 +1932,48 @@ mod bucketed_history {
                                zero_min_bucket!(nonlinear_success_probability_f, cumulative_float_success_prob);
                        }
 
+                       let total_points_float = total_valid_points_tracked as f32;
+                       let total_points_simd = FourF32::new(
+                               total_points_float, total_points_float, total_points_float, total_points_float,
+                       );
+
+                       macro_rules! main_liq_loop { ($prob: ident, $accum: ident) => { {
                        for (min_idx, min_bucket) in min_liquidity_offset_history_buckets.iter().enumerate().skip(1) {
                                let min_bucket_start_pos = BUCKET_START_POS[min_idx];
                                let max_max_idx = 31 - min_idx;
+                               let min_bucket_float = *min_bucket as f32;
+                               let min_bucket_simd = FourF32::new(
+                                       min_bucket_float, min_bucket_float, min_bucket_float, min_bucket_float
+                               );
                                if payment_pos < min_bucket_start_pos {
-                                       for (idx, chunk) in max_liquidity_offset_history_buckets.chunks(2).enumerate().take(16 - min_idx/2) {
-                                               let max_idx_a = idx * 2;
-                                               let max_idx_b = idx * 2 + 1;
+                                       for (idx, chunk) in max_liquidity_offset_history_buckets.chunks(4).enumerate() {
+                                               let (max_idx_a, max_idx_b, max_idx_c, max_idx_d) =
+                                                       (idx * 4, idx * 4 + 1, idx * 4 + 2, idx * 4 + 3);
 
                                                let max_bucket_a = chunk[0];
                                                let mut max_bucket_b = chunk[1];
+                                               let mut max_bucket_c = chunk[2];
+                                               let mut max_bucket_d = chunk[3];
 
                                                let max_bucket_end_pos_a = BUCKET_START_POS[32 - max_idx_a] - 1;
-                                               if payment_pos >= max_bucket_end_pos_a {
+                                               if payment_pos >= max_bucket_end_pos_a || max_idx_a > max_max_idx {
                                                        // Success probability 0, the payment amount may be above the max liquidity
                                                        break;
                                                }
                                                let max_bucket_end_pos_b = BUCKET_START_POS[32 - max_idx_b] - 1;
-                                               if max_idx_b > max_max_idx || payment_pos >= max_bucket_end_pos_b { max_bucket_b = 0 }
+                                               if max_idx_b > max_max_idx || payment_pos >= max_bucket_end_pos_b { max_bucket_b = 0; }
+                                               let max_bucket_end_pos_c = BUCKET_START_POS[32 - max_idx_c] - 1;
+                                               if max_idx_c > max_max_idx || payment_pos >= max_bucket_end_pos_c { max_bucket_c = 0; }
+                                               let max_bucket_end_pos_d = BUCKET_START_POS[32 - max_idx_d] - 1;
+                                               if max_idx_d > max_max_idx || payment_pos >= max_bucket_end_pos_d { max_bucket_d = 0; }
 
-                                               // Note that this multiply can only barely not overflow - two 16 bit ints plus
-                                               // 30 bits is 62 bits.
-                                               let bucket_prob_times_billion_a = ((*min_bucket as u32) * (max_bucket_a as u32)) as u64
-                                                       * 1024 * 1024 * 1024 / total_valid_points_tracked;
-                                               let bucket_prob_times_billion_b = ((*min_bucket as u32) * (max_bucket_b as u32)) as u64
-                                                       * 1024 * 1024 * 1024 / total_valid_points_tracked;
-                                               debug_assert!(bucket_prob_times_billion_a < u32::max_value() as u64);
-                                               debug_assert!(bucket_prob_times_billion_b < u32::max_value() as u64);
-                                               cumulative_success_prob_times_billion += bucket_prob_times_billion_a;
-                                               cumulative_success_prob_times_billion += bucket_prob_times_billion_b;
+                                               let buckets = FourF32::from_ints(max_bucket_a, max_bucket_b, max_bucket_c, max_bucket_d);
+
+                                               let min_times_max = min_bucket_simd * buckets;
+                                               let ratio = min_times_max / total_points_simd;
+                                               cumulative_float_success_prob += ratio.consuming_sum();
                                        }
                                } else {
-                       macro_rules! main_liq_loop { ($prob: ident, $accum: ident) => { {
                                        for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) {
                                                let max_bucket_end_pos = BUCKET_START_POS[32 - max_idx] - 1;
                                                if payment_pos >= max_bucket_end_pos {
@@ -1978,14 +1988,14 @@ mod bucketed_history {
                                                        max_bucket_end_pos, POSITION_TICKS - 1, true,
                                                        bucket_points, total_valid_points_tracked);
                                        }
+                               }
+                       }
                        } } }
                        if params.linear_success_probability {
                                main_liq_loop!(linear_success_probability_times_value_times_billion, cumulative_success_prob_times_billion);
                        } else {
                                main_liq_loop!(nonlinear_success_probability_f, cumulative_float_success_prob);
                        }
-                               }
-                       }
 
                        const BILLIONISH: f32 = 1024.0 * 1024.0 * 1024.0;
                        cumulative_success_prob_times_billion += (cumulative_float_success_prob * BILLIONISH) as u64;