(prob) f use new simd impl wrappers
authorMatt Corallo <git@bluematt.me>
Fri, 15 Dec 2023 02:20:34 +0000 (02:20 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 15 Dec 2023 02:25:25 +0000 (02:25 +0000)
lightning/src/routing/scoring.rs

index f8edd54331ad6fa5bc9228eca68f7a91f9eabd29..defffaf1ae164b113505a01b0c853d6fe4934a28 100644 (file)
@@ -62,6 +62,7 @@ use crate::routing::router::{Path, CandidateRouteHop};
 use crate::routing::log_approx;
 use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer};
 use crate::util::logger::Logger;
+use crate::util::simd_f32::FourF32;
 
 use crate::prelude::*;
 use core::{cmp, fmt};
@@ -1095,81 +1096,42 @@ fn linear_success_probability(
        (numerator, denominator)
 }
 
-#[repr(align(16))]
-struct AlignedFloats([f32; 4]);
-
 #[inline(always)]
-#[cfg(target_feature = "sse")]
-unsafe fn do_nonlinear_success_probability(
-       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
+fn nonlinear_success_probability_f(
+       amount_msat: u16, min_liquidity_msat: u16, max_liquidity_msat: u16, capacity_msat: u16,
        value_numerator: u64, value_denominator: u64, times_16_on_21: bool,
 ) -> f32 {
-       #[cfg(target_arch = "x86")]
-       use std::arch::x86::*;
-       #[cfg(target_arch = "x86_64")]
-       use std::arch::x86_64::*;
-
-       let min_max_amt_max_msat = _mm_set_ps(
-               min_liquidity_msat as f32,
-               max_liquidity_msat as f32,
-               amount_msat as f32,
-               max_liquidity_msat as f32,
+       let min_max_amt_max_msat = FourF32::from_ints(
+               min_liquidity_msat,
+               max_liquidity_msat,
+               amount_msat,
+               max_liquidity_msat,
        );
 
        let capacity = capacity_msat as f32;
-       let cap_cap_cap_cap = _mm_set_ps(capacity, capacity, capacity, capacity);
+       let cap_cap_cap_cap = FourF32::new(capacity, capacity, capacity, capacity);
 
-       let min_max_amt_max = _mm_div_ps(min_max_amt_max_msat, cap_cap_cap_cap);
+       let min_max_amt_max = min_max_amt_max_msat / cap_cap_cap_cap;
 
-       let mhalf_mhalf_mhalf_mhalf = _mm_set_ps(-0.5f32, -0.5f32, -0.5f32, -0.5f32);
-       let min_max_amt_max_offset = _mm_add_ps(min_max_amt_max, mhalf_mhalf_mhalf_mhalf);
+       let mhalf_mhalf_mhalf_mhalf = FourF32::new(-0.5f32, -0.5f32, -0.5f32, -0.5f32);
+       let min_max_amt_max_offset = min_max_amt_max + mhalf_mhalf_mhalf_mhalf;
 
-       let min_max_amt_max_sq = _mm_mul_ps(min_max_amt_max_offset, min_max_amt_max_offset);
-       let min_max_amt_max_pow = _mm_mul_ps(min_max_amt_max_sq, min_max_amt_max_offset);
+       let min_max_amt_max_sq = min_max_amt_max_offset * min_max_amt_max_offset;
+       let min_max_amt_max_pow = min_max_amt_max_sq * min_max_amt_max_offset;
 
-       let zero_zero_zero_zero = _mm_setzero_ps();
-       let zero_zero_maxmin_maxamt = _mm_hsub_ps(min_max_amt_max_pow, zero_zero_zero_zero);
+       let zero_zero_maxmin_maxamt = min_max_amt_max_pow.hsub();
 
        let mut zero_zero_den_num = zero_zero_maxmin_maxamt;
        if times_16_on_21 {
-               let zero_zero_twentyone_sixteen = _mm_set_ps(0.0f32, 0.0f32, 21.0f32, 16.0f32);
-               zero_zero_den_num = _mm_mul_ps(zero_zero_den_num, zero_zero_twentyone_sixteen);
-       }
-
-       let zero_zero_vden_vnum = _mm_set_ps(0.0f32, 0.0f32, value_denominator as f32, value_numerator as f32);
-       let zero_zero_rden_rnum = _mm_mul_ps(zero_zero_den_num, zero_zero_vden_vnum);
-
-       let mut res = AlignedFloats([0.0; 4]);
-       _mm_store_ps(&mut res.0[0], zero_zero_rden_rnum);
-       res.0[0] / res.0[1]
-}
-
-#[inline(always)]
-#[cfg(not(target_feature = "sse"))]
-unsafe fn do_nonlinear_success_probability(
-       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
-       value_numerator: u64, value_denominator: u64, times_16_on_21: bool,
-) -> f32 {
-       let (num, mut den) = rust_nonlinear_success_probability(
-               amount_msat, min_liquidity_msat, max_liquidity_msat, capacity_msat
-       );
-       let value = (value_numerator as f32) / (value_denominator as f32);
-       if times_16_on_21 {
-               den = den * 21 / 16;
+               let zero_zero_twentyone_sixteen = FourF32::new(0.0f32, 0.0f32, 21.0f32, 16.0f32);
+               zero_zero_den_num = zero_zero_den_num * zero_zero_twentyone_sixteen;
        }
-       value * num / den
-}
 
+       let zero_zero_vden_vnum = FourF32::new(0.0f32, 0.0f32, value_denominator as f32, value_numerator as f32);
+       let zero_zero_rden_rnum = zero_zero_den_num * zero_zero_vden_vnum;
 
-#[inline(always)]
-fn nonlinear_success_probability_f(
-       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
-       value_numerator: u64, value_denominator: u64, times_16_on_21: bool,
-) -> f32 {
-       unsafe { do_nonlinear_success_probability(
-               amount_msat, min_liquidity_msat, max_liquidity_msat, capacity_msat,
-               value_numerator, value_denominator, times_16_on_21,
-       ) }
+       let res = zero_zero_rden_rnum.dump();
+       res.3 / res.2
 }
 
 
@@ -1249,7 +1211,7 @@ fn success_probability(
 /// (recently) seen an HTLC successfully complete over this channel.
 #[inline(always)]
 fn success_probability_times_value_times_billion(
-       amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64,
+       amount_msat: u16, min_liquidity_msat: u16, max_liquidity_msat: u16, capacity_msat: u16,
        params: &ProbabilisticScoringFeeParameters, min_zero_implies_no_successes: bool,
        value_numerator: u64, value_denominator: u64,
 ) -> u64 {
@@ -1259,7 +1221,7 @@ fn success_probability_times_value_times_billion(
 
        if params.linear_success_probability {
                let (numerator, denominator) = linear_success_probability(
-                       amount_msat, min_liquidity_msat, max_liquidity_msat, min_zero_implies_no_successes
+                       amount_msat as u64, min_liquidity_msat as u64, max_liquidity_msat as u64, min_zero_implies_no_successes
                );
                const BILLIONISH: u64 = 1024 * 1024 * 1024;
                return (value_numerator * BILLIONISH / value_denominator) * numerator / denominator;
@@ -1962,8 +1924,8 @@ mod bucketed_history {
                                if payment_pos < max_bucket_end_pos {
                                        let bucket_points = (min_liquidity_offset_history_buckets[0] as u64) * total_max_points;
                                        cumulative_success_prob_times_billion += success_probability_times_value_times_billion(
-                                               payment_pos as u64, 0, max_bucket_end_pos as u64,
-                                               POSITION_TICKS as u64 - 1, params, true,
+                                               payment_pos, 0, max_bucket_end_pos,
+                                               POSITION_TICKS - 1, params, true,
                                                bucket_points, total_valid_points_tracked
                                        );
                                }
@@ -1996,8 +1958,8 @@ mod bucketed_history {
                                                // 30 bits is 62 bits.
                                                let bucket_points = ((*min_bucket as u32) * (*max_bucket as u32)) as u64;
                                                cumulative_success_prob_times_billion += success_probability_times_value_times_billion(
-                                                       payment_pos as u64, min_bucket_start_pos as u64,
-                                                       max_bucket_end_pos as u64, POSITION_TICKS as u64 - 1, params, true,
+                                                       payment_pos, min_bucket_start_pos,
+                                                       max_bucket_end_pos, POSITION_TICKS - 1, params, true,
                                                        bucket_points, total_valid_points_tracked);
                                        }
                                }