///
/// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int
/// types to do multiplication (potentially) natively.
-const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
- debug_assert!(a.len() == 2);
- debug_assert!(b.len() == 2);
-
+const fn mul_2(a: &[u64; 2], b: &[u64; 2]) -> [u64; 4] {
// Gradeschool multiplication is way faster here.
let (a0, a1) = (a[0] as u128, a[1] as u128);
let (b0, b1) = (b[0] as u128, b[1] as u128);
[i, j, k, l]
}
-const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
- debug_assert!(a.len() == 3);
- debug_assert!(b.len() == 3);
-
+const fn mul_3(a: &[u64; 3], b: &[u64; 3]) -> [u64; 6] {
let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128);
let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128);
macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
/// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
- const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
+ const fn $name(a: &[u64; $len], b: &[u64; $len]) -> [u64; $len * 2] {
// We could probably get a bit faster doing gradeschool multiplication for some smaller
// sizes, but its easier to just have one variable-length multiplication, so we do
// Karatsuba always here.
- debug_assert!(a.len() == $len);
- debug_assert!(b.len() == $len);
-
let a0: &[u64; $len / 2] = const_subarr(a, 0);
let a1: &[u64; $len / 2] = const_subarr(a, $len / 2);
let b0: &[u64; $len / 2] = const_subarr(b, 0);
///
/// This is the base case for our squaring, taking advantage of Rust's native 128-bit int
/// types to do multiplication (potentially) natively.
-const fn sqr_2(a: &[u64]) -> [u64; 4] {
- debug_assert!(a.len() == 2);
-
+const fn sqr_2(a: &[u64; 2]) -> [u64; 4] {
let (a0, a1) = (a[0] as u128, a[1] as u128);
let z2 = a0 * a0;
let mut z1 = a0 * a1;
macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
/// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
- const fn $name(a: &[u64]) -> [u64; $len * 2] {
+ const fn $name(a: &[u64; $len]) -> [u64; $len * 2] {
// Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that.
- debug_assert!(a.len() == $len);
-
- let hi = const_subslice(a, 0, $len / 2);
- let lo = const_subslice(a, $len / 2, $len);
+ let hi: &[u64; $len / 2] = const_subarr(a, 0);
+ let lo: &[u64; $len / 2] = const_subarr(a, $len / 2);
let v0 = $subsqr(lo);
let mut v1 = $submul(hi, lo);
} }
// TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
-const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) }
+const fn sqr_3(a: &[u64; 3]) -> [u64; 6] { mul_3(a, a) }
define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
// we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math
// here which debug_assert's the required bits are 0s and then uses faster math primitives.
- type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
- type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
+ type mul_ty = fn(&[u64; WORD_COUNT_4096], &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2];
+ type sqr_ty = fn(&[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2];
type add_double_ty = fn(&[u64; WORD_COUNT_4096 * 2], &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool);
type sub_ty = fn(&[u64; WORD_COUNT_4096], &[u64; WORD_COUNT_4096]) -> ([u64; WORD_COUNT_4096], bool);
let (word_count, log_bits, mul, sqr, add_double, sub) =
if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] {
- fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
- debug_assert_eq!(a.len(), WORD_COUNT_4096);
- debug_assert_eq!(b.len(), WORD_COUNT_4096);
+ fn mul_16_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
let mut res = [0; WORD_COUNT_4096 * 2];
- res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
- &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]));
+ let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4);
+ let b_arr = const_subarr(b, WORD_COUNT_4096 * 3 / 4);
+ res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..]
+ .copy_from_slice(&mul_16(a_arr, b_arr));
res
}
- fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
- debug_assert_eq!(a.len(), WORD_COUNT_4096);
+ fn sqr_16_subarr(a: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
let mut res = [0; WORD_COUNT_4096 * 2];
- res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
- &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
+ let a_arr = const_subarr(a, WORD_COUNT_4096 * 3 / 4);
+ res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..]
+ .copy_from_slice(&sqr_16(a_arr));
res
}
fn add_32_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
}
(16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty)
} else {
- fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
- debug_assert_eq!(a.len(), WORD_COUNT_4096);
- debug_assert_eq!(b.len(), WORD_COUNT_4096);
+ fn mul_32_subarr(a: &[u64; WORD_COUNT_4096], b: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
let mut res = [0; WORD_COUNT_4096 * 2];
- res[WORD_COUNT_4096..].copy_from_slice(
- &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]));
+ let a_arr = const_subarr(a, WORD_COUNT_4096 / 2);
+ let b_arr = const_subarr(b, WORD_COUNT_4096 / 2);
+ res[WORD_COUNT_4096..].copy_from_slice(&mul_32(a_arr, b_arr));
res
}
- fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
+ fn sqr_32_subarr(a: &[u64; WORD_COUNT_4096]) -> [u64; WORD_COUNT_4096 * 2] {
debug_assert_eq!(a.len(), WORD_COUNT_4096);
debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
+ let a_arr = const_subarr(a, WORD_COUNT_4096 / 2);
let mut res = [0; WORD_COUNT_4096 * 2];
- res[WORD_COUNT_4096..].copy_from_slice(
- &sqr_32(&a[WORD_COUNT_4096 / 2..]));
+ res[WORD_COUNT_4096..].copy_from_slice(&sqr_32(a_arr));
res
}
fn add_64_subarr(a: &[u64; WORD_COUNT_4096 * 2], b: &[u64; WORD_COUNT_4096 * 2]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
// t_on_r = (mu + v*modulus) / R
- let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
+ let t0 = mul(const_subarr(&v, WORD_COUNT_4096), &m.0);
let (t1, t1_extra_bit) = add_double(&t0, &mu);
// Note that dividing t1 by R is simply a matter of shifting right by word_count bytes
// if t >= N { t - N } else { t }
// mu % R is just the bottom 4 bytes of mu
- let mu_mod_r = const_subslice(&mu, 4, 8);
+ let mu_mod_r: &[u64; 4] = const_subarr(&mu, 4);
// v = ((mu % R) * negative_modulus_inverse) % R
- let mut v = mul_4(&mu_mod_r, negative_prime_inv_mod_r);
+ let mut v = mul_4(mu_mod_r, negative_prime_inv_mod_r);
const ZEROS: &[u64; 4] = &[0; 4];
copy_from_slice!(v, 0, 4, ZEROS); // mod R
// t_on_r = (mu + v*modulus) / R
- let t0 = mul_4(const_subslice(&v, 4, 8), prime);
+ let t0 = mul_4(const_subarr(&v, 4), prime);
let (t1, t1_extra_bit) = add_8(&t0, &mu);
// Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
// if t >= N { t - N } else { t }
// mu % R is just the bottom 4 bytes of mu
- let mu_mod_r = const_subslice(&mu, 6, 12);
+ let mu_mod_r: &[u64; 6] = const_subarr(&mu, 6);
// v = ((mu % R) * negative_modulus_inverse) % R
- let mut v = mul_6(&mu_mod_r, negative_prime_inv_mod_r);
+ let mut v = mul_6(mu_mod_r, negative_prime_inv_mod_r);
const ZEROS: &[u64; 6] = &[0; 6];
copy_from_slice!(v, 0, 6, ZEROS); // mod R
// t_on_r = (mu + v*modulus) / R
- let t0 = mul_6(const_subslice(&v, 6, 12), prime);
+ let t0 = mul_6(const_subarr(&v, 6), prime);
let (t1, t1_extra_bit) = add_12(&t0, &mu);
// Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
let a_arg = (&a_u64s[..]).try_into().unwrap();
let b_arg = (&b_u64s[..]).try_into().unwrap();
- let res = $mul(&a_u64s, &b_u64s);
+ let res = $mul(a_arg, b_arg);
let mut res_bytes = Vec::with_capacity(input.len() / 2);
for i in res {
res_bytes.extend_from_slice(&i.to_be_bytes());
}
assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone());
- debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
- debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
+ debug_assert_eq!($mul(a_arg, a_arg), $sqr(a_arg));
+ debug_assert_eq!($mul(b_arg, b_arg), $sqr(b_arg));
let (res, carry) = $add(a_arg, b_arg);
let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
let a_arg = (&a_u64s[..]).try_into().unwrap();
let b_arg = (&b_u64s[..]).try_into().unwrap();
- let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
+ let amodp_squared = $div_rem_double(&$mul(a_arg, a_arg), &p_extended).unwrap().1;
assert_eq!(&amodp_squared[..$len], &[0; $len]);
assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
- let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1;
+ let abmodp = $div_rem_double(&$mul(a_arg, b_arg), &p_extended).unwrap().1;
assert_eq!(&abmodp[..$len], &[0; $len]);
assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);