]> git.bitcoin.ninja Git - dnssec-prover/commitdiff
Make multiplication take array references rather than slices
authorMatt Corallo <git@bluematt.me>
Mon, 29 Jul 2024 20:24:20 +0000 (20:24 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 1 Aug 2024 03:55:38 +0000 (03:55 +0000)
This seems to reduce binary size marginally by avoiding slice
bounds checking.

src/crypto/bigint.rs

index 296681efa46096d2ef2c40cc158bfd0140f0f244..7a2ab150e7362db0d6fa4c4f1fc0a56bb5328c37 100644 (file)
@@ -248,10 +248,7 @@ define_sub!(sub_128, sub_abs_128, 128);
 ///
 /// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int
 /// types to do multiplication (potentially) natively.
-const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
-       debug_assert!(a.len() == 2);
-       debug_assert!(b.len() == 2);
-
+const fn mul_2(a: &[u64; 2], b: &[u64; 2]) -> [u64; 4] {
        // Gradeschool multiplication is way faster here.
        let (a0, a1) = (a[0] as u128, a[1] as u128);
        let (b0, b1) = (b[0] as u128, b[1] as u128);
@@ -288,10 +285,7 @@ const fn add_mul_2_parts(z2: u128, z1: u128, z0: u128, i_carry_a: bool) -> [u64;
        [i, j, k, l]
 }
 
-const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
-       debug_assert!(a.len() == 3);
-       debug_assert!(b.len() == 3);
-
+const fn mul_3(a: &[u64; 3], b: &[u64; 3]) -> [u64; 6] {
        let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128);
        let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128);
 
@@ -360,13 +354,10 @@ const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
 
 macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
        /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
-       const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
+       const fn $name(a: &[u64; $len], b: &[u64; $len]) -> [u64; $len * 2] {
                // We could probably get a bit faster doing gradeschool multiplication for some smaller
                // sizes, but its easier to just have one variable-length multiplication, so we do
                // Karatsuba always here.
-               debug_assert!(a.len() == $len);
-               debug_assert!(b.len() == $len);
-
                let a0: &[u64; $len / 2] = const_subarr(a, 0);
                let a1: &[u64; $len / 2] = const_subarr(a, $len / 2);
                let b0: &[u64; $len / 2] = const_subarr(b, 0);
@@ -441,9 +432,7 @@ define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32);
 ///
 /// This is the base case for our squaring, taking advantage of Rust's native 128-bit int
 /// types to do multiplication (potentially) natively.
-const fn sqr_2(a: &[u64]) -> [u64; 4] {
-       debug_assert!(a.len() == 2);
-
+const fn sqr_2(a: &[u64; 2]) -> [u64; 4] {
        let (a0, a1) = (a[0] as u128, a[1] as u128);
        let z2 = a0 * a0;
        let mut z1 = a0 * a1;
@@ -456,12 +445,10 @@ const fn sqr_2(a: &[u64]) -> [u64; 4] {
 
 macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
        /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
-       const fn $name(a: &[u64]) -> [u64; $len * 2] {
+       const fn $name(a: &[u64; $len]) -> [u64; $len * 2] {
                // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that.
-               debug_assert!(a.len() == $len);
-
-               let hi = const_subslice(a, 0, $len / 2);
-               let lo = const_subslice(a, $len / 2, $len);
+               let hi: &[u64; $len / 2] = const_subarr(a, 0);
+               let lo: &[u64; $len / 2] = const_subarr(a, $len / 2);
 
                let v0 = $subsqr(lo);
                let mut v1 = $submul(hi, lo);
@@ -501,7 +488,7 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id
 } }
 
 // TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
-const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) }
+const fn sqr_3(a: &[u64; 3]) -> [u64; 6] { mul_3(a, a) }
 
 define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
 define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
@@ -751,29 +738,29 @@ impl U4096 {
                // we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math
                // here which debug_assert's the required bits are 0s and then uses faster math primitives.
 
-               type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
-               type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
+               type mul_ty = fn(&[u64; WORD_COUNT_4096], &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2];
+               type sqr_ty = fn(&[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2];
                type add_double_ty = fn(&[u64; WORD_COUNT_4096 * 2], &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool);
                type sub_ty = fn(&[u64; WORD_COUNT_4096], &[u64; WORD_COUNT_4096]) -> ([u64; WORD_COUNT_4096], bool);
                let (word_count, log_bits, mul, sqr, add_double, sub) =
                        if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
                                if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] {
-                                       fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
-                                               debug_assert_eq!(a.len(), WORD_COUNT_4096);
-                                               debug_assert_eq!(b.len(), WORD_COUNT_4096);
+                                       fn mul_16_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
-                                               res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
-                                                       &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]));
+                                               let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4);
+                                               let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 4);
+                                               res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..]
+                                                       .copy_from_slice(&mul_16(a_arr, b_arr));
                                                res
                                        }
-                                       fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
-                                               debug_assert_eq!(a.len(), WORD_COUNT_4096);
+                                       fn sqr_16_subarr(a: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
-                                               res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
-                                                       &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
+                                               let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4);
+                                               res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..]
+                                                       .copy_from_slice(&sqr_16(a_arr));
                                                res
                                        }
                                        fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
@@ -798,22 +785,21 @@ impl U4096 {
                                        }
                                        (16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty)
                                } else {
-                                       fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
-                                               debug_assert_eq!(a.len(), WORD_COUNT_4096);
-                                               debug_assert_eq!(b.len(), WORD_COUNT_4096);
+                                       fn mul_32_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
-                                               res[WORD_COUNT_4096..].copy_from_slice(
-                                                       &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]));
+                                               let a_arr = const_subarr(a, WORD_COUNT_4096 / 2);
+                                               let b_arr = const_subarr(b, WORD_COUNT_4096 / 2);
+                                               res[WORD_COUNT_4096..].copy_from_slice(&mul_32(a_arr, b_arr));
                                                res
                                        }
-                                       fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
+                                       fn sqr_32_subarr(a: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
                                                debug_assert_eq!(a.len(), WORD_COUNT_4096);
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
+                                               let a_arr = const_subarr(a, WORD_COUNT_4096 / 2);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
-                                               res[WORD_COUNT_4096..].copy_from_slice(
-                                                       &sqr_32(&a[WORD_COUNT_4096 / 2..]));
+                                               res[WORD_COUNT_4096..].copy_from_slice(&sqr_32(a_arr));
                                                res
                                        }
                                        fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
@@ -886,7 +872,7 @@ impl U4096 {
                        v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
 
                        // t_on_r = (mu + v*modulus) / R
-                       let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
+                       let t0 = mul(const_subarr(&v, WORD_COUNT_4096), &m.0);
                        let (t1, t1_extra_bit) = add_double(&t0, &mu);
 
                        // Note that dividing t1 by R is simply a matter of shifting right by word_count bytes
@@ -1030,14 +1016,14 @@ const fn u256_mont_reduction_given_prime(mu: [u64; 8], prime: &[u64; 4], negativ
        // if t >= N { t - N } else { t }
 
        // mu % R is just the bottom 4 bytes of mu
-       let mu_mod_r = const_subslice(&mu, 4, 8);
+       let mu_mod_r: &[u64; 4] = const_subarr(&mu, 4);
        // v = ((mu % R) * negative_modulus_inverse) % R
-       let mut v = mul_4(&mu_mod_r, negative_prime_inv_mod_r);
+       let mut v = mul_4(mu_mod_r, negative_prime_inv_mod_r);
        const ZEROS: &[u64; 4] = &[0; 4];
        copy_from_slice!(v, 0, 4, ZEROS); // mod R
 
        // t_on_r = (mu + v*modulus) / R
-       let t0 = mul_4(const_subslice(&v, 4, 8), prime);
+       let t0 = mul_4(const_subarr(&v, 4), prime);
        let (t1, t1_extra_bit) = add_8(&t0, &mu);
 
        // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
@@ -1241,14 +1227,14 @@ const fn u384_mont_reduction_given_prime(mu: [u64; 12], prime: &[u64; 6], negati
        // if t >= N { t - N } else { t }
 
        // mu % R is just the bottom 4 bytes of mu
-       let mu_mod_r = const_subslice(&mu, 6, 12);
+       let mu_mod_r: &[u64; 6] = const_subarr(&mu, 6);
        // v = ((mu % R) * negative_modulus_inverse) % R
-       let mut v = mul_6(&mu_mod_r, negative_prime_inv_mod_r);
+       let mut v = mul_6(mu_mod_r, negative_prime_inv_mod_r);
        const ZEROS: &[u64; 6] = &[0; 6];
        copy_from_slice!(v, 0, 6, ZEROS); // mod R
 
        // t_on_r = (mu + v*modulus) / R
-       let t0 = mul_6(const_subslice(&v, 6, 12), prime);
+       let t0 = mul_6(const_subarr(&v, 6), prime);
        let (t1, t1_extra_bit) = add_12(&t0, &mu);
 
        // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
@@ -1457,15 +1443,15 @@ pub fn fuzz_math(input: &[u8]) {
                let a_arg = (&a_u64s[..]).try_into().unwrap();
                let b_arg = (&b_u64s[..]).try_into().unwrap();
 
-               let res = $mul(&a_u64s, &b_u64s);
+               let res = $mul(a_arg, b_arg);
                let mut res_bytes = Vec::with_capacity(input.len() / 2);
                for i in res {
                        res_bytes.extend_from_slice(&i.to_be_bytes());
                }
                assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone());
 
-               debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
-               debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
+               debug_assert_eq!($mul(a_arg, a_arg), $sqr(a_arg));
+               debug_assert_eq!($mul(b_arg, b_arg), $sqr(b_arg));
 
                let (res, carry) = $add(a_arg, b_arg);
                let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
@@ -1534,11 +1520,11 @@ pub fn fuzz_math(input: &[u8]) {
                let a_arg = (&a_u64s[..]).try_into().unwrap();
                let b_arg = (&b_u64s[..]).try_into().unwrap();
 
-               let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
+               let amodp_squared = $div_rem_double(&$mul(a_arg, a_arg), &p_extended).unwrap().1;
                assert_eq!(&amodp_squared[..$len], &[0; $len]);
                assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
 
-               let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1;
+               let abmodp = $div_rem_double(&$mul(a_arg, b_arg), &p_extended).unwrap().1;
                assert_eq!(&abmodp[..$len], &[0; $len]);
                assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);