From: Matt Corallo Date: Mon, 29 Jul 2024 20:24:20 +0000 (+0000) Subject: Make multiplication take array references rather than slices X-Git-Tag: v0.6.4~5 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=eb5497c0abe22a8acac8e2c306fcfb4d40faf406;p=dnssec-prover Make multiplication take array references rather than slices This seems to reduce binary size marginally by avoiding slice bounds checking. --- diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index 296681e..7a2ab15 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -248,10 +248,7 @@ define_sub!(sub_128, sub_abs_128, 128); /// /// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int /// types to do multiplication (potentially) natively. -const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] { - debug_assert!(a.len() == 2); - debug_assert!(b.len() == 2); - +const fn mul_2(a: &[u64; 2], b: &[u64; 2]) -> [u64; 4] { // Gradeschool multiplication is way faster here. let (a0, a1) = (a[0] as u128, a[1] as u128); let (b0, b1) = (b[0] as u128, b[1] as u128); @@ -288,10 +285,7 @@ const fn add_mul_2_parts(z2: u128, z1: u128, z0: u128, i_carry_a: bool) -> [u64; [i, j, k, l] } -const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] { - debug_assert!(a.len() == 3); - debug_assert!(b.len() == 3); - +const fn mul_3(a: &[u64; 3], b: &[u64; 3]) -> [u64; 6] { let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128); let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128); @@ -360,13 +354,10 @@ const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] { macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => { /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer. - const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] { + const fn $name(a: &[u64; $len], b: &[u64; $len]) -> [u64; $len * 2] { // We could probably get a bit faster doing gradeschool multiplication for some smaller // sizes, but its easier to just have one variable-length multiplication, so we do // Karatsuba always here. - debug_assert!(a.len() == $len); - debug_assert!(b.len() == $len); - let a0: &[u64; $len / 2] = const_subarr(a, 0); let a1: &[u64; $len / 2] = const_subarr(a, $len / 2); let b0: &[u64; $len / 2] = const_subarr(b, 0); @@ -441,9 +432,7 @@ define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32); /// /// This is the base case for our squaring, taking advantage of Rust's native 128-bit int /// types to do multiplication (potentially) natively. -const fn sqr_2(a: &[u64]) -> [u64; 4] { - debug_assert!(a.len() == 2); - +const fn sqr_2(a: &[u64; 2]) -> [u64; 4] { let (a0, a1) = (a[0] as u128, a[1] as u128); let z2 = a0 * a0; let mut z1 = a0 * a1; @@ -456,12 +445,10 @@ const fn sqr_2(a: &[u64]) -> [u64; 4] { macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => { /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer. - const fn $name(a: &[u64]) -> [u64; $len * 2] { + const fn $name(a: &[u64; $len]) -> [u64; $len * 2] { // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that. - debug_assert!(a.len() == $len); - - let hi = const_subslice(a, 0, $len / 2); - let lo = const_subslice(a, $len / 2, $len); + let hi: &[u64; $len / 2] = const_subarr(a, 0); + let lo: &[u64; $len / 2] = const_subarr(a, $len / 2); let v0 = $subsqr(lo); let mut v1 = $submul(hi, lo); @@ -501,7 +488,7 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id } } // TODO: Write an optimized sqr_3 (though secp384r1 is barely used) -const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) } +const fn sqr_3(a: &[u64; 3]) -> [u64; 6] { mul_3(a, a) } define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2); define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3); @@ -751,29 +738,29 @@ impl U4096 { // we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math // here which debug_assert's the required bits are 0s and then uses faster math primitives. - type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2]; - type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2]; + type mul_ty = fn(&[u64; WORD_COUNT_4096], &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2]; + type sqr_ty = fn(&[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2]; type add_double_ty = fn(&[u64; WORD_COUNT_4096 * 2], &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool); type sub_ty = fn(&[u64; WORD_COUNT_4096], &[u64; WORD_COUNT_4096]) -> ([u64; WORD_COUNT_4096], bool); let (word_count, log_bits, mul, sqr, add_double, sub) = if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] { if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] { - fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] { - debug_assert_eq!(a.len(), WORD_COUNT_4096); - debug_assert_eq!(b.len(), WORD_COUNT_4096); + fn mul_16_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] { debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]); debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]); let mut res = [0; WORD_COUNT_4096 * 2]; - res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice( - &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..])); + let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4); + let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 4); + res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..] + .copy_from_slice(&mul_16(a_arr, b_arr)); res } - fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] { - debug_assert_eq!(a.len(), WORD_COUNT_4096); + fn sqr_16_subarr(a: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] { debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]); let mut res = [0; WORD_COUNT_4096 * 2]; - res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice( - &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..])); + let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4); + res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..] + .copy_from_slice(&sqr_16(a_arr)); res } fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) { @@ -798,22 +785,21 @@ impl U4096 { } (16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty) } else { - fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] { - debug_assert_eq!(a.len(), WORD_COUNT_4096); - debug_assert_eq!(b.len(), WORD_COUNT_4096); + fn mul_32_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] { debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]); debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]); let mut res = [0; WORD_COUNT_4096 * 2]; - res[WORD_COUNT_4096..].copy_from_slice( - &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..])); + let a_arr = const_subarr(a, WORD_COUNT_4096 / 2); + let b_arr = const_subarr(b, WORD_COUNT_4096 / 2); + res[WORD_COUNT_4096..].copy_from_slice(&mul_32(a_arr, b_arr)); res } - fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] { + fn sqr_32_subarr(a: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] { debug_assert_eq!(a.len(), WORD_COUNT_4096); debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]); + let a_arr = const_subarr(a, WORD_COUNT_4096 / 2); let mut res = [0; WORD_COUNT_4096 * 2]; - res[WORD_COUNT_4096..].copy_from_slice( - &sqr_32(&a[WORD_COUNT_4096 / 2..])); + res[WORD_COUNT_4096..].copy_from_slice(&sqr_32(a_arr)); res } fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) { @@ -886,7 +872,7 @@ impl U4096 { v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R // t_on_r = (mu + v*modulus) / R - let t0 = mul(&v[WORD_COUNT_4096..], &m.0); + let t0 = mul(const_subarr(&v, WORD_COUNT_4096), &m.0); let (t1, t1_extra_bit) = add_double(&t0, &mu); // Note that dividing t1 by R is simply a matter of shifting right by word_count bytes @@ -1030,14 +1016,14 @@ const fn u256_mont_reduction_given_prime(mu: [u64; 8], prime: &[u64; 4], negativ // if t >= N { t - N } else { t } // mu % R is just the bottom 4 bytes of mu - let mu_mod_r = const_subslice(&mu, 4, 8); + let mu_mod_r: &[u64; 4] = const_subarr(&mu, 4); // v = ((mu % R) * negative_modulus_inverse) % R - let mut v = mul_4(&mu_mod_r, negative_prime_inv_mod_r); + let mut v = mul_4(mu_mod_r, negative_prime_inv_mod_r); const ZEROS: &[u64; 4] = &[0; 4]; copy_from_slice!(v, 0, 4, ZEROS); // mod R // t_on_r = (mu + v*modulus) / R - let t0 = mul_4(const_subslice(&v, 4, 8), prime); + let t0 = mul_4(const_subarr(&v, 4), prime); let (t1, t1_extra_bit) = add_8(&t0, &mu); // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes. @@ -1241,14 +1227,14 @@ const fn u384_mont_reduction_given_prime(mu: [u64; 12], prime: &[u64; 6], negati // if t >= N { t - N } else { t } // mu % R is just the bottom 4 bytes of mu - let mu_mod_r = const_subslice(&mu, 6, 12); + let mu_mod_r: &[u64; 6] = const_subarr(&mu, 6); // v = ((mu % R) * negative_modulus_inverse) % R - let mut v = mul_6(&mu_mod_r, negative_prime_inv_mod_r); + let mut v = mul_6(mu_mod_r, negative_prime_inv_mod_r); const ZEROS: &[u64; 6] = &[0; 6]; copy_from_slice!(v, 0, 6, ZEROS); // mod R // t_on_r = (mu + v*modulus) / R - let t0 = mul_6(const_subslice(&v, 6, 12), prime); + let t0 = mul_6(const_subarr(&v, 6), prime); let (t1, t1_extra_bit) = add_12(&t0, &mu); // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes. @@ -1457,15 +1443,15 @@ pub fn fuzz_math(input: &[u8]) { let a_arg = (&a_u64s[..]).try_into().unwrap(); let b_arg = (&b_u64s[..]).try_into().unwrap(); - let res = $mul(&a_u64s, &b_u64s); + let res = $mul(a_arg, b_arg); let mut res_bytes = Vec::with_capacity(input.len() / 2); for i in res { res_bytes.extend_from_slice(&i.to_be_bytes()); } assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone()); - debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s)); - debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s)); + debug_assert_eq!($mul(a_arg, a_arg), $sqr(a_arg)); + debug_assert_eq!($mul(b_arg, b_arg), $sqr(b_arg)); let (res, carry) = $add(a_arg, b_arg); let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1); @@ -1534,11 +1520,11 @@ pub fn fuzz_math(input: &[u8]) { let a_arg = (&a_u64s[..]).try_into().unwrap(); let b_arg = (&b_u64s[..]).try_into().unwrap(); - let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1; + let amodp_squared = $div_rem_double(&$mul(a_arg, a_arg), &p_extended).unwrap().1; assert_eq!(&amodp_squared[..$len], &[0; $len]); assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]); - let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1; + let abmodp = $div_rem_double(&$mul(a_arg, b_arg), &p_extended).unwrap().1; assert_eq!(&abmodp[..$len], &[0; $len]); assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);