From fc9d2e28d5494029ade2fc1d24fb8de8644d2b45 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Fri, 3 May 2024 16:30:28 +0000 Subject: [PATCH] Clean up + test add/sub/negate, fixing a debug assert in negate --- src/crypto/bigint.rs | 130 +++++++++++++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 29 deletions(-) diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index 6c9bd6e..d20ec1f 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -146,8 +146,7 @@ macro_rules! negate { ($v: ident) => { { $v[i] ^= 0xffff_ffff_ffff_ffff; i += 1; } - let overflow = add_u64!($v, 1); - debug_assert!(!overflow); + add_u64!($v, 1); } } } /// Doubles in-place, returning an overflow flag, in which case one out-of-bounds high bit is @@ -158,15 +157,21 @@ macro_rules! double { ($a: ident) => { { { let _: &[u64] = &$a; } // Force type resolution let len = $a.len(); let mut carry = false; - let mut i = 0; - while i < len { - let mut next_carry = ($a[len - 1 - i] & (1 << 63)) != 0; - let (v, next_carry_2) = ($a[len - 1 - i] << 1).overflowing_add(carry as u64); - $a[len - 1 - i] = v; - debug_assert!(!next_carry || !next_carry_2); - next_carry |= next_carry_2; + let mut i = len - 1; + loop { + let next_carry = ($a[i] & (1 << 63)) != 0; + let (v, _next_carry_2) = ($a[i] << 1).overflowing_add(carry as u64); + if !next_carry { + debug_assert!(!_next_carry_2, "Adding one to 0x7ffff..*2 is only 0xffff.."); + } + // Note that we can ignore _next_carry_2 here as we never need it - it cannot be set if + // next_carry is not set and at max 0xffff..*2 + 1 is only 0x1ffff.. (i.e. we can not need + // a double-carry). + $a[i] = v; carry = next_carry; - i += 1; + + if i == 0 { break; } + i -= 1; } carry } } } @@ -179,15 +184,16 @@ macro_rules! define_add { ($name: ident, $len: expr) => { debug_assert!(b.len() == $len); let mut r = [0; $len]; let mut carry = false; - let mut i = 0; - while i < $len { - let pos = $len - 1 - i; - let (v, mut new_carry) = a[pos].overflowing_add(b[pos]); + let mut i = $len - 1; + loop { + let (v, mut new_carry) = a[i].overflowing_add(b[i]); let (v2, new_new_carry) = v.overflowing_add(carry as u64); new_carry |= new_new_carry; - r[pos] = v2; + r[i] = v2; carry = new_carry; - i += 1; + + if i == 0 { break; } + i -= 1; } (r, carry) } @@ -213,15 +219,16 @@ macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => { debug_assert!(b.len() == $len); let mut r = [0; $len]; let mut carry = false; - let mut i = 0; - while i < $len { - let pos = $len - 1 - i; - let (v, mut new_carry) = a[pos].overflowing_sub(b[pos]); + let mut i = $len - 1; + loop { + let (v, mut new_carry) = a[i].overflowing_sub(b[i]); let (v2, new_new_carry) = v.overflowing_sub(carry as u64); new_carry |= new_new_carry; - r[pos] = v2; + r[i] = v2; carry = new_carry; - i += 1; + + if i == 0 { break; } + i -= 1; } (r, carry) } @@ -1577,6 +1584,58 @@ pub fn fuzz_math(input: &[u8]) { mod tests { use super::*; + fn u64s_to_u128(v: [u64; 2]) -> u128 { + let mut r = 0; + r |= v[1] as u128; + r |= (v[0] as u128) << 64; + r + } + + fn u64s_to_i128(v: [u64; 2]) -> i128 { + let mut r = 0; + r |= v[1] as i128; + r |= (v[0] as i128) << 64; + r + } + + #[test] + fn test_negate() { + let mut zero = [0u64; 2]; + negate!(zero); + assert_eq!(zero, [0; 2]); + + let mut one = [0u64, 1u64]; + negate!(one); + assert_eq!(u64s_to_i128(one), -1); + + let mut minus_one: [u64; 2] = [u64::MAX, u64::MAX]; + negate!(minus_one); + assert_eq!(minus_one, [0, 1]); + } + + #[test] + fn test_double() { + let mut zero = [0u64; 2]; + assert!(!double!(zero)); + assert_eq!(zero, [0; 2]); + + let mut one = [0u64, 1u64]; + assert!(!double!(one)); + assert_eq!(one, [0, 2]); + + let mut u64_max = [0, u64::MAX]; + assert!(!double!(u64_max)); + assert_eq!(u64_max, [1, u64::MAX - 1]); + + let mut u64_carry_overflow = [0x7fff_ffff_ffff_ffffu64, 0x8000_0000_0000_0000]; + assert!(!double!(u64_carry_overflow)); + assert_eq!(u64_carry_overflow, [u64::MAX, 0]); + + let mut max = [u64::MAX; 4]; + assert!(double!(max)); + assert_eq!(max, [u64::MAX, u64::MAX, u64::MAX, u64::MAX - 1]); + } + #[test] fn mul_min_simple_tests() { let a = [1, 2]; @@ -1621,14 +1680,27 @@ mod tests { } #[test] - fn add_simple_tests() { - let a = [u64::MAX; 2]; - let b = [u64::MAX; 2]; - assert_eq!(add_2(&a, &b), ([18446744073709551615, 18446744073709551614], true)); + fn test_add_sub() { + fn test(a: [u64; 2], b: [u64; 2]) { + let a_int = u64s_to_u128(a); + let b_int = u64s_to_u128(b); - let a = [0x1bad_cafe_dead_beef, 2424]; - let b = [0x2bad_beef_dead_cafe, 4242]; - assert_eq!(add_2(&a, &b), ([5141855058045667821, 6666], false)); + let res = add_2(&a, &b); + assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int)); + + let res = sub_2(&a, &b); + assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_sub(b_int)); + } + + test([0; 2], [0; 2]); + test([0x1bad_cafe_dead_beef, 2424], [0x2bad_cafe_dead_cafe, 4242]); + test([u64::MAX; 2], [u64::MAX; 2]); + test([u64::MAX, 0x8000_0000_0000_0000], [0, 0x7fff_ffff_ffff_ffff]); + test([0, 0x7fff_ffff_ffff_ffff], [u64::MAX, 0x8000_0000_0000_0000]); + test([u64::MAX, 0], [0, u64::MAX]); + test([0, u64::MAX], [u64::MAX, 0]); + test([u64::MAX; 2], [0; 2]); + test([0; 2], [u64::MAX; 2]); } #[test] -- 2.39.5