From b2c22151ee3fd6e1894a57fad1349217545fc9ed Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sat, 16 Dec 2023 19:06:12 +0000 Subject: [PATCH] Try again with ints --- lightning/src/routing/scoring.rs | 55 +++++++++++++++++++------------- lightning/src/util/simd_f32.rs | 26 +++++++++++++++ 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index 8b35c2e98..d911c051e 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -1911,7 +1911,7 @@ mod bucketed_history { let mut cumulative_success_prob_times_billion = 0; let mut cumulative_success_prob_float = 0.0; - let mut cumulative_success_points = 0.0; + let mut cumulative_success_points = 0; macro_rules! calculate_probability { ($success_probability: ident, $accumulate_prob: ident, $payment_pos: ident, $BUCKET_START_POS: ident, $MATH_TY: ty @@ -1950,36 +1950,43 @@ mod bucketed_history { .map(|idx| BUCKET_START_POS[idx] + 1).unwrap_or(0); if payment_pos < int_min_bucket_start_pos { - let min_bucket_float = *min_bucket as f32; - let min_bucket_simd = FourF32::new( - min_bucket_float, min_bucket_float, min_bucket_float, min_bucket_float - ); let max_max_idx = 31 - min_idx; - for (idx, chunk) in max_liquidity_offset_history_buckets.chunks(4).enumerate() { - let (max_idx_a, max_idx_b, max_idx_c, max_idx_d) = - (idx * 4, idx * 4 + 1, idx * 4 + 2, idx * 4 + 3); + for (idx, chunk) in max_liquidity_offset_history_buckets.chunks(8).enumerate() { + let (max_idx_a, max_idx_b, max_idx_c, max_idx_d, max_idx_e, max_idx_f, max_idx_g, max_idx_h) = + (idx * 8, idx * 8 + 1, idx * 8 + 2, idx * 8 + 3, idx * 8 + 4, idx * 8 + 5, idx * 8 + 6, idx * 8 + 7); let max_bucket_a = chunk[0]; let mut max_bucket_b = chunk[1]; let mut max_bucket_c = chunk[2]; let mut max_bucket_d = chunk[3]; + let mut max_bucket_e = chunk[4]; + let mut max_bucket_f = chunk[5]; + let mut max_bucket_g = chunk[6]; + let mut max_bucket_h = chunk[7]; - let max_bucket_end_pos_a = $BUCKET_START_POS[31 - max_idx_a]; - if $payment_pos >= max_bucket_end_pos_a || max_idx_a > max_max_idx { + let max_bucket_end_pos_a = BUCKET_START_POS[31 - max_idx_a]; + if payment_pos >= max_bucket_end_pos_a || max_idx_a > max_max_idx { // Success probability 0, the payment amount may be above the max liquidity break; } - let max_bucket_end_pos_b = $BUCKET_START_POS[31 - max_idx_b]; - if max_idx_b > max_max_idx || $payment_pos >= max_bucket_end_pos_b { max_bucket_b = 0; } - let max_bucket_end_pos_c = $BUCKET_START_POS[31 - max_idx_c]; - if max_idx_c > max_max_idx || $payment_pos >= max_bucket_end_pos_c { max_bucket_c = 0; } - let max_bucket_end_pos_d = $BUCKET_START_POS[31 - max_idx_d]; - if max_idx_d > max_max_idx || $payment_pos >= max_bucket_end_pos_d { max_bucket_d = 0; } - - let buckets = FourF32::from_ints(max_bucket_a, max_bucket_b, max_bucket_c, max_bucket_d); - - let points = min_bucket_simd * buckets; - cumulative_success_points += points.consuming_sum(); + let max_bucket_end_pos_b = BUCKET_START_POS[31 - max_idx_b]; + if max_idx_b > max_max_idx || payment_pos >= max_bucket_end_pos_b { max_bucket_b = 0; } + let max_bucket_end_pos_c = BUCKET_START_POS[31 - max_idx_c]; + if max_idx_c > max_max_idx || payment_pos >= max_bucket_end_pos_c { max_bucket_c = 0; } + let max_bucket_end_pos_d = BUCKET_START_POS[31 - max_idx_d]; + if max_idx_d > max_max_idx || payment_pos >= max_bucket_end_pos_d { max_bucket_d = 0; } + let max_bucket_end_pos_e = BUCKET_START_POS[31 - max_idx_e]; + if max_idx_e > max_max_idx || payment_pos >= max_bucket_end_pos_e { max_bucket_e = 0; } + let max_bucket_end_pos_f = BUCKET_START_POS[31 - max_idx_f]; + if max_idx_f > max_max_idx || payment_pos >= max_bucket_end_pos_f { max_bucket_f = 0; } + let max_bucket_end_pos_g = BUCKET_START_POS[31 - max_idx_g]; + if max_idx_g > max_max_idx || payment_pos >= max_bucket_end_pos_g { max_bucket_g = 0; } + let max_bucket_end_pos_h = BUCKET_START_POS[31 - max_idx_h]; + if max_idx_h > max_max_idx || payment_pos >= max_bucket_end_pos_h { max_bucket_h = 0; } + + cumulative_success_points += crate::util::simd_f32::mul_sum_8xu16(*min_bucket, + max_bucket_a, max_bucket_b, max_bucket_c, max_bucket_d, + max_bucket_e, max_bucket_f, max_bucket_g, max_bucket_h); } } else { for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) { @@ -2027,8 +2034,12 @@ mod bucketed_history { ); } + // Once we've added all 32*32/2 32-bit success points together, we may have up to 42 + // bits. Thus, we still have > 20 bits left, which we multiply before dividing by + // total_valid_points_tracked. We finally normalize back to billions. + debug_assert!(cumulative_success_points < u64::max_value() / 1024 / 1024); cumulative_success_prob_times_billion += - (cumulative_success_points / total_points_tracked_float * (1024.0 * 1024.0 * 1024.0)) as u64; + cumulative_success_points * 1024 * 1024 / total_valid_points_tracked * 1024; cumulative_success_prob_times_billion += (cumulative_success_prob_float * 1024.0 * 1024.0 * 1024.0) as u64; diff --git a/lightning/src/util/simd_f32.rs b/lightning/src/util/simd_f32.rs index fbc3e7951..68db25af4 100644 --- a/lightning/src/util/simd_f32.rs +++ b/lightning/src/util/simd_f32.rs @@ -66,6 +66,8 @@ mod x86_sse { #[repr(align(16))] struct AlignedFloats([f32; 4]); + #[repr(align(32))] + struct AlignedInts([u64; 4]); #[derive(Clone, Copy)] pub(crate) struct FourF32(__m128); @@ -130,6 +132,30 @@ mod x86_sse { Self(unsafe { _mm_sub_ps(self.0, o.0) }) } } + + #[inline(always)] + pub(crate) fn mul_sum_8xu16(multiplicand: u16, a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> u64 { + unsafe { + let mul = _mm256_set1_epi32(multiplicand as i32); + let vals = _mm256_set_epi32(a as i32, b as i32, c as i32, d as i32, e as i32, f as i32, g as i32, h as i32); + + let lo = _mm256_mullo_epi32(mul, vals); + + let zeros = _mm256_setzero_si256(); + let res_a = _mm256_unpacklo_epi32(lo, zeros); + let res_b = _mm256_unpackhi_epi32(lo, zeros); + + let suma = _mm256_add_epi64(res_a, res_b); + let res_a = _mm256_unpacklo_epi64(suma, zeros); + let res_b = _mm256_unpackhi_epi64(suma, zeros); + + let sumb = _mm256_add_epi64(res_a, res_b); + + let mut res_bytes = AlignedInts([0; 4]); + _mm256_store_si256(&mut res_bytes.0[0] as *mut u64 as *mut __m256i, sumb); + res_bytes.0[0] + res_bytes.0[2] + } + } } #[cfg(target_feature = "sse")] pub(crate) use x86_sse::*; -- 2.39.5