]> git.bitcoin.ninja Git - dnssec-prover/commitdiff
Use a single const generic `add` method rather than macro-izing
authorMatt Corallo <git@bluematt.me>
Mon, 29 Jul 2024 21:36:28 +0000 (21:36 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 1 Aug 2024 03:55:38 +0000 (03:55 +0000)
src/crypto/bigint.rs

index 7a2ab150e7362db0d6fa4c4f1fc0a56bb5328c37..5ac131376e24cb001b0f91dfe501411b7742be43 100644 (file)
@@ -168,37 +168,22 @@ macro_rules! double { ($a: ident) => { {
        carry
 } } }
 
-macro_rules! define_add { ($name: ident, $len: expr) => {
-       /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
-       /// bit, with the same semantics as the std [`u64::overflowing_add`] method.
-       const fn $name(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) {
-               let mut r = [0; $len];
-               let mut carry = false;
-               let mut i = $len - 1;
-               loop {
-                       let (v, mut new_carry) = a[i].overflowing_add(b[i]);
-                       let (v2, new_new_carry) = v.overflowing_add(carry as u64);
-                       new_carry |= new_new_carry;
-                       r[i] = v2;
-                       carry = new_carry;
+const fn add<const N: usize>(a: &[u64; N], b: &[u64; N]) -> ([u64; N], bool) {
+       let mut r = [0; N];
+       let mut carry = false;
+       let mut i = N - 1;
+       loop {
+               let (v, mut new_carry) = a[i].overflowing_add(b[i]);
+               let (v2, new_new_carry) = v.overflowing_add(carry as u64);
+               new_carry |= new_new_carry;
+               r[i] = v2;
+               carry = new_carry;
 
-                       if i == 0 { break; }
-                       i -= 1;
-               }
-               (r, carry)
+               if i == 0 { break; }
+               i -= 1;
        }
-} }
-
-define_add!(add_2, 2);
-define_add!(add_3, 3);
-define_add!(add_4, 4);
-define_add!(add_6, 6);
-define_add!(add_8, 8);
-define_add!(add_12, 12);
-define_add!(add_16, 16);
-define_add!(add_32, 32);
-define_add!(add_64, 64);
-define_add!(add_128, 128);
+       (r, carry)
+}
 
 macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
        /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
@@ -352,7 +337,7 @@ const fn mul_3(a: &[u64; 3], b: &[u64; 3]) -> [u64; 6] {
        [r0, r1, r2, r3, r4, r5]
 }
 
-macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
+macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $sub: ident, $subsub: ident) => {
        /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
        const fn $name(a: &[u64; $len], b: &[u64; $len]) -> [u64; $len * 2] {
                // We could probably get a bit faster doing gradeschool multiplication for some smaller
@@ -378,14 +363,14 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
                let z1m_sign = z1a_sign == z1b_sign;
 
                let z1m = $submul(&z1a.0, &z1b.0);
-               let z1n = $add(&z0, &z2);
+               let z1n = add(&z0, &z2);
                let mut z1_carry = z1n.1;
                let z1 = if z1m_sign {
                        let r = $sub(&z1n.0, &z1m);
                        if r.1 { z1_carry ^= true; }
                        r.0
                } else {
-                       let r = $add(&z1n.0, &z1m);
+                       let r = add(&z1n.0, &z1m);
                        if r.1 { z1_carry = true; }
                        r.0
                };
@@ -396,8 +381,8 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
                let z2_end: &[u64; $len / 2] = const_subarr(&z2, $len / 2);
 
                let l = const_subslice(&z0, $len / 2, $len);
-               let (k, j_carry) = $subadd(z0_start, z1_end);
-               let (mut j, i_carry_a) = $subadd(z1_start, z2_end);
+               let (k, j_carry) = add(z0_start, z1_end);
+               let (mut j, i_carry_a) = add(z1_start, z2_end);
                let mut i_carry_b = false;
                if j_carry {
                        i_carry_b = add_u64!(j, 1);
@@ -420,12 +405,12 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
        }
 } }
 
-define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
-define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3);
-define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
-define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
-define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
-define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32);
+define_mul!(mul_4, 4, mul_2, sub_4, sub_2);
+define_mul!(mul_6, 6, mul_3, sub_6, sub_3);
+define_mul!(mul_8, 8, mul_4, sub_8, sub_4);
+define_mul!(mul_16, 16, mul_8, sub_16, sub_8);
+define_mul!(mul_32, 32, mul_16, sub_32, sub_16);
+define_mul!(mul_64, 64, mul_32, sub_64, sub_32);
 
 
 /// Squares a 128-bit integer, returning a new 256-bit integer.
@@ -443,7 +428,7 @@ const fn sqr_2(a: &[u64; 2]) -> [u64; 4] {
        add_mul_2_parts(z2, z1, z0, i_carry_a)
 }
 
-macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
+macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident) => {
        /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
        const fn $name(a: &[u64; $len]) -> [u64; $len * 2] {
                // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that.
@@ -461,8 +446,8 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id
                let v2_end: &[u64; $len / 2] = const_subarr(&v2, $len / 2);
 
                let l = const_subslice(&v0, $len / 2, $len);
-               let (k, j_carry) = $subadd(v0_start, v1_end);
-               let (mut j, i_carry_b) = $subadd(v1_start, v2_end);
+               let (k, j_carry) = add(v0_start, v1_end);
+               let (mut j, i_carry_b) = add(v1_start, v2_end);
 
                let mut i = [0; $len / 2];
                let i_source = const_subslice(&v2, 0, $len / 2);
@@ -490,12 +475,12 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id
 // TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
 const fn sqr_3(a: &[u64; 3]) -> [u64; 6] { mul_3(a, a) }
 
-define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
-define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
-define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
-define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
-define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
-define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
+define_sqr!(sqr_4, 4, mul_2, sqr_2);
+define_sqr!(sqr_6, 6, mul_3, sqr_3);
+define_sqr!(sqr_8, 8, mul_4, sqr_4);
+define_sqr!(sqr_16, 16, mul_8, sqr_8);
+define_sqr!(sqr_32, 32, mul_16, sqr_16);
+define_sqr!(sqr_64, 64, mul_32, sqr_32);
 
 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
@@ -559,7 +544,7 @@ define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to
 #[cfg(debug_assertions)]
 define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
 
-macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => {
+macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $sub_abs: ident, $mul: ident) => {
        /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus,
        /// if one exists.
        const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> {
@@ -579,17 +564,17 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: iden
                        debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
                        let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
                                (true, true) => {
-                                       let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
+                                       let (new_s, overflow) = add(&old_s, const_subarr(&new_sa, $len));
                                        debug_assert!(!overflow);
                                        (new_s, true)
                                }
                                (false, true) => {
-                                       let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
+                                       let (new_s, overflow) = add(&old_s, const_subarr(&new_sa, $len));
                                        debug_assert!(!overflow);
                                        (new_s, false)
                                },
                                (true, false) => {
-                                       let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
+                                       let (new_s, overflow) = add(&old_s, const_subarr(&new_sa, $len));
                                        debug_assert!(!overflow);
                                        (new_s, true)
                                },
@@ -622,11 +607,11 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: iden
        }
 } }
 #[cfg(fuzzing)]
-define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2);
-define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4);
-define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6);
+define_mod_inv!(mod_inv_2, 2, div_rem_2, sub_abs_2, mul_2);
+define_mod_inv!(mod_inv_4, 4, div_rem_4, sub_abs_4, mul_4);
+define_mod_inv!(mod_inv_6, 6, div_rem_6, sub_abs_6, mul_6);
 #[cfg(fuzzing)]
-define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
+define_mod_inv!(mod_inv_8, 8, div_rem_8, sub_abs_8, mul_8);
 
 // ******************
 // * The public API *
@@ -766,9 +751,9 @@ impl U4096 {
                                        fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
-                                               let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 2);
-                                               let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 2);
-                                               let (add, overflow) = add_32(a_arr, b_arr);
+                                               let a_arr: &[u64; 32] = const_subarr(a, WORD_COUNT_4096 * 3 / 2);
+                                               let b_arr: &[u64; 32] = const_subarr(b, WORD_COUNT_4096 * 3 / 2);
+                                               let (add, overflow) = add(a_arr, b_arr);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
                                                res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
                                                (res, overflow)
@@ -805,9 +790,9 @@ impl U4096 {
                                        fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
-                                               let a_arr = const_subarr(a, WORD_COUNT_4096);
-                                               let b_arr = const_subarr(b, WORD_COUNT_4096);
-                                               let (add, overflow) = add_64(a_arr, b_arr);
+                                               let a_arr: &[u64; 64] = const_subarr(a, WORD_COUNT_4096);
+                                               let b_arr: &[u64; 64] = const_subarr(b, WORD_COUNT_4096);
+                                               let (add, overflow) = add(a_arr, b_arr);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
                                                res[WORD_COUNT_4096..].copy_from_slice(&add);
                                                (res, overflow)
@@ -825,7 +810,7 @@ impl U4096 {
                                        (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty)
                                }
                        } else {
-                               (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
+                               (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add as add_double_ty, sub_64 as sub_ty)
                        };
 
                // r is always the even value with one bit set above the word count we're using.
@@ -1024,7 +1009,7 @@ const fn u256_mont_reduction_given_prime(mu: [u64; 8], prime: &[u64; 4], negativ
 
        // t_on_r = (mu + v*modulus) / R
        let t0 = mul_4(const_subarr(&v, 4), prime);
-       let (t1, t1_extra_bit) = add_8(&t0, &mu);
+       let (t1, t1_extra_bit) = add(&t0, &mu);
 
        // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
        // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
@@ -1151,7 +1136,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                let (mut val, underflow) = sub_4(&self.0.0, &b.0.0);
                if underflow {
                        let overflow;
-                       (val, overflow) = add_4(&val, &M::PRIME.0);
+                       (val, overflow) = add(&val, &M::PRIME.0);
                        debug_assert_eq!(overflow, underflow);
                }
                Self(U256(val), PhantomData)
@@ -1159,7 +1144,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
 
        /// Adds `b` to `self` % `m`.
        pub(super) fn add(&self, b: &Self) -> Self {
-               let (mut val, overflow) = add_4(&self.0.0, &b.0.0);
+               let (mut val, overflow) = add(&self.0.0, &b.0.0);
                if overflow || !slice_greater_than(&M::PRIME.0, &val) {
                        let underflow;
                        (val, underflow) = sub_4(&val, &M::PRIME.0);
@@ -1235,7 +1220,7 @@ const fn u384_mont_reduction_given_prime(mu: [u64; 12], prime: &[u64; 6], negati
 
        // t_on_r = (mu + v*modulus) / R
        let t0 = mul_6(const_subarr(&v, 6), prime);
-       let (t1, t1_extra_bit) = add_12(&t0, &mu);
+       let (t1, t1_extra_bit) = add(&t0, &mu);
 
        // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
        // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
@@ -1367,7 +1352,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                let (mut val, underflow) = sub_6(&self.0.0, &b.0.0);
                if underflow {
                        let overflow;
-                       (val, overflow) = add_6(&val, &M::PRIME.0);
+                       (val, overflow) = add(&val, &M::PRIME.0);
                        debug_assert_eq!(overflow, underflow);
                }
                Self(U384(val), PhantomData)
@@ -1375,7 +1360,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
 
        /// Adds `b` to `self` % `m`.
        pub(super) fn add(&self, b: &Self) -> Self {
-               let (mut val, overflow) = add_6(&self.0.0, &b.0.0);
+               let (mut val, overflow) = add(&self.0.0, &b.0.0);
                if overflow || !slice_greater_than(&M::PRIME.0, &val) {
                        let underflow;
                        (val, underflow) = sub_6(&val, &M::PRIME.0);
@@ -1439,7 +1424,7 @@ pub fn fuzz_math(input: &[u8]) {
                b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
        }
 
-       macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
+       macro_rules! test { ($mul: ident, $sqr: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
                let a_arg = (&a_u64s[..]).try_into().unwrap();
                let b_arg = (&b_u64s[..]).try_into().unwrap();
 
@@ -1453,7 +1438,7 @@ pub fn fuzz_math(input: &[u8]) {
                debug_assert_eq!($mul(a_arg, a_arg), $sqr(a_arg));
                debug_assert_eq!($mul(b_arg, b_arg), $sqr(b_arg));
 
-               let (res, carry) = $add(a_arg, b_arg);
+               let (res, carry) = add(a_arg, b_arg);
                let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
                if carry { res_bytes.push(1); } else { res_bytes.push(0); }
                for i in res {
@@ -1512,7 +1497,7 @@ pub fn fuzz_math(input: &[u8]) {
                }
        } }
 
-       macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => {
+       macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $sub: ident) => {
                // Test the U256/U384Mod wrapper, which operates in Montgomery representation
                let mut p_extended = [0; $len * 2];
                p_extended[$len..].copy_from_slice(&$PRIME);
@@ -1528,7 +1513,7 @@ pub fn fuzz_math(input: &[u8]) {
                assert_eq!(&abmodp[..$len], &[0; $len]);
                assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
 
-               let (aplusb, aplusb_overflow) = $add(a_arg, b_arg);
+               let (aplusb, aplusb_overflow) = add(a_arg, b_arg);
                let mut aplusb_extended = [0; $len * 2];
                aplusb_extended[$len..].copy_from_slice(&aplusb);
                if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
@@ -1539,9 +1524,9 @@ pub fn fuzz_math(input: &[u8]) {
                let (mut aminusb, aminusb_underflow) = $sub(a_arg, b_arg);
                if aminusb_underflow {
                        let mut overflow;
-                       (aminusb, overflow) = $add(&aminusb, &$PRIME);
+                       (aminusb, overflow) = add(&aminusb, &$PRIME);
                        if !overflow {
-                               (aminusb, overflow) = $add(&aminusb, &$PRIME);
+                               (aminusb, overflow) = add(&aminusb, &$PRIME);
                        }
                        assert!(overflow);
                }
@@ -1550,19 +1535,19 @@ pub fn fuzz_math(input: &[u8]) {
        } }
 
        if a_u64s.len() == 2 {
-               test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2);
+               test!(mul_2, sqr_2, sub_2, div_rem_2, mod_inv_2);
        } else if a_u64s.len() == 4 {
-               test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4);
+               test!(mul_4, sqr_4, sub_4, div_rem_4, mod_inv_4);
                let amodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(a_u64s[..].try_into().unwrap()));
                let bmodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(b_u64s[..].try_into().unwrap()));
-               test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4);
+               test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, sub_4);
        } else if a_u64s.len() == 6 {
-               test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6);
+               test!(mul_6, sqr_6, sub_6, div_rem_6, mod_inv_6);
                let amodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(a_u64s[..].try_into().unwrap()));
                let bmodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(b_u64s[..].try_into().unwrap()));
-               test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6);
+               test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, sub_6);
        } else if a_u64s.len() == 8 {
-               test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8);
+               test!(mul_8, sqr_8, sub_8, div_rem_8, mod_inv_8);
        } else if input.len() == 512*2 + 4 {
                let mut e_bytes = [0; 4];
                e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);
@@ -1687,7 +1672,7 @@ mod tests {
                        let a_int = u64s_to_u128(a);
                        let b_int = u64s_to_u128(b);
 
-                       let res = add_2(&a, &b);
+                       let res = add(&a, &b);
                        assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int));
 
                        let res = sub_2(&a, &b);