From aee1e36043996618d944c50fa132934d5d212fc0 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 29 Jul 2024 20:01:15 +0000 Subject: [PATCH] Make addition take array references rather than slices This seems to reduce binary size marginally by avoiding slice bounds checking. --- src/crypto/bigint.rs | 87 ++++++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 28 deletions(-) diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index 781d54c..b7c9608 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -48,6 +48,23 @@ const fn const_subslice<'a, T>(a: &'a [T], start: usize, end: usize) -> &'a [T] unsafe { alloc::slice::from_raw_parts(startptr, len) }*/ } +/// Const version of `...: &[a; N] = &a[start..start + N].try_into().unwrap()` +const fn const_subarr<'a, const N: usize, T>(a: &'a [T], start: usize) -> &'a [T; N] { + debug_assert!(N > 0); + + let end = start + N; + assert!(start <= a.len()); + assert!(end <= a.len()); + let mut startptr = a.as_ptr(); + startptr = unsafe { startptr.add(start) }; + // transmute will fail to compile if the source and target sizes don't match, so we know the + // target type is just the length of a pointer. This leaves only a few possible encodings for + // a pointer to an array, basically just a pointer to the start or end. While its possible Rust + // could use a pointer to the end, this would be very surprising (and probably less efficient + // in most cases), so we just assume array references are always just pointers to the start. + unsafe { core::mem::transmute(startptr) } +} + /// Const version of `dest[dest_start..dest_end].copy_from_slice(source)` /// /// Once `const_mut_refs` is stable we can convert this to a function @@ -154,9 +171,7 @@ macro_rules! double { ($a: ident) => { { macro_rules! define_add { ($name: ident, $len: expr) => { /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow /// bit, with the same semantics as the std [`u64::overflowing_add`] method. - const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) { - debug_assert!(a.len() == $len); - debug_assert!(b.len() == $len); + const fn $name(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) { let mut r = [0; $len]; let mut carry = false; let mut i = $len - 1; @@ -354,18 +369,18 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident debug_assert!(a.len() == $len); debug_assert!(b.len() == $len); - let a0 = const_subslice(a, 0, $len / 2); - let a1 = const_subslice(a, $len / 2, $len); - let b0 = const_subslice(b, 0, $len / 2); - let b1 = const_subslice(b, $len / 2, $len); + let a0: &[u64; $len / 2] = const_subarr(a, 0); + let a1: &[u64; $len / 2] = const_subarr(a, $len / 2); + let b0: &[u64; $len / 2] = const_subarr(b, 0); + let b1: &[u64; $len / 2] = const_subarr(b, $len / 2); let z2 = $submul(a0, b0); let z0 = $submul(a1, b1); let (z1a_max, z1a_min, z1a_sign) = - if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) }; + if slice_greater_than(a1, a0) { (a1, a0, true) } else { (a0, a1, false) }; let (z1b_max, z1b_min, z1b_sign) = - if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) }; + if slice_greater_than(b1, b0) { (b1, b0, true) } else { (b0, b1, false) }; let z1a = $subsub(z1a_max, z1a_min); debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min"); @@ -386,9 +401,14 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident r.0 }; + let z0_start: &[u64; $len / 2] = const_subarr(&z0, 0); + let z1_start: &[u64; $len / 2] = const_subarr(&z1, 0); + let z1_end: &[u64; $len / 2] = const_subarr(&z1, $len / 2); + let z2_end: &[u64; $len / 2] = const_subarr(&z2, $len / 2); + let l = const_subslice(&z0, $len / 2, $len); - let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len)); - let (mut j, i_carry_a) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len)); + let (k, j_carry) = $subadd(z0_start, z1_end); + let (mut j, i_carry_a) = $subadd(z1_start, z2_end); let mut i_carry_b = false; if j_carry { i_carry_b = add_u64!(j, 1); @@ -450,9 +470,14 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id let i_carry_a = double!(v1); let v2 = $subsqr(hi); + let v0_start: &[u64; $len / 2] = const_subarr(&v0, 0); + let v1_start: &[u64; $len / 2] = const_subarr(&v1, 0); + let v1_end: &[u64; $len / 2] = const_subarr(&v1, $len / 2); + let v2_end: &[u64; $len / 2] = const_subarr(&v2, $len / 2); + let l = const_subslice(&v0, $len / 2, $len); - let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len)); - let (mut j, i_carry_b) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len)); + let (k, j_carry) = $subadd(v0_start, v1_end); + let (mut j, i_carry_b) = $subadd(v1_start, v2_end); let mut i = [0; $len / 2]; let i_source = const_subslice(&v2, 0, $len / 2); @@ -569,17 +594,17 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: iden debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed"); let (new_s, new_s_neg) = match (old_s_neg, s_neg) { (true, true) => { - let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len())); + let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len)); debug_assert!(!overflow); (new_s, true) } (false, true) => { - let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len())); + let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len)); debug_assert!(!overflow); (new_s, false) }, (true, false) => { - let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len())); + let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len)); debug_assert!(!overflow); (new_s, true) }, @@ -730,7 +755,7 @@ impl U4096 { type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2]; type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2]; - type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool); + type add_double_ty = fn(&[u64; WORD_COUNT_4096 * 2], &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool); type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool); let (word_count, log_bits, mul, sqr, add_double, sub) = if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] { @@ -753,12 +778,12 @@ impl U4096 { &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..])); res } - fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) { - debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2); - debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2); + fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) { debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]); debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]); - let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]); + let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 2); + let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 2); + let (add, overflow) = add_32(a_arr, b_arr); let mut res = [0; WORD_COUNT_4096 * 2]; res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add); (res, overflow) @@ -793,12 +818,12 @@ impl U4096 { &sqr_32(&a[WORD_COUNT_4096 / 2..])); res } - fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) { - debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2); - debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2); + fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) { debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]); debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]); - let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]); + let a_arr = const_subarr(a, WORD_COUNT_4096); + let b_arr = const_subarr(b, WORD_COUNT_4096); + let (add, overflow) = add_64(a_arr, b_arr); let mut res = [0; WORD_COUNT_4096 * 2]; res[WORD_COUNT_4096..].copy_from_slice(&add); (res, overflow) @@ -1430,6 +1455,9 @@ pub fn fuzz_math(input: &[u8]) { } macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => { + let a_arg = (&a_u64s[..]).try_into().unwrap(); + let b_arg = (&b_u64s[..]).try_into().unwrap(); + let res = $mul(&a_u64s, &b_u64s); let mut res_bytes = Vec::with_capacity(input.len() / 2); for i in res { @@ -1440,7 +1468,7 @@ pub fn fuzz_math(input: &[u8]) { debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s)); debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s)); - let (res, carry) = $add(&a_u64s, &b_u64s); + let (res, carry) = $add(a_arg, b_arg); let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1); if carry { res_bytes.push(1); } else { res_bytes.push(0); } for i in res { @@ -1504,6 +1532,9 @@ pub fn fuzz_math(input: &[u8]) { let mut p_extended = [0; $len * 2]; p_extended[$len..].copy_from_slice(&$PRIME); + let a_arg = (&a_u64s[..]).try_into().unwrap(); + let b_arg = (&b_u64s[..]).try_into().unwrap(); + let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1; assert_eq!(&amodp_squared[..$len], &[0; $len]); assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]); @@ -1512,7 +1543,7 @@ pub fn fuzz_math(input: &[u8]) { assert_eq!(&abmodp[..$len], &[0; $len]); assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]); - let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s); + let (aplusb, aplusb_overflow) = $add(a_arg, b_arg); let mut aplusb_extended = [0; $len * 2]; aplusb_extended[$len..].copy_from_slice(&aplusb); if aplusb_overflow { aplusb_extended[$len - 1] = 1; } @@ -1520,7 +1551,7 @@ pub fn fuzz_math(input: &[u8]) { assert_eq!(&aplusbmodp[..$len], &[0; $len]); assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]); - let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s); + let (mut aminusb, aminusb_underflow) = $sub(a_arg, b_arg); if aminusb_underflow { let mut overflow; (aminusb, overflow) = $add(&aminusb, &$PRIME); -- 2.39.5