From 7322a874b826a755bd5de306848996cfd2984544 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 29 Jul 2024 21:36:28 +0000 Subject: [PATCH] Use a single const generic `add` method rather than macro-izing --- src/crypto/bigint.rs | 151 +++++++++++++++++++------------------------ 1 file changed, 68 insertions(+), 83 deletions(-) diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index 7a2ab15..5ac1313 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -168,37 +168,22 @@ macro_rules! double { ($a: ident) => { { carry } } } -macro_rules! define_add { ($name: ident, $len: expr) => { - /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow - /// bit, with the same semantics as the std [`u64::overflowing_add`] method. - const fn $name(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) { - let mut r = [0; $len]; - let mut carry = false; - let mut i = $len - 1; - loop { - let (v, mut new_carry) = a[i].overflowing_add(b[i]); - let (v2, new_new_carry) = v.overflowing_add(carry as u64); - new_carry |= new_new_carry; - r[i] = v2; - carry = new_carry; +const fn add(a: &[u64; N], b: &[u64; N]) -> ([u64; N], bool) { + let mut r = [0; N]; + let mut carry = false; + let mut i = N - 1; + loop { + let (v, mut new_carry) = a[i].overflowing_add(b[i]); + let (v2, new_new_carry) = v.overflowing_add(carry as u64); + new_carry |= new_new_carry; + r[i] = v2; + carry = new_carry; - if i == 0 { break; } - i -= 1; - } - (r, carry) + if i == 0 { break; } + i -= 1; } -} } - -define_add!(add_2, 2); -define_add!(add_3, 3); -define_add!(add_4, 4); -define_add!(add_6, 6); -define_add!(add_8, 8); -define_add!(add_12, 12); -define_add!(add_16, 16); -define_add!(add_32, 32); -define_add!(add_64, 64); -define_add!(add_128, 128); + (r, carry) +} macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => { /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new @@ -352,7 +337,7 @@ const fn mul_3(a: &[u64; 3], b: &[u64; 3]) -> [u64; 6] { [r0, r1, r2, r3, r4, r5] } -macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => { +macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $sub: ident, $subsub: ident) => { /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer. const fn $name(a: &[u64; $len], b: &[u64; $len]) -> [u64; $len * 2] { // We could probably get a bit faster doing gradeschool multiplication for some smaller @@ -378,14 +363,14 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident let z1m_sign = z1a_sign == z1b_sign; let z1m = $submul(&z1a.0, &z1b.0); - let z1n = $add(&z0, &z2); + let z1n = add(&z0, &z2); let mut z1_carry = z1n.1; let z1 = if z1m_sign { let r = $sub(&z1n.0, &z1m); if r.1 { z1_carry ^= true; } r.0 } else { - let r = $add(&z1n.0, &z1m); + let r = add(&z1n.0, &z1m); if r.1 { z1_carry = true; } r.0 }; @@ -396,8 +381,8 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident let z2_end: &[u64; $len / 2] = const_subarr(&z2, $len / 2); let l = const_subslice(&z0, $len / 2, $len); - let (k, j_carry) = $subadd(z0_start, z1_end); - let (mut j, i_carry_a) = $subadd(z1_start, z2_end); + let (k, j_carry) = add(z0_start, z1_end); + let (mut j, i_carry_a) = add(z1_start, z2_end); let mut i_carry_b = false; if j_carry { i_carry_b = add_u64!(j, 1); @@ -420,12 +405,12 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident } } } -define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2); -define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3); -define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4); -define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8); -define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16); -define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32); +define_mul!(mul_4, 4, mul_2, sub_4, sub_2); +define_mul!(mul_6, 6, mul_3, sub_6, sub_3); +define_mul!(mul_8, 8, mul_4, sub_8, sub_4); +define_mul!(mul_16, 16, mul_8, sub_16, sub_8); +define_mul!(mul_32, 32, mul_16, sub_32, sub_16); +define_mul!(mul_64, 64, mul_32, sub_64, sub_32); /// Squares a 128-bit integer, returning a new 256-bit integer. @@ -443,7 +428,7 @@ const fn sqr_2(a: &[u64; 2]) -> [u64; 4] { add_mul_2_parts(z2, z1, z0, i_carry_a) } -macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => { +macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident) => { /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer. const fn $name(a: &[u64; $len]) -> [u64; $len * 2] { // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that. @@ -461,8 +446,8 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id let v2_end: &[u64; $len / 2] = const_subarr(&v2, $len / 2); let l = const_subslice(&v0, $len / 2, $len); - let (k, j_carry) = $subadd(v0_start, v1_end); - let (mut j, i_carry_b) = $subadd(v1_start, v2_end); + let (k, j_carry) = add(v0_start, v1_end); + let (mut j, i_carry_b) = add(v1_start, v2_end); let mut i = [0; $len / 2]; let i_source = const_subslice(&v2, 0, $len / 2); @@ -490,12 +475,12 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id // TODO: Write an optimized sqr_3 (though secp384r1 is barely used) const fn sqr_3(a: &[u64; 3]) -> [u64; 6] { mul_3(a, a) } -define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2); -define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3); -define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4); -define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8); -define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16); -define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32); +define_sqr!(sqr_4, 4, mul_2, sqr_2); +define_sqr!(sqr_6, 6, mul_3, sqr_3); +define_sqr!(sqr_8, 8, mul_4, sqr_4); +define_sqr!(sqr_16, 16, mul_8, sqr_8); +define_sqr!(sqr_32, 32, mul_16, sqr_16); +define_sqr!(sqr_64, 64, mul_32, sqr_32); macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} } macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } } @@ -559,7 +544,7 @@ define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to #[cfg(debug_assertions)] define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap -macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => { +macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $sub_abs: ident, $mul: ident) => { /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus, /// if one exists. const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> { @@ -579,17 +564,17 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: iden debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed"); let (new_s, new_s_neg) = match (old_s_neg, s_neg) { (true, true) => { - let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len)); + let (new_s, overflow) = add(&old_s, const_subarr(&new_sa, $len)); debug_assert!(!overflow); (new_s, true) } (false, true) => { - let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len)); + let (new_s, overflow) = add(&old_s, const_subarr(&new_sa, $len)); debug_assert!(!overflow); (new_s, false) }, (true, false) => { - let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len)); + let (new_s, overflow) = add(&old_s, const_subarr(&new_sa, $len)); debug_assert!(!overflow); (new_s, true) }, @@ -622,11 +607,11 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: iden } } } #[cfg(fuzzing)] -define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2); -define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4); -define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6); +define_mod_inv!(mod_inv_2, 2, div_rem_2, sub_abs_2, mul_2); +define_mod_inv!(mod_inv_4, 4, div_rem_4, sub_abs_4, mul_4); +define_mod_inv!(mod_inv_6, 6, div_rem_6, sub_abs_6, mul_6); #[cfg(fuzzing)] -define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8); +define_mod_inv!(mod_inv_8, 8, div_rem_8, sub_abs_8, mul_8); // ****************** // * The public API * @@ -766,9 +751,9 @@ impl U4096 { fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) { debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]); debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]); - let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 2); - let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 2); - let (add, overflow) = add_32(a_arr, b_arr); + let a_arr: &[u64; 32] = const_subarr(a, WORD_COUNT_4096 * 3 / 2); + let b_arr: &[u64; 32] = const_subarr(b, WORD_COUNT_4096 * 3 / 2); + let (add, overflow) = add(a_arr, b_arr); let mut res = [0; WORD_COUNT_4096 * 2]; res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add); (res, overflow) @@ -805,9 +790,9 @@ impl U4096 { fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) { debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]); debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]); - let a_arr = const_subarr(a, WORD_COUNT_4096); - let b_arr = const_subarr(b, WORD_COUNT_4096); - let (add, overflow) = add_64(a_arr, b_arr); + let a_arr: &[u64; 64] = const_subarr(a, WORD_COUNT_4096); + let b_arr: &[u64; 64] = const_subarr(b, WORD_COUNT_4096); + let (add, overflow) = add(a_arr, b_arr); let mut res = [0; WORD_COUNT_4096 * 2]; res[WORD_COUNT_4096..].copy_from_slice(&add); (res, overflow) @@ -825,7 +810,7 @@ impl U4096 { (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty) } } else { - (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty) + (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add as add_double_ty, sub_64 as sub_ty) }; // r is always the even value with one bit set above the word count we're using. @@ -1024,7 +1009,7 @@ const fn u256_mont_reduction_given_prime(mu: [u64; 8], prime: &[u64; 4], negativ // t_on_r = (mu + v*modulus) / R let t0 = mul_4(const_subarr(&v, 4), prime); - let (t1, t1_extra_bit) = add_8(&t0, &mu); + let (t1, t1_extra_bit) = add(&t0, &mu); // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes. // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit) @@ -1151,7 +1136,7 @@ impl> U256Mod { let (mut val, underflow) = sub_4(&self.0.0, &b.0.0); if underflow { let overflow; - (val, overflow) = add_4(&val, &M::PRIME.0); + (val, overflow) = add(&val, &M::PRIME.0); debug_assert_eq!(overflow, underflow); } Self(U256(val), PhantomData) @@ -1159,7 +1144,7 @@ impl> U256Mod { /// Adds `b` to `self` % `m`. pub(super) fn add(&self, b: &Self) -> Self { - let (mut val, overflow) = add_4(&self.0.0, &b.0.0); + let (mut val, overflow) = add(&self.0.0, &b.0.0); if overflow || !slice_greater_than(&M::PRIME.0, &val) { let underflow; (val, underflow) = sub_4(&val, &M::PRIME.0); @@ -1235,7 +1220,7 @@ const fn u384_mont_reduction_given_prime(mu: [u64; 12], prime: &[u64; 6], negati // t_on_r = (mu + v*modulus) / R let t0 = mul_6(const_subarr(&v, 6), prime); - let (t1, t1_extra_bit) = add_12(&t0, &mu); + let (t1, t1_extra_bit) = add(&t0, &mu); // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes. // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit) @@ -1367,7 +1352,7 @@ impl> U384Mod { let (mut val, underflow) = sub_6(&self.0.0, &b.0.0); if underflow { let overflow; - (val, overflow) = add_6(&val, &M::PRIME.0); + (val, overflow) = add(&val, &M::PRIME.0); debug_assert_eq!(overflow, underflow); } Self(U384(val), PhantomData) @@ -1375,7 +1360,7 @@ impl> U384Mod { /// Adds `b` to `self` % `m`. pub(super) fn add(&self, b: &Self) -> Self { - let (mut val, overflow) = add_6(&self.0.0, &b.0.0); + let (mut val, overflow) = add(&self.0.0, &b.0.0); if overflow || !slice_greater_than(&M::PRIME.0, &val) { let underflow; (val, underflow) = sub_6(&val, &M::PRIME.0); @@ -1439,7 +1424,7 @@ pub fn fuzz_math(input: &[u8]) { b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap())); } - macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => { + macro_rules! test { ($mul: ident, $sqr: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => { let a_arg = (&a_u64s[..]).try_into().unwrap(); let b_arg = (&b_u64s[..]).try_into().unwrap(); @@ -1453,7 +1438,7 @@ pub fn fuzz_math(input: &[u8]) { debug_assert_eq!($mul(a_arg, a_arg), $sqr(a_arg)); debug_assert_eq!($mul(b_arg, b_arg), $sqr(b_arg)); - let (res, carry) = $add(a_arg, b_arg); + let (res, carry) = add(a_arg, b_arg); let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1); if carry { res_bytes.push(1); } else { res_bytes.push(0); } for i in res { @@ -1512,7 +1497,7 @@ pub fn fuzz_math(input: &[u8]) { } } } - macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => { + macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $sub: ident) => { // Test the U256/U384Mod wrapper, which operates in Montgomery representation let mut p_extended = [0; $len * 2]; p_extended[$len..].copy_from_slice(&$PRIME); @@ -1528,7 +1513,7 @@ pub fn fuzz_math(input: &[u8]) { assert_eq!(&abmodp[..$len], &[0; $len]); assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]); - let (aplusb, aplusb_overflow) = $add(a_arg, b_arg); + let (aplusb, aplusb_overflow) = add(a_arg, b_arg); let mut aplusb_extended = [0; $len * 2]; aplusb_extended[$len..].copy_from_slice(&aplusb); if aplusb_overflow { aplusb_extended[$len - 1] = 1; } @@ -1539,9 +1524,9 @@ pub fn fuzz_math(input: &[u8]) { let (mut aminusb, aminusb_underflow) = $sub(a_arg, b_arg); if aminusb_underflow { let mut overflow; - (aminusb, overflow) = $add(&aminusb, &$PRIME); + (aminusb, overflow) = add(&aminusb, &$PRIME); if !overflow { - (aminusb, overflow) = $add(&aminusb, &$PRIME); + (aminusb, overflow) = add(&aminusb, &$PRIME); } assert!(overflow); } @@ -1550,19 +1535,19 @@ pub fn fuzz_math(input: &[u8]) { } } if a_u64s.len() == 2 { - test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2); + test!(mul_2, sqr_2, sub_2, div_rem_2, mod_inv_2); } else if a_u64s.len() == 4 { - test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4); + test!(mul_4, sqr_4, sub_4, div_rem_4, mod_inv_4); let amodp = U256Mod::::from_u256(U256(a_u64s[..].try_into().unwrap())); let bmodp = U256Mod::::from_u256(U256(b_u64s[..].try_into().unwrap())); - test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4); + test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, sub_4); } else if a_u64s.len() == 6 { - test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6); + test!(mul_6, sqr_6, sub_6, div_rem_6, mod_inv_6); let amodp = U384Mod::::from_u384(U384(a_u64s[..].try_into().unwrap())); let bmodp = U384Mod::::from_u384(U384(b_u64s[..].try_into().unwrap())); - test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6); + test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, sub_6); } else if a_u64s.len() == 8 { - test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8); + test!(mul_8, sqr_8, sub_8, div_rem_8, mod_inv_8); } else if input.len() == 512*2 + 4 { let mut e_bytes = [0; 4]; e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]); @@ -1687,7 +1672,7 @@ mod tests { let a_int = u64s_to_u128(a); let b_int = u64s_to_u128(b); - let res = add_2(&a, &b); + let res = add(&a, &b); assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int)); let res = sub_2(&a, &b); -- 2.39.5