From: Matt Corallo Date: Fri, 3 May 2024 16:41:54 +0000 (+0000) Subject: Clean up carry/debug assertions in multiplies/squaring X-Git-Tag: v0.5.4~14 X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=96c42f1531b76cc643cae79da0f8ebcf6dbbdf4b;p=dnssec-prover Clean up carry/debug assertions in multiplies/squaring --- diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index d20ec1f..d07bb56 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -270,9 +270,14 @@ const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] { let z2 = a0 * b0; let z1i = a0 * b1; let z1j = b0 * a1; - let (z1, i_carry) = z1i.overflowing_add(z1j); + let (z1, i_carry_a) = z1i.overflowing_add(z1j); let z0 = a1 * b1; + add_mul_2_parts(z2, z1, z0, i_carry_a) +} + +/// Adds the gradeschool multiplication intermediate parts to a final 256-bit result +const fn add_mul_2_parts(z2: u128, z1: u128, z0: u128, i_carry_a: bool) -> [u64; 4] { let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64; let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64; let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64; @@ -281,20 +286,16 @@ const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] { let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64; let l = z0b; + let (k, j_carry) = z0a.overflowing_add(z1b); - let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b); - let new_i_carry; - (j, new_i_carry) = j.overflowing_add(j_carry as u64); - debug_assert!(!second_i_carry || !new_i_carry); - second_i_carry |= new_i_carry; + let (mut j, i_carry_b) = z1a.overflowing_add(z2b); + let i_carry_c; + (j, i_carry_c) = j.overflowing_add(j_carry as u64); - let mut i = z2a; - let mut spurious_overflow; - (i, spurious_overflow) = i.overflowing_add(i_carry as u64); - debug_assert!(!spurious_overflow); - (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64); - debug_assert!(!spurious_overflow); + let i_carry = i_carry_a as u64 + i_carry_b as u64 + i_carry_c as u64; + let (i, must_not_overflow) = z2a.overflowing_add(i_carry); + debug_assert!(!must_not_overflow, "Two 2*64 bit numbers, multiplied, will not use more than 4*64 bits"); [i, j, k, l] } @@ -364,7 +365,7 @@ const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] { let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64; let (r0, must_not_overflow) = r0a.overflowing_add(r0_c); - debug_assert!(!must_not_overflow); + debug_assert!(!must_not_overflow, "Two 3*64 bit numbers, multiplied, will not use more than 6*64 bits"); [r0, r1, r2, r3, r4, r5] } @@ -392,9 +393,9 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) }; let z1a = $subsub(z1a_max, z1a_min); - debug_assert!(!z1a.1); + debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min"); let z1b = $subsub(z1b_max, z1b_min); - debug_assert!(!z1b.1); + debug_assert!(!z1b.1, "z1b_max was selected to be greater than z1b_min"); let z1m_sign = z1a_sign == z1b_sign; let z1m = $submul(&z1a.0, &z1b.0); @@ -412,22 +413,18 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident let l = const_subslice(&z0, $len / 2, $len); let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len)); - let (mut j, mut i_carry) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len)); + let (mut j, i_carry_a) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len)); + let mut i_carry_b = false; if j_carry { - let new_i_carry = add_u64!(j, 1); - debug_assert!(!i_carry || !new_i_carry); - i_carry |= new_i_carry; + i_carry_b = add_u64!(j, 1); } let mut i = [0; $len / 2]; let i_source = const_subslice(&z2, 0, $len / 2); copy_from_slice!(i, 0, $len / 2, i_source); - if i_carry { - let spurious_carry = add_u64!(i, 1); - debug_assert!(!spurious_carry); - } - if z1_carry { - let spurious_carry = add_u64!(i, 1); - debug_assert!(!spurious_carry); + let i_carry = i_carry_a as u64 + i_carry_b as u64 + z1_carry as u64; + if i_carry != 0 { + let must_not_overflow = add_u64!(i, i_carry); + debug_assert!(!must_not_overflow, "Two N*64 bit numbers, multiplied, will not use more than 2*N*64 bits"); } let mut res = [0; $len * 2]; @@ -457,39 +454,17 @@ const fn sqr_2(a: &[u64]) -> [u64; 4] { let (a0, a1) = (a[0] as u128, a[1] as u128); let z2 = a0 * a0; let mut z1 = a0 * a1; - let i_carry = z1 & (1u128 << 127) != 0; + let i_carry_a = z1 & (1u128 << 127) != 0; z1 <<= 1; let z0 = a1 * a1; - let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64; - let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64; - let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64; - let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64; - let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64; - let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64; - - let l = z0b; - let (k, j_carry) = z0a.overflowing_add(z1b); - let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b); - - let new_i_carry; - (j, new_i_carry) = j.overflowing_add(j_carry as u64); - debug_assert!(!second_i_carry || !new_i_carry); - second_i_carry |= new_i_carry; - - let mut i = z2a; - let mut spurious_overflow; - (i, spurious_overflow) = i.overflowing_add(i_carry as u64); - debug_assert!(!spurious_overflow); - (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64); - debug_assert!(!spurious_overflow); - - [i, j, k, l] + add_mul_2_parts(z2, z1, z0, i_carry_a) } macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => { /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer. const fn $name(a: &[u64]) -> [u64; $len * 2] { + // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that. debug_assert!(a.len() == $len); let hi = const_subslice(a, 0, $len / 2); @@ -497,29 +472,25 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id let v0 = $subsqr(lo); let mut v1 = $submul(hi, lo); - let i_carry = double!(v1); + let i_carry_a = double!(v1); let v2 = $subsqr(hi); let l = const_subslice(&v0, $len / 2, $len); let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len)); - let (mut j, mut i_carry_2) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len)); + let (mut j, i_carry_b) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len)); let mut i = [0; $len / 2]; let i_source = const_subslice(&v2, 0, $len / 2); copy_from_slice!(i, 0, $len / 2, i_source); + let mut i_carry_c = false; if j_carry { - let new_i_carry = add_u64!(j, 1); - debug_assert!(!i_carry_2 || !new_i_carry); - i_carry_2 |= new_i_carry; - } - if i_carry { - let spurious_carry = add_u64!(i, 1); - debug_assert!(!spurious_carry); + i_carry_c = add_u64!(j, 1); } - if i_carry_2 { - let spurious_carry = add_u64!(i, 1); - debug_assert!(!spurious_carry); + let i_carry = i_carry_a as u64 + i_carry_b as u64 + i_carry_c as u64; + if i_carry != 0 { + let must_not_overflow = add_u64!(i, i_carry); + debug_assert!(!must_not_overflow, "Two N*64 bit numbers, multiplied, will not use more than 2*N*64 bits"); } let mut res = [0; $len * 2];