Clean up + test add/sub/negate, fixing a debug assert in negate
authorMatt Corallo <git@bluematt.me>
Fri, 3 May 2024 16:30:28 +0000 (16:30 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 3 May 2024 18:48:43 +0000 (18:48 +0000)
src/crypto/bigint.rs

index 6c9bd6e9930ab2209ab652d05ff9b1b5889d6749..d20ec1fc39ff8ccf950683631442f73df765ea81 100644 (file)
@@ -146,8 +146,7 @@ macro_rules! negate { ($v: ident) => { {
                $v[i] ^= 0xffff_ffff_ffff_ffff;
                i += 1;
        }
-       let overflow = add_u64!($v, 1);
-       debug_assert!(!overflow);
+       add_u64!($v, 1);
 } } }
 
 /// Doubles in-place, returning an overflow flag, in which case one out-of-bounds high bit is
@@ -158,15 +157,21 @@ macro_rules! double { ($a: ident) => { {
        { let _: &[u64] = &$a; } // Force type resolution
        let len = $a.len();
        let mut carry = false;
-       let mut i = 0;
-       while i < len {
-               let mut next_carry = ($a[len - 1 - i] & (1 << 63)) != 0;
-               let (v, next_carry_2) = ($a[len - 1 - i] << 1).overflowing_add(carry as u64);
-               $a[len - 1 - i] = v;
-               debug_assert!(!next_carry || !next_carry_2);
-               next_carry |= next_carry_2;
+       let mut i = len - 1;
+       loop {
+               let next_carry = ($a[i] & (1 << 63)) != 0;
+               let (v, _next_carry_2) = ($a[i] << 1).overflowing_add(carry as u64);
+               if !next_carry {
+                       debug_assert!(!_next_carry_2, "Adding one to 0x7ffff..*2 is only 0xffff..");
+               }
+               // Note that we can ignore _next_carry_2 here as we never need it - it cannot be set if
+               // next_carry is not set and at max 0xffff..*2 + 1 is only 0x1ffff.. (i.e. we can not need
+               // a double-carry).
+               $a[i] = v;
                carry = next_carry;
-               i += 1;
+
+               if i == 0 { break; }
+               i -= 1;
        }
        carry
 } } }
@@ -179,15 +184,16 @@ macro_rules! define_add { ($name: ident, $len: expr) => {
                debug_assert!(b.len() == $len);
                let mut r = [0; $len];
                let mut carry = false;
-               let mut i = 0;
-               while i < $len {
-                       let pos = $len - 1 - i;
-                       let (v, mut new_carry) = a[pos].overflowing_add(b[pos]);
+               let mut i = $len - 1;
+               loop {
+                       let (v, mut new_carry) = a[i].overflowing_add(b[i]);
                        let (v2, new_new_carry) = v.overflowing_add(carry as u64);
                        new_carry |= new_new_carry;
-                       r[pos] = v2;
+                       r[i] = v2;
                        carry = new_carry;
-                       i += 1;
+
+                       if i == 0 { break; }
+                       i -= 1;
                }
                (r, carry)
        }
@@ -213,15 +219,16 @@ macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
                debug_assert!(b.len() == $len);
                let mut r = [0; $len];
                let mut carry = false;
-               let mut i = 0;
-               while i < $len {
-                       let pos = $len - 1 - i;
-                       let (v, mut new_carry) = a[pos].overflowing_sub(b[pos]);
+               let mut i = $len - 1;
+               loop {
+                       let (v, mut new_carry) = a[i].overflowing_sub(b[i]);
                        let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
                        new_carry |= new_new_carry;
-                       r[pos] = v2;
+                       r[i] = v2;
                        carry = new_carry;
-                       i += 1;
+
+                       if i == 0 { break; }
+                       i -= 1;
                }
                (r, carry)
        }
@@ -1577,6 +1584,58 @@ pub fn fuzz_math(input: &[u8]) {
 mod tests {
        use super::*;
 
+       fn u64s_to_u128(v: [u64; 2]) -> u128 {
+               let mut r = 0;
+               r |= v[1] as u128;
+               r |= (v[0] as u128) << 64;
+               r
+       }
+
+       fn u64s_to_i128(v: [u64; 2]) -> i128 {
+               let mut r = 0;
+               r |= v[1] as i128;
+               r |= (v[0] as i128) << 64;
+               r
+       }
+
+       #[test]
+       fn test_negate() {
+               let mut zero = [0u64; 2];
+               negate!(zero);
+               assert_eq!(zero, [0; 2]);
+
+               let mut one = [0u64, 1u64];
+               negate!(one);
+               assert_eq!(u64s_to_i128(one), -1);
+
+               let mut minus_one: [u64; 2] = [u64::MAX, u64::MAX];
+               negate!(minus_one);
+               assert_eq!(minus_one, [0, 1]);
+       }
+
+       #[test]
+       fn test_double() {
+               let mut zero = [0u64; 2];
+               assert!(!double!(zero));
+               assert_eq!(zero, [0; 2]);
+
+               let mut one = [0u64, 1u64];
+               assert!(!double!(one));
+               assert_eq!(one, [0, 2]);
+
+               let mut u64_max = [0, u64::MAX];
+               assert!(!double!(u64_max));
+               assert_eq!(u64_max, [1, u64::MAX - 1]);
+
+               let mut u64_carry_overflow = [0x7fff_ffff_ffff_ffffu64, 0x8000_0000_0000_0000];
+               assert!(!double!(u64_carry_overflow));
+               assert_eq!(u64_carry_overflow, [u64::MAX, 0]);
+
+               let mut max = [u64::MAX; 4];
+               assert!(double!(max));
+               assert_eq!(max, [u64::MAX, u64::MAX, u64::MAX, u64::MAX - 1]);
+       }
+
        #[test]
        fn mul_min_simple_tests() {
                let a = [1, 2];
@@ -1621,14 +1680,27 @@ mod tests {
        }
 
        #[test]
-       fn add_simple_tests() {
-               let a = [u64::MAX; 2];
-               let b = [u64::MAX; 2];
-               assert_eq!(add_2(&a, &b), ([18446744073709551615, 18446744073709551614], true));
+       fn test_add_sub() {
+               fn test(a: [u64; 2], b: [u64; 2]) {
+                       let a_int = u64s_to_u128(a);
+                       let b_int = u64s_to_u128(b);
 
-               let a = [0x1bad_cafe_dead_beef, 2424];
-               let b = [0x2bad_beef_dead_cafe, 4242];
-               assert_eq!(add_2(&a, &b), ([5141855058045667821, 6666], false));
+                       let res = add_2(&a, &b);
+                       assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int));
+
+                       let res = sub_2(&a, &b);
+                       assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_sub(b_int));
+               }
+
+               test([0; 2], [0; 2]);
+               test([0x1bad_cafe_dead_beef, 2424], [0x2bad_cafe_dead_cafe, 4242]);
+               test([u64::MAX; 2], [u64::MAX; 2]);
+               test([u64::MAX, 0x8000_0000_0000_0000], [0, 0x7fff_ffff_ffff_ffff]);
+               test([0, 0x7fff_ffff_ffff_ffff], [u64::MAX, 0x8000_0000_0000_0000]);
+               test([u64::MAX, 0], [0, u64::MAX]);
+               test([0, u64::MAX], [u64::MAX, 0]);
+               test([u64::MAX; 2], [0; 2]);
+               test([0; 2], [u64::MAX; 2]);
        }
 
        #[test]