unsafe { alloc::slice::from_raw_parts(startptr, len) }*/
}
+/// Const version of `...: &[a; N] = &a[start..start + N].try_into().unwrap()`
+const fn const_subarr<'a, const N: usize, T>(a: &'a [T], start: usize) -> &'a [T; N] {
+ debug_assert!(N > 0);
+
+ let end = start + N;
+ assert!(start <= a.len());
+ assert!(end <= a.len());
+ let mut startptr = a.as_ptr();
+ startptr = unsafe { startptr.add(start) };
+ // transmute will fail to compile if the source and target sizes don't match, so we know the
+ // target type is just the length of a pointer. This leaves only a few possible encodings for
+ // a pointer to an array, basically just a pointer to the start or end. While its possible Rust
+ // could use a pointer to the end, this would be very surprising (and probably less efficient
+ // in most cases), so we just assume array references are always just pointers to the start.
+ unsafe { core::mem::transmute(startptr) }
+}
+
/// Const version of `dest[dest_start..dest_end].copy_from_slice(source)`
///
/// Once `const_mut_refs` is stable we can convert this to a function
macro_rules! define_add { ($name: ident, $len: expr) => {
/// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
/// bit, with the same semantics as the std [`u64::overflowing_add`] method.
- const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
- debug_assert!(a.len() == $len);
- debug_assert!(b.len() == $len);
+ const fn $name(a: &[u64; $len], b: &[u64; $len]) -> ([u64; $len], bool) {
let mut r = [0; $len];
let mut carry = false;
let mut i = $len - 1;
debug_assert!(a.len() == $len);
debug_assert!(b.len() == $len);
- let a0 = const_subslice(a, 0, $len / 2);
- let a1 = const_subslice(a, $len / 2, $len);
- let b0 = const_subslice(b, 0, $len / 2);
- let b1 = const_subslice(b, $len / 2, $len);
+ let a0: &[u64; $len / 2] = const_subarr(a, 0);
+ let a1: &[u64; $len / 2] = const_subarr(a, $len / 2);
+ let b0: &[u64; $len / 2] = const_subarr(b, 0);
+ let b1: &[u64; $len / 2] = const_subarr(b, $len / 2);
let z2 = $submul(a0, b0);
let z0 = $submul(a1, b1);
let (z1a_max, z1a_min, z1a_sign) =
- if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) };
+ if slice_greater_than(a1, a0) { (a1, a0, true) } else { (a0, a1, false) };
let (z1b_max, z1b_min, z1b_sign) =
- if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
+ if slice_greater_than(b1, b0) { (b1, b0, true) } else { (b0, b1, false) };
let z1a = $subsub(z1a_max, z1a_min);
debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min");
r.0
};
+ let z0_start: &[u64; $len / 2] = const_subarr(&z0, 0);
+ let z1_start: &[u64; $len / 2] = const_subarr(&z1, 0);
+ let z1_end: &[u64; $len / 2] = const_subarr(&z1, $len / 2);
+ let z2_end: &[u64; $len / 2] = const_subarr(&z2, $len / 2);
+
let l = const_subslice(&z0, $len / 2, $len);
- let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
- let (mut j, i_carry_a) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
+ let (k, j_carry) = $subadd(z0_start, z1_end);
+ let (mut j, i_carry_a) = $subadd(z1_start, z2_end);
let mut i_carry_b = false;
if j_carry {
i_carry_b = add_u64!(j, 1);
let i_carry_a = double!(v1);
let v2 = $subsqr(hi);
+ let v0_start: &[u64; $len / 2] = const_subarr(&v0, 0);
+ let v1_start: &[u64; $len / 2] = const_subarr(&v1, 0);
+ let v1_end: &[u64; $len / 2] = const_subarr(&v1, $len / 2);
+ let v2_end: &[u64; $len / 2] = const_subarr(&v2, $len / 2);
+
let l = const_subslice(&v0, $len / 2, $len);
- let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
- let (mut j, i_carry_b) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
+ let (k, j_carry) = $subadd(v0_start, v1_end);
+ let (mut j, i_carry_b) = $subadd(v1_start, v2_end);
let mut i = [0; $len / 2];
let i_source = const_subslice(&v2, 0, $len / 2);
debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
(true, true) => {
- let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+ let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
debug_assert!(!overflow);
(new_s, true)
}
(false, true) => {
- let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+ let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
debug_assert!(!overflow);
(new_s, false)
},
(true, false) => {
- let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+ let (new_s, overflow) = $add(&old_s, const_subarr(&new_sa, $len));
debug_assert!(!overflow);
(new_s, true)
},
type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
- type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
+ type add_double_ty = fn(&[u64; WORD_COUNT_4096 * 2], &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool);
type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool);
let (word_count, log_bits, mul, sqr, add_double, sub) =
if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
&sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
res
}
- fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
- debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
- debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
+ fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
- let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]);
+ let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 2);
+ let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 2);
+ let (add, overflow) = add_32(a_arr, b_arr);
let mut res = [0; WORD_COUNT_4096 * 2];
res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
(res, overflow)
&sqr_32(&a[WORD_COUNT_4096 / 2..]));
res
}
- fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
- debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
- debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
+ fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
- let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]);
+ let a_arr = const_subarr(a, WORD_COUNT_4096);
+ let b_arr = const_subarr(b, WORD_COUNT_4096);
+ let (add, overflow) = add_64(a_arr, b_arr);
let mut res = [0; WORD_COUNT_4096 * 2];
res[WORD_COUNT_4096..].copy_from_slice(&add);
(res, overflow)
}
macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
+ let a_arg = (&a_u64s[..]).try_into().unwrap();
+ let b_arg = (&b_u64s[..]).try_into().unwrap();
+
let res = $mul(&a_u64s, &b_u64s);
let mut res_bytes = Vec::with_capacity(input.len() / 2);
for i in res {
debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
- let (res, carry) = $add(&a_u64s, &b_u64s);
+ let (res, carry) = $add(a_arg, b_arg);
let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
if carry { res_bytes.push(1); } else { res_bytes.push(0); }
for i in res {
let mut p_extended = [0; $len * 2];
p_extended[$len..].copy_from_slice(&$PRIME);
+ let a_arg = (&a_u64s[..]).try_into().unwrap();
+ let b_arg = (&b_u64s[..]).try_into().unwrap();
+
let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
assert_eq!(&amodp_squared[..$len], &[0; $len]);
assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
assert_eq!(&abmodp[..$len], &[0; $len]);
assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
- let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s);
+ let (aplusb, aplusb_overflow) = $add(a_arg, b_arg);
let mut aplusb_extended = [0; $len * 2];
aplusb_extended[$len..].copy_from_slice(&aplusb);
if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
assert_eq!(&aplusbmodp[..$len], &[0; $len]);
assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
- let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s);
+ let (mut aminusb, aminusb_underflow) = $sub(a_arg, b_arg);
if aminusb_underflow {
let mut overflow;
(aminusb, overflow) = $add(&aminusb, &$PRIME);