Clean up carry/debug assertions in multiplies/squaring
[dnssec-prover] / src / crypto / bigint.rs
index d20ec1fc39ff8ccf950683631442f73df765ea81..d07bb5622789c1ad45085516af5f1d9cfd8f1914 100644 (file)
@@ -270,9 +270,14 @@ const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
        let z2 = a0 * b0;
        let z1i = a0 * b1;
        let z1j = b0 * a1;
-       let (z1, i_carry) = z1i.overflowing_add(z1j);
+       let (z1, i_carry_a) = z1i.overflowing_add(z1j);
        let z0 = a1 * b1;
 
+       add_mul_2_parts(z2, z1, z0, i_carry_a)
+}
+
+/// Adds the gradeschool multiplication intermediate parts to a final 256-bit result
+const fn add_mul_2_parts(z2: u128, z1: u128, z0: u128, i_carry_a: bool) -> [u64; 4] {
        let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
        let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
        let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
@@ -281,20 +286,16 @@ const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
        let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
 
        let l = z0b;
+
        let (k, j_carry) = z0a.overflowing_add(z1b);
-       let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
 
-       let new_i_carry;
-       (j, new_i_carry) = j.overflowing_add(j_carry as u64);
-       debug_assert!(!second_i_carry || !new_i_carry);
-       second_i_carry |= new_i_carry;
+       let (mut j, i_carry_b) = z1a.overflowing_add(z2b);
+       let i_carry_c;
+       (j, i_carry_c) = j.overflowing_add(j_carry as u64);
 
-       let mut i = z2a;
-       let mut spurious_overflow;
-       (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
-       debug_assert!(!spurious_overflow);
-       (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
-       debug_assert!(!spurious_overflow);
+       let i_carry = i_carry_a as u64 + i_carry_b as u64 + i_carry_c as u64;
+       let (i, must_not_overflow) = z2a.overflowing_add(i_carry);
+       debug_assert!(!must_not_overflow, "Two 2*64 bit numbers, multiplied, will not use more than 4*64 bits");
 
        [i, j, k, l]
 }
@@ -364,7 +365,7 @@ const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
        let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64;
 
        let (r0, must_not_overflow) = r0a.overflowing_add(r0_c);
-       debug_assert!(!must_not_overflow);
+       debug_assert!(!must_not_overflow, "Two 3*64 bit numbers, multiplied, will not use more than 6*64 bits");
 
        [r0, r1, r2, r3, r4, r5]
 }
@@ -392,9 +393,9 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
                        if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
 
                let z1a = $subsub(z1a_max, z1a_min);
-               debug_assert!(!z1a.1);
+               debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min");
                let z1b = $subsub(z1b_max, z1b_min);
-               debug_assert!(!z1b.1);
+               debug_assert!(!z1b.1, "z1b_max was selected to be greater than z1b_min");
                let z1m_sign = z1a_sign == z1b_sign;
 
                let z1m = $submul(&z1a.0, &z1b.0);
@@ -412,22 +413,18 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
 
                let l = const_subslice(&z0, $len / 2, $len);
                let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
-               let (mut j, mut i_carry) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
+               let (mut j, i_carry_a) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
+               let mut i_carry_b = false;
                if j_carry {
-                       let new_i_carry = add_u64!(j, 1);
-                       debug_assert!(!i_carry || !new_i_carry);
-                       i_carry |= new_i_carry;
+                       i_carry_b = add_u64!(j, 1);
                }
                let mut i = [0; $len / 2];
                let i_source = const_subslice(&z2, 0, $len / 2);
                copy_from_slice!(i, 0, $len / 2, i_source);
-               if i_carry {
-                       let spurious_carry = add_u64!(i, 1);
-                       debug_assert!(!spurious_carry);
-               }
-               if z1_carry {
-                       let spurious_carry = add_u64!(i, 1);
-                       debug_assert!(!spurious_carry);
+               let i_carry = i_carry_a as u64 + i_carry_b as u64 + z1_carry as u64;
+               if i_carry != 0 {
+                       let must_not_overflow = add_u64!(i, i_carry);
+                       debug_assert!(!must_not_overflow, "Two N*64 bit numbers, multiplied, will not use more than 2*N*64 bits");
                }
 
                let mut res = [0; $len * 2];
@@ -457,39 +454,17 @@ const fn sqr_2(a: &[u64]) -> [u64; 4] {
        let (a0, a1) = (a[0] as u128, a[1] as u128);
        let z2 = a0 * a0;
        let mut z1 = a0 * a1;
-       let i_carry = z1 & (1u128 << 127) != 0;
+       let i_carry_a = z1 & (1u128 << 127) != 0;
        z1 <<= 1;
        let z0 = a1 * a1;
 
-       let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
-       let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
-       let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
-       let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
-       let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
-       let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
-
-       let l = z0b;
-       let (k, j_carry) = z0a.overflowing_add(z1b);
-       let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
-
-       let new_i_carry;
-       (j, new_i_carry) = j.overflowing_add(j_carry as u64);
-       debug_assert!(!second_i_carry || !new_i_carry);
-       second_i_carry |= new_i_carry;
-
-       let mut i = z2a;
-       let mut spurious_overflow;
-       (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
-       debug_assert!(!spurious_overflow);
-       (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
-       debug_assert!(!spurious_overflow);
-
-       [i, j, k, l]
+       add_mul_2_parts(z2, z1, z0, i_carry_a)
 }
 
 macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
        /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
        const fn $name(a: &[u64]) -> [u64; $len * 2] {
+               // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that.
                debug_assert!(a.len() == $len);
 
                let hi = const_subslice(a, 0, $len / 2);
@@ -497,29 +472,25 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id
 
                let v0 = $subsqr(lo);
                let mut v1 = $submul(hi, lo);
-               let i_carry  = double!(v1);
+               let i_carry_a  = double!(v1);
                let v2 = $subsqr(hi);
 
                let l = const_subslice(&v0, $len / 2, $len);
                let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
-               let (mut j, mut i_carry_2) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
+               let (mut j, i_carry_b) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
 
                let mut i = [0; $len / 2];
                let i_source = const_subslice(&v2, 0, $len / 2);
                copy_from_slice!(i, 0, $len / 2, i_source);
 
+               let mut i_carry_c = false;
                if j_carry {
-                       let new_i_carry = add_u64!(j, 1);
-                       debug_assert!(!i_carry_2 || !new_i_carry);
-                       i_carry_2 |= new_i_carry;
-               }
-               if i_carry {
-                       let spurious_carry = add_u64!(i, 1);
-                       debug_assert!(!spurious_carry);
+                       i_carry_c = add_u64!(j, 1);
                }
-               if i_carry_2 {
-                       let spurious_carry = add_u64!(i, 1);
-                       debug_assert!(!spurious_carry);
+               let i_carry = i_carry_a as u64 + i_carry_b as u64 + i_carry_c as u64;
+               if i_carry != 0 {
+                       let must_not_overflow = add_u64!(i, i_carry);
+                       debug_assert!(!must_not_overflow, "Two N*64 bit numbers, multiplied, will not use more than 2*N*64 bits");
                }
 
                let mut res = [0; $len * 2];