]> git.bitcoin.ninja Git - dnssec-prover/commitdiff
Make addition take array references rather than slices
authorMatt Corallo <git@bluematt.me>
Mon, 29 Jul 2024 20:01:15 +0000 (20:01 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 1 Aug 2024 03:55:38 +0000 (03:55 +0000)
This seems to reduce binary size marginally by avoiding slice
bounds checking.

src/crypto/bigint.rs

index 781d54c5ebfd5ff9a71db288c09f22688ca2e36e..b7c960853989d27bf9891e66342974bb1b9638fd 100644 (file)
@@ -48,6 +48,23 @@ const fn const_subslice<'a, T>(a: &'a [T], start: usize, end: usize) -> &'a [T]
        unsafe { alloc::slice::from_raw_parts(startptr, len) }*/
 }
 
+/// Const version of `...: &[a; N] = &a[start..start + N].try_into().unwrap()`
+const fn const_subarr<'a, const N: usize, T>(a: &'a [T], start: usize) -> &'a [T; N] {
+       debug_assert!(N > 0);
+
+       let end = start + N;
+       assert!(start <= a.len());
+       assert!(end <= a.len());
+       let mut startptr = a.as_ptr();
+       startptr = unsafe { startptr.add(start) };
+       // transmute will fail to compile if the source and target sizes don't match, so we know the
+       // target type is just the length of a pointer. This leaves only a few possible encodings for
+       // a pointer to an array, basically just a pointer to the start or end. While its possible Rust
+       // could use a pointer to the end, this would be very surprising (and probably less efficient
+       // in most cases), so we just assume array references are always just pointers to the start.
+       unsafe { core::mem::transmute(startptr) }
+}
+
 /// Const version of `dest[dest_start..dest_end].copy_from_slice(source)`
 ///
 /// Once `const_mut_refs` is stable we can convert this to a function
@@ -154,9 +171,7 @@ macro_rules! double { ($a: ident) => { {
 macro_rules! define_add { ($name: ident, $len: expr) => {
        /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
        /// bit, with the same semantics as the std [`u64::overflowing_add`] method.
-       const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
-               debug_assert!(a.len() == $len);
-               debug_assert!(b.len() == $len);
+       const fn $name(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) {
                let mut r = [0; $len];
                let mut carry = false;
                let mut i = $len - 1;
@@ -354,18 +369,18 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
                debug_assert!(a.len() == $len);
                debug_assert!(b.len() == $len);
 
-               let a0 = const_subslice(a, 0, $len / 2);
-               let a1 = const_subslice(a, $len / 2, $len);
-               let b0 = const_subslice(b, 0, $len / 2);
-               let b1 = const_subslice(b, $len / 2, $len);
+               let a0: &[u64; $len / 2] = const_subarr(a, 0);
+               let a1: &[u64; $len / 2] = const_subarr(a, $len / 2);
+               let b0: &[u64; $len / 2] = const_subarr(b, 0);
+               let b1: &[u64; $len / 2] = const_subarr(b, $len / 2);
 
                let z2 = $submul(a0, b0);
                let z0 = $submul(a1, b1);
 
                let (z1a_max, z1a_min, z1a_sign) =
-                       if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) };
+                       if slice_greater_than(a1, a0) { (a1, a0, true) } else { (a0, a1, false) };
                let (z1b_max, z1b_min, z1b_sign) =
-                       if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
+                       if slice_greater_than(b1, b0) { (b1, b0, true) } else { (b0, b1, false) };
 
                let z1a = $subsub(z1a_max, z1a_min);
                debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min");
@@ -386,9 +401,14 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
                        r.0
                };
 
+               let z0_start: &[u64; $len / 2] = const_subarr(&z0, 0);
+               let z1_start: &[u64; $len / 2] = const_subarr(&z1, 0);
+               let z1_end: &[u64; $len / 2] = const_subarr(&z1, $len / 2);
+               let z2_end: &[u64; $len / 2] = const_subarr(&z2, $len / 2);
+
                let l = const_subslice(&z0, $len / 2, $len);
-               let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
-               let (mut j, i_carry_a) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
+               let (k, j_carry) = $subadd(z0_start, z1_end);
+               let (mut j, i_carry_a) = $subadd(z1_start, z2_end);
                let mut i_carry_b = false;
                if j_carry {
                        i_carry_b = add_u64!(j, 1);
@@ -450,9 +470,14 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id
                let i_carry_a  = double!(v1);
                let v2 = $subsqr(hi);
 
+               let v0_start: &[u64; $len / 2] = const_subarr(&v0, 0);
+               let v1_start: &[u64; $len / 2] = const_subarr(&v1, 0);
+               let v1_end: &[u64; $len / 2] = const_subarr(&v1, $len / 2);
+               let v2_end: &[u64; $len / 2] = const_subarr(&v2, $len / 2);
+
                let l = const_subslice(&v0, $len / 2, $len);
-               let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
-               let (mut j, i_carry_b) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
+               let (k, j_carry) = $subadd(v0_start, v1_end);
+               let (mut j, i_carry_b) = $subadd(v1_start, v2_end);
 
                let mut i = [0; $len / 2];
                let i_source = const_subslice(&v2, 0, $len / 2);
@@ -569,17 +594,17 @@ macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: iden
                        debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
                        let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
                                (true, true) => {
-                                       let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+                                       let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
                                        debug_assert!(!overflow);
                                        (new_s, true)
                                }
                                (false, true) => {
-                                       let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+                                       let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
                                        debug_assert!(!overflow);
                                        (new_s, false)
                                },
                                (true, false) => {
-                                       let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+                                       let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
                                        debug_assert!(!overflow);
                                        (new_s, true)
                                },
@@ -730,7 +755,7 @@ impl U4096 {
 
                type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
                type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
-               type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
+               type add_double_ty = fn(&[u64; WORD_COUNT_4096 * 2], &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool);
                type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool);
                let (word_count, log_bits, mul, sqr, add_double, sub) =
                        if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
@@ -753,12 +778,12 @@ impl U4096 {
                                                        &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
                                                res
                                        }
-                                       fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
-                                               debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
-                                               debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
+                                       fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
-                                               let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]);
+                                               let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 2);
+                                               let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 2);
+                                               let (add, overflow) = add_32(a_arr, b_arr);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
                                                res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
                                                (res, overflow)
@@ -793,12 +818,12 @@ impl U4096 {
                                                        &sqr_32(&a[WORD_COUNT_4096 / 2..]));
                                                res
                                        }
-                                       fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
-                                               debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
-                                               debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
+                                       fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
                                                debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
                                                debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
-                                               let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]);
+                                               let a_arr = const_subarr(a, WORD_COUNT_4096);
+                                               let b_arr = const_subarr(b, WORD_COUNT_4096);
+                                               let (add, overflow) = add_64(a_arr, b_arr);
                                                let mut res = [0; WORD_COUNT_4096 * 2];
                                                res[WORD_COUNT_4096..].copy_from_slice(&add);
                                                (res, overflow)
@@ -1430,6 +1455,9 @@ pub fn fuzz_math(input: &[u8]) {
        }
 
        macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
+               let a_arg = (&a_u64s[..]).try_into().unwrap();
+               let b_arg = (&b_u64s[..]).try_into().unwrap();
+
                let res = $mul(&a_u64s, &b_u64s);
                let mut res_bytes = Vec::with_capacity(input.len() / 2);
                for i in res {
@@ -1440,7 +1468,7 @@ pub fn fuzz_math(input: &[u8]) {
                debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
                debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
 
-               let (res, carry) = $add(&a_u64s, &b_u64s);
+               let (res, carry) = $add(a_arg, b_arg);
                let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
                if carry { res_bytes.push(1); } else { res_bytes.push(0); }
                for i in res {
@@ -1504,6 +1532,9 @@ pub fn fuzz_math(input: &[u8]) {
                let mut p_extended = [0; $len * 2];
                p_extended[$len..].copy_from_slice(&$PRIME);
 
+               let a_arg = (&a_u64s[..]).try_into().unwrap();
+               let b_arg = (&b_u64s[..]).try_into().unwrap();
+
                let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
                assert_eq!(&amodp_squared[..$len], &[0; $len]);
                assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
@@ -1512,7 +1543,7 @@ pub fn fuzz_math(input: &[u8]) {
                assert_eq!(&abmodp[..$len], &[0; $len]);
                assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
 
-               let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s);
+               let (aplusb, aplusb_overflow) = $add(a_arg, b_arg);
                let mut aplusb_extended = [0; $len * 2];
                aplusb_extended[$len..].copy_from_slice(&aplusb);
                if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
@@ -1520,7 +1551,7 @@ pub fn fuzz_math(input: &[u8]) {
                assert_eq!(&aplusbmodp[..$len], &[0; $len]);
                assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
 
-               let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s);
+               let (mut aminusb, aminusb_underflow) = $sub(a_arg, b_arg);
                if aminusb_underflow {
                        let mut overflow;
                        (aminusb, overflow) = $add(&aminusb, &$PRIME);