]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Try again with ints
authorMatt Corallo <git@bluematt.me>
Sat, 16 Dec 2023 19:06:12 +0000 (19:06 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 17 Jul 2024 14:38:41 +0000 (14:38 +0000)
lightning/src/routing/scoring.rs
lightning/src/util/simd_f32.rs

index 8b35c2e98d326345a577bc6e503f1cb5c676ec35..d911c051efb61d7240758ed44efd85512a70dc51 100644 (file)
@@ -1911,7 +1911,7 @@ mod bucketed_history {
 
                        let mut cumulative_success_prob_times_billion = 0;
                        let mut cumulative_success_prob_float = 0.0;
-                       let mut cumulative_success_points = 0.0;
+                       let mut cumulative_success_points = 0;
                        macro_rules! calculate_probability {
                                ($success_probability: ident, $accumulate_prob: ident,
                                 $payment_pos: ident, $BUCKET_START_POS: ident, $MATH_TY: ty
@@ -1950,36 +1950,43 @@ mod bucketed_history {
                                                        .map(|idx| BUCKET_START_POS[idx] + 1).unwrap_or(0);
 
                                                if payment_pos < int_min_bucket_start_pos {
-                                                       let min_bucket_float = *min_bucket as f32;
-                                                       let min_bucket_simd = FourF32::new(
-                                                               min_bucket_float, min_bucket_float, min_bucket_float, min_bucket_float
-                                                       );
                                                        let max_max_idx = 31 - min_idx;
-                                                       for (idx, chunk) in max_liquidity_offset_history_buckets.chunks(4).enumerate() {
-                                                               let (max_idx_a, max_idx_b, max_idx_c, max_idx_d) =
-                                                                       (idx * 4, idx * 4 + 1, idx * 4 + 2, idx * 4 + 3);
+                                                       for (idx, chunk) in max_liquidity_offset_history_buckets.chunks(8).enumerate() {
+                                                               let (max_idx_a, max_idx_b, max_idx_c, max_idx_d, max_idx_e, max_idx_f, max_idx_g, max_idx_h) =
+                                                                       (idx * 8, idx * 8 + 1, idx * 8 + 2, idx * 8 + 3, idx * 8 + 4, idx * 8 + 5, idx * 8 + 6, idx * 8 + 7);
 
                                                                let max_bucket_a = chunk[0];
                                                                let mut max_bucket_b = chunk[1];
                                                                let mut max_bucket_c = chunk[2];
                                                                let mut max_bucket_d = chunk[3];
+                                                               let mut max_bucket_e = chunk[4];
+                                                               let mut max_bucket_f = chunk[5];
+                                                               let mut max_bucket_g = chunk[6];
+                                                               let mut max_bucket_h = chunk[7];
 
-                                                               let max_bucket_end_pos_a = $BUCKET_START_POS[31 - max_idx_a];
-                                                               if $payment_pos >= max_bucket_end_pos_a || max_idx_a > max_max_idx {
+                                                               let max_bucket_end_pos_a = BUCKET_START_POS[31 - max_idx_a];
+                                                               if payment_pos >= max_bucket_end_pos_a || max_idx_a > max_max_idx {
                                                                        // Success probability 0, the payment amount may be above the max liquidity
                                                                        break;
                                                                }
-                                                               let max_bucket_end_pos_b = $BUCKET_START_POS[31 - max_idx_b];
-                                                               if max_idx_b > max_max_idx || $payment_pos >= max_bucket_end_pos_b { max_bucket_b = 0; }
-                                                               let max_bucket_end_pos_c = $BUCKET_START_POS[31 - max_idx_c];
-                                                               if max_idx_c > max_max_idx || $payment_pos >= max_bucket_end_pos_c { max_bucket_c = 0; }
-                                                               let max_bucket_end_pos_d = $BUCKET_START_POS[31 - max_idx_d];
-                                                               if max_idx_d > max_max_idx || $payment_pos >= max_bucket_end_pos_d { max_bucket_d = 0; }
-
-                                                               let buckets = FourF32::from_ints(max_bucket_a, max_bucket_b, max_bucket_c, max_bucket_d);
-
-                                                               let points = min_bucket_simd * buckets;
-                                                               cumulative_success_points += points.consuming_sum();
+                                                               let max_bucket_end_pos_b = BUCKET_START_POS[31 - max_idx_b];
+                                                               if max_idx_b > max_max_idx || payment_pos >= max_bucket_end_pos_b { max_bucket_b = 0; }
+                                                               let max_bucket_end_pos_c = BUCKET_START_POS[31 - max_idx_c];
+                                                               if max_idx_c > max_max_idx || payment_pos >= max_bucket_end_pos_c { max_bucket_c = 0; }
+                                                               let max_bucket_end_pos_d = BUCKET_START_POS[31 - max_idx_d];
+                                                               if max_idx_d > max_max_idx || payment_pos >= max_bucket_end_pos_d { max_bucket_d = 0; }
+                                                               let max_bucket_end_pos_e = BUCKET_START_POS[31 - max_idx_e];
+                                                               if max_idx_e > max_max_idx || payment_pos >= max_bucket_end_pos_e { max_bucket_e = 0; }
+                                                               let max_bucket_end_pos_f = BUCKET_START_POS[31 - max_idx_f];
+                                                               if max_idx_f > max_max_idx || payment_pos >= max_bucket_end_pos_f { max_bucket_f = 0; }
+                                                               let max_bucket_end_pos_g = BUCKET_START_POS[31 - max_idx_g];
+                                                               if max_idx_g > max_max_idx || payment_pos >= max_bucket_end_pos_g { max_bucket_g = 0; }
+                                                               let max_bucket_end_pos_h = BUCKET_START_POS[31 - max_idx_h];
+                                                               if max_idx_h > max_max_idx || payment_pos >= max_bucket_end_pos_h { max_bucket_h = 0; }
+
+                                                               cumulative_success_points += crate::util::simd_f32::mul_sum_8xu16(*min_bucket,
+                                                                       max_bucket_a, max_bucket_b, max_bucket_c, max_bucket_d,
+                                                                       max_bucket_e, max_bucket_f, max_bucket_g, max_bucket_h);
                                                        }
                                                } else {
                                                        for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) {
@@ -2027,8 +2034,12 @@ mod bucketed_history {
                                );
                        }
 
+                       // Once we've added all 32*32/2 32-bit success points together, we may have up to 42
+                       // bits. Thus, we still have > 20 bits left, which we multiply before dividing by
+                       // total_valid_points_tracked. We finally normalize back to billions.
+                       debug_assert!(cumulative_success_points < u64::max_value() / 1024 / 1024);
                        cumulative_success_prob_times_billion +=
-                               (cumulative_success_points / total_points_tracked_float * (1024.0 * 1024.0 * 1024.0)) as u64;
+                               cumulative_success_points * 1024 * 1024 / total_valid_points_tracked * 1024;
 
                        cumulative_success_prob_times_billion +=
                                (cumulative_success_prob_float * 1024.0 * 1024.0 * 1024.0) as u64;
index fbc3e7951788887deff88b8b0dbfac660b3194b1..68db25af40194e159296406335f8fabebf7b09b6 100644 (file)
@@ -66,6 +66,8 @@ mod x86_sse {
 
        #[repr(align(16))]
        struct AlignedFloats([f32; 4]);
+       #[repr(align(32))]
+       struct AlignedInts([u64; 4]);
 
        #[derive(Clone, Copy)]
        pub(crate) struct FourF32(__m128);
@@ -130,6 +132,30 @@ mod x86_sse {
                        Self(unsafe { _mm_sub_ps(self.0, o.0) })
                }
        }
+
+       #[inline(always)]
+       pub(crate) fn mul_sum_8xu16(multiplicand: u16, a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> u64 {
+               unsafe {
+                       let mul = _mm256_set1_epi32(multiplicand as i32);
+                       let vals = _mm256_set_epi32(a as i32, b as i32, c as i32, d as i32, e as i32, f as i32, g as i32, h as i32);
+
+                       let lo = _mm256_mullo_epi32(mul, vals);
+
+                       let zeros = _mm256_setzero_si256();
+                       let res_a = _mm256_unpacklo_epi32(lo, zeros);
+                       let res_b = _mm256_unpackhi_epi32(lo, zeros);
+
+                       let suma = _mm256_add_epi64(res_a, res_b);
+                       let res_a = _mm256_unpacklo_epi64(suma, zeros);
+                       let res_b = _mm256_unpackhi_epi64(suma, zeros);
+
+                       let sumb = _mm256_add_epi64(res_a, res_b);
+
+                       let mut res_bytes = AlignedInts([0; 4]);
+                       _mm256_store_si256(&mut res_bytes.0[0] as *mut u64 as *mut __m256i, sumb);
+                       res_bytes.0[0] + res_bytes.0[2]
+               }
+       }
 }
 #[cfg(target_feature = "sse")]
 pub(crate) use x86_sse::*;