]> git.bitcoin.ninja Git - dnssec-prover/commitdiff
Use a single const-generic `sub` method rather than macro-izing
authorMatt Corallo <git@bluematt.me>
Tue, 30 Jul 2024 13:55:14 +0000 (13:55 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 1 Aug 2024 03:55:38 +0000 (03:55 +0000)
src/crypto/bigint.rs

index 5ac131376e24cb001b0f91dfe501411b7742be43..bac56517ec8c9802d163a41b1989adea1d9cb71b 100644 (file)
@@ -185,49 +185,35 @@ const fn add<const N: usize>(a: &[u64; N], b: &[u64; N]) -> ([u64; N], bool) {
        (r, carry)
 }
 
-macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
-       /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
-       /// $len-64-bit integer and an overflow bit, with the same semantics as the std
-       /// [`u64::overflowing_sub`] method.
-       const fn $name(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) {
-               let mut r = [0; $len];
-               let mut carry = false;
-               let mut i = $len - 1;
-               loop {
-                       let (v, mut new_carry) = a[i].overflowing_sub(b[i]);
-                       let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
-                       new_carry |= new_new_carry;
-                       r[i] = v2;
-                       carry = new_carry;
-
-                       if i == 0 { break; }
-                       i -= 1;
-               }
-               (r, carry)
-       }
+/// Subtracts the `b` N-64-bit integer from the `a` N-64-bit integer, returning a new
+/// N-64-bit integer and an overflow bit, with the same semantics as the std
+/// [`u64::overflowing_sub`] method.
+const fn sub<const N: usize>(a: &[u64; N], b: &[u64; N]) -> ([u64; N], bool) {
+       let mut r = [0; N];
+       let mut carry = false;
+       let mut i = N - 1;
+       loop {
+               let (v, mut new_carry) = a[i].overflowing_sub(b[i]);
+               let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
+               new_carry |= new_new_carry;
+               r[i] = v2;
+               carry = new_carry;
 
-       /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
-       /// $len-64-bit integer representing the absolute value of the result, as well as a sign bit.
-       #[allow(unused)]
-       const fn $name_abs(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) {
-               let (mut res, neg) = $name(a, b);
-               if neg {
-                       negate!(res);
-               }
-               (res, neg)
+               if i == 0 { break; }
+               i -= 1;
        }
-} }
+       (r, carry)
+}
 
-define_sub!(sub_2, sub_abs_2, 2);
-define_sub!(sub_3, sub_abs_3, 3);
-define_sub!(sub_4, sub_abs_4, 4);
-define_sub!(sub_6, sub_abs_6, 6);
-define_sub!(sub_8, sub_abs_8, 8);
-define_sub!(sub_12, sub_abs_12, 12);
-define_sub!(sub_16, sub_abs_16, 16);
-define_sub!(sub_32, sub_abs_32, 32);
-define_sub!(sub_64, sub_abs_64, 64);
-define_sub!(sub_128, sub_abs_128, 128);
+/// Subtracts the `b` N-64-bit integer from the `a` N-64-bit integer, returning a new
+/// N-64-bit integer representing the absolute value of the result, as well as a sign bit.
+const fn sub_abs<const N: usize>(a: &[u64; N], b: &[u64; N]) -> ([u64; N], bool) {
+       let (mut res, neg) = sub(a, b);
+       if neg {
+               negate!(res);
+       }
+       (res, neg)
+}
 
 /// Multiplies two 128-bit integers together, returning a new 256-bit integer.
 ///
@@ -337,7 +323,7 @@ const fn mul_3(a: &[u64; 3], b: &[u64; 3]) -> [u64; 6] {
        [r0, r1, r2, r3, r4, r5]
 }
 
-macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $sub: ident, $subsub: ident) => {
+macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident) => {
        /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
        const fn $name(a: &[u64; $len], b: &[u64; $len]) -> [u64; $len * 2] {
                // We could probably get a bit faster doing gradeschool multiplication for some smaller
@@ -356,9 +342,9 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $sub: ident
                let (z1b_max, z1b_min, z1b_sign) =
                        if slice_greater_than(b1, b0) { (b1, b0, true) } else { (b0, b1, false) };
 
-               let z1a = $subsub(z1a_max, z1a_min);
+               let z1a = sub(z1a_max, z1a_min);
                debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min");
-               let z1b = $subsub(z1b_max, z1b_min);
+               let z1b = sub(z1b_max, z1b_min);
                debug_assert!(!z1b.1, "z1b_max was selected to be greater than z1b_min");
                let z1m_sign = z1a_sign == z1b_sign;
 
@@ -366,7 +352,7 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $sub: ident
                let z1n = add(&z0, &z2);
                let mut z1_carry = z1n.1;
                let z1 = if z1m_sign {
-                       let r = $sub(&z1n.0, &z1m);
+                       let r = sub(&z1n.0, &z1m);
                        if r.1 { z1_carry ^= true; }
                        r.0
                } else {
@@ -405,12 +391,12 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $sub: ident
        }
 } }
 
-define_mul!(mul_4, 4, mul_2, sub_4, sub_2);
-define_mul!(mul_6, 6, mul_3, sub_6, sub_3);
-define_mul!(mul_8, 8, mul_4, sub_8, sub_4);
-define_mul!(mul_16, 16, mul_8, sub_16, sub_8);
-define_mul!(mul_32, 32, mul_16, sub_32, sub_16);
-define_mul!(mul_64, 64, mul_32, sub_64, sub_32);
+define_mul!(mul_4, 4, mul_2);
+define_mul!(mul_6, 6, mul_3);
+define_mul!(mul_8, 8, mul_4);
+define_mul!(mul_16, 16, mul_8);
+define_mul!(mul_32, 32, mul_16);
+define_mul!(mul_64, 64, mul_32);
 
 
 /// Squares a 128-bit integer, returning a new 256-bit integer.
@@ -485,7 +471,7 @@ define_sqr!(sqr_64, 64, mul_32, sqr_32);
 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
 
-macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init: expr, $pre_push: ident $(, $const_opt: tt)?) => {
+macro_rules! define_div_rem { ($name: ident, $len: expr, $heap_init: expr, $pre_push: ident $(, $const_opt: tt)?) => {
        /// Divides two $len-64-bit integers, `a` by `b`, returning the quotient and remainder
        ///
        /// Fails iff `b` is zero.
@@ -515,7 +501,7 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init
                        let overflow = double!(quot);
                        debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
                        if slice_greater_than(&rem, &b_pow) {
-                               let (r, underflow) = $sub(&rem, &b_pow);
+                               let (r, underflow) = sub(&rem, &b_pow);
                                debug_assert!(!underflow, "rem was just checked to be > b_pow, so sub cannot underflow");
                                rem = r;
                                quot[$len - 1] |= 1;
@@ -533,18 +519,18 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init
 } }
 
 #[cfg(fuzzing)]
-define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
-define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
-define_div_rem!(div_rem_6, 6, sub_6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack!
+define_div_rem!(div_rem_2, 2, [[0; 2]; 2 * 64], dummy_pre_push, const);
+define_div_rem!(div_rem_4, 4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
+define_div_rem!(div_rem_6, 6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack!
 #[cfg(debug_assertions)]
-define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
+define_div_rem!(div_rem_8, 8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
 #[cfg(debug_assertions)]
-define_div_rem!(div_rem_12, 12, sub_12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack!
-define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
+define_div_rem!(div_rem_12, 12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack!
+define_div_rem!(div_rem_64, 64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
 #[cfg(debug_assertions)]
-define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
+define_div_rem!(div_rem_128, 128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
 
-macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $sub_abs: ident, $mul: ident) => {
+macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $mul: ident) => {
        /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus,
        /// if one exists.
        const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> {
@@ -578,7 +564,7 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $sub_abs:
                                        debug_assert!(!overflow);
                                        (new_s, true)
                                },
-                               (false, false) => $sub_abs(&old_s, const_subarr(&new_sa, $len)),
+                               (false, false) => sub_abs(&old_s, const_subarr(&new_sa, $len)),
                        };
 
                        old_r = r;
@@ -596,7 +582,7 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $sub_abs:
                } else {
                        debug_assert!(slice_greater_than(m, &old_s));
                        if old_s_neg {
-                               let (modinv, underflow) = $sub_abs(m, &old_s);
+                               let (modinv, underflow) = sub_abs(m, &old_s);
                                debug_assert!(!underflow);
                                debug_assert!(slice_greater_than(m, &modinv));
                                Ok(modinv)
@@ -607,11 +593,11 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $sub_abs:
        }
 } }
 #[cfg(fuzzing)]
-define_mod_inv!(mod_inv_2, 2, div_rem_2, sub_abs_2, mul_2);
-define_mod_inv!(mod_inv_4, 4, div_rem_4, sub_abs_4, mul_4);
-define_mod_inv!(mod_inv_6, 6, div_rem_6, sub_abs_6, mul_6);
+define_mod_inv!(mod_inv_2, 2, div_rem_2, mul_2);
+define_mod_inv!(mod_inv_4, 4, div_rem_4, mul_4);
+define_mod_inv!(mod_inv_6, 6, div_rem_6, mul_6);
 #[cfg(fuzzing)]
-define_mod_inv!(mod_inv_8, 8, div_rem_8, sub_abs_8, mul_8);
+define_mod_inv!(mod_inv_8, 8, div_rem_8, mul_8);
 
 // ******************
 // * The public API *
@@ -761,9 +747,9 @@ impl U4096 {
                                        fn sub_16_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> ([u64; WORD_COUNT_4096], bool) {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
-                                               let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4);
-                                               let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 4);
-                                               let (sub, underflow) = sub_16(a_arr, b_arr);
+                                               let a_arr: &[u64; 16] = const_subarr(a, WORD_COUNT_4096 * 3 / 4);
+                                               let b_arr: &[u64; 16] = const_subarr(b, WORD_COUNT_4096 * 3 / 4);
+                                               let (sub, underflow) = sub(a_arr, b_arr);
                                                let mut res = [0; WORD_COUNT_4096];
                                                res[WORD_COUNT_4096 * 3 / 4..].copy_from_slice(&sub);
                                                (res, underflow)
@@ -800,9 +786,9 @@ impl U4096 {
                                        fn sub_32_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> ([u64; WORD_COUNT_4096], bool) {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
-                                               let a_arr = const_subarr(a, WORD_COUNT_4096 / 2);
-                                               let b_arr = const_subarr(b, WORD_COUNT_4096 / 2);
-                                               let (sub, underflow) = sub_32(a_arr, b_arr);
+                                               let a_arr: &[u64; 32] = const_subarr(a, WORD_COUNT_4096 / 2);
+                                               let b_arr: &[u64; 32] = const_subarr(b, WORD_COUNT_4096 / 2);
+                                               let (sub, underflow) = sub(a_arr, b_arr);
                                                let mut res = [0; WORD_COUNT_4096];
                                                res[WORD_COUNT_4096 / 2..].copy_from_slice(&sub);
                                                (res, underflow)
@@ -810,7 +796,7 @@ impl U4096 {
                                        (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty)
                                }
                        } else {
-                               (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add as add_double_ty, sub_64 as sub_ty)
+                               (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add as add_double_ty, sub as sub_ty)
                        };
 
                // r is always the even value with one bit set above the word count we're using.
@@ -891,7 +877,7 @@ impl U4096 {
                let (_, mut r_mod_m) = debug_unwrap!(div_rem_64(&r_minus_one, &m.0));
                let r_mod_m_overflow = add_u64!(r_mod_m, 1);
                if r_mod_m_overflow || r_mod_m >= m.0 {
-                       (r_mod_m, _) = sub_64(&r_mod_m, &m.0);
+                       (r_mod_m, _) = crate::crypto::bigint::sub(&r_mod_m, &m.0);
                }
 
                let mut r2_mod_m: [u64; 64] = r_mod_m;
@@ -901,7 +887,7 @@ impl U4096 {
                for _ in 0..DOUBLES {
                        let overflow = double!(r2_mod_m);
                        if overflow || r2_mod_m > m.0 {
-                               (r2_mod_m, _) = sub_64(&r2_mod_m, &m.0);
+                               (r2_mod_m, _) = crate::crypto::bigint::sub(&r2_mod_m, &m.0);
                        }
                }
                for _ in 0..log_bits - LOG2_DOUBLES {
@@ -1021,7 +1007,7 @@ const fn u256_mont_reduction_given_prime(mu: [u64; 8], prime: &[u64; 4], negativ
        // modulus.
        if t1_extra_bit || slice_greater_than(t1_on_r, prime) {
                let underflow;
-               (res, underflow) = sub_4(t1_on_r, prime);
+               (res, underflow) = sub(t1_on_r, prime);
                debug_assert!(t1_extra_bit == underflow,
                        "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
        } else {
@@ -1044,7 +1030,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                        // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
                        // should be able to do it at compile time alone.
                        let r_minus_one = [0xffff_ffff_ffff_ffff; 4];
-                       let (mut r_mod_prime, _) = sub_4(&r_minus_one, &M::PRIME.0);
+                       let (mut r_mod_prime, _) = sub(&r_minus_one, &M::PRIME.0);
                        let r_mod_prime_overflow = add_u64!(r_mod_prime, 1);
                        assert!(!r_mod_prime_overflow);
                        let r_squared = sqr_4(&r_mod_prime);
@@ -1078,7 +1064,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                debug_assert!(M::PRIME.0 != [0; 4]);
                debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
                while v >= M::PRIME {
-                       let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
+                       let (new_v, spurious_underflow) = sub(&v.0, &M::PRIME.0);
                        debug_assert!(!spurious_underflow, "v was > M::PRIME.0");
                        v = U256(new_v);
                }
@@ -1102,7 +1088,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                let overflow = double!(res);
                if overflow || !slice_greater_than(&M::PRIME.0, &res) {
                        let underflow;
-                       (res, underflow) = sub_4(&res, &M::PRIME.0);
+                       (res, underflow) = sub(&res, &M::PRIME.0);
                        debug_assert_eq!(overflow, underflow);
                }
                Self(U256(res), PhantomData)
@@ -1133,7 +1119,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
 
        /// Subtracts `b` from `self` % `m`.
        pub(super) fn sub(&self, b: &Self) -> Self {
-               let (mut val, underflow) = sub_4(&self.0.0, &b.0.0);
+               let (mut val, underflow) = sub(&self.0.0, &b.0.0);
                if underflow {
                        let overflow;
                        (val, overflow) = add(&val, &M::PRIME.0);
@@ -1147,7 +1133,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                let (mut val, overflow) = add(&self.0.0, &b.0.0);
                if overflow || !slice_greater_than(&M::PRIME.0, &val) {
                        let underflow;
-                       (val, underflow) = sub_4(&val, &M::PRIME.0);
+                       (val, underflow) = sub(&val, &M::PRIME.0);
                        debug_assert_eq!(overflow, underflow);
                }
                Self(U256(val), PhantomData)
@@ -1232,7 +1218,7 @@ const fn u384_mont_reduction_given_prime(mu: [u64; 12], prime: &[u64; 6], negati
        // modulus.
        if t1_extra_bit || slice_greater_than(t1_on_r, prime) {
                let underflow;
-               (res, underflow) = sub_6(t1_on_r, prime);
+               (res, underflow) = sub(t1_on_r, prime);
                debug_assert!(t1_extra_bit == underflow);
        } else {
                copy_from_slice!(res, 0, 6, t1_on_r);
@@ -1253,7 +1239,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                        // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
                        // should be able to do it at compile time alone.
                        let r_minus_one = [0xffff_ffff_ffff_ffff; 6];
-                       let (mut r_mod_prime, _) = sub_6(&r_minus_one, &M::PRIME.0);
+                       let (mut r_mod_prime, _) = sub(&r_minus_one, &M::PRIME.0);
                        let r_mod_prime_overflow = add_u64!(r_mod_prime, 1);
                        assert!(!r_mod_prime_overflow);
                        let r_squared = sqr_6(&r_mod_prime);
@@ -1294,7 +1280,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                debug_assert!(M::PRIME.0 != [0; 6]);
                debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
                while v >= M::PRIME {
-                       let (new_v, spurious_underflow) = sub_6(&v.0, &M::PRIME.0);
+                       let (new_v, spurious_underflow) = sub(&v.0, &M::PRIME.0);
                        debug_assert!(!spurious_underflow);
                        v = U384(new_v);
                }
@@ -1318,7 +1304,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                let overflow = double!(res);
                if overflow || !slice_greater_than(&M::PRIME.0, &res) {
                        let underflow;
-                       (res, underflow) = sub_6(&res, &M::PRIME.0);
+                       (res, underflow) = sub(&res, &M::PRIME.0);
                        debug_assert_eq!(overflow, underflow);
                }
                Self(U384(res), PhantomData)
@@ -1349,7 +1335,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
 
        /// Subtracts `b` from `self` % `m`.
        pub(super) fn sub(&self, b: &Self) -> Self {
-               let (mut val, underflow) = sub_6(&self.0.0, &b.0.0);
+               let (mut val, underflow) = sub(&self.0.0, &b.0.0);
                if underflow {
                        let overflow;
                        (val, overflow) = add(&val, &M::PRIME.0);
@@ -1363,7 +1349,7 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                let (mut val, overflow) = add(&self.0.0, &b.0.0);
                if overflow || !slice_greater_than(&M::PRIME.0, &val) {
                        let underflow;
-                       (val, underflow) = sub_6(&val, &M::PRIME.0);
+                       (val, underflow) = sub(&val, &M::PRIME.0);
                        debug_assert_eq!(overflow, underflow);
                }
                Self(U384(val), PhantomData)
@@ -1424,7 +1410,7 @@ pub fn fuzz_math(input: &[u8]) {
                b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
        }
 
-       macro_rules! test { ($mul: ident, $sqr: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
+       macro_rules! test { ($mul: ident, $sqr: ident, $div_rem: ident, $mod_inv: ident) => {
                let a_arg = (&a_u64s[..]).try_into().unwrap();
                let b_arg = (&b_u64s[..]).try_into().unwrap();
 
@@ -1497,7 +1483,7 @@ pub fn fuzz_math(input: &[u8]) {
                }
        } }
 
-       macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $sub: ident) => {
+       macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident) => {
                // Test the U256/U384Mod wrapper, which operates in Montgomery representation
                let mut p_extended = [0; $len * 2];
                p_extended[$len..].copy_from_slice(&$PRIME);
@@ -1521,7 +1507,7 @@ pub fn fuzz_math(input: &[u8]) {
                assert_eq!(&aplusbmodp[..$len], &[0; $len]);
                assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
 
-               let (mut aminusb, aminusb_underflow) = $sub(a_arg, b_arg);
+               let (mut aminusb, aminusb_underflow) = sub(a_arg, b_arg);
                if aminusb_underflow {
                        let mut overflow;
                        (aminusb, overflow) = add(&aminusb, &$PRIME);
@@ -1535,19 +1521,19 @@ pub fn fuzz_math(input: &[u8]) {
        } }
 
        if a_u64s.len() == 2 {
-               test!(mul_2, sqr_2, sub_2, div_rem_2, mod_inv_2);
+               test!(mul_2, sqr_2, div_rem_2, mod_inv_2);
        } else if a_u64s.len() == 4 {
-               test!(mul_4, sqr_4, sub_4, div_rem_4, mod_inv_4);
+               test!(mul_4, sqr_4, div_rem_4, mod_inv_4);
                let amodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(a_u64s[..].try_into().unwrap()));
                let bmodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(b_u64s[..].try_into().unwrap()));
-               test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, sub_4);
+               test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4);
        } else if a_u64s.len() == 6 {
-               test!(mul_6, sqr_6, sub_6, div_rem_6, mod_inv_6);
+               test!(mul_6, sqr_6, div_rem_6, mod_inv_6);
                let amodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(a_u64s[..].try_into().unwrap()));
                let bmodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(b_u64s[..].try_into().unwrap()));
-               test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, sub_6);
+               test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6);
        } else if a_u64s.len() == 8 {
-               test!(mul_8, sqr_8, sub_8, div_rem_8, mod_inv_8);
+               test!(mul_8, sqr_8, div_rem_8, mod_inv_8);
        } else if input.len() == 512*2 + 4 {
                let mut e_bytes = [0; 4];
                e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);
@@ -1675,7 +1661,7 @@ mod tests {
                        let res = add(&a, &b);
                        assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int));
 
-                       let res = sub_2(&a, &b);
+                       let res = sub(&a, &b);
                        assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_sub(b_int));
                }