//! Simple variable-time big integer implementation
use alloc::vec::Vec;
+use core::marker::PhantomData;
const WORD_COUNT_4096: usize = 4096 / 64;
+const WORD_COUNT_256: usize = 256 / 64;
+const WORD_COUNT_384: usize = 384 / 64;
// RFC 5702 indicates RSA keys can be up to 4096 bits
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(super) struct U4096([u64; WORD_COUNT_4096]);
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U256([u64; WORD_COUNT_256]);
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U384([u64; WORD_COUNT_384]);
+
+pub(super) trait Int: Clone + Ord + Sized {
+ const ZERO: Self;
+ const BYTES: usize;
+ fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
+ fn limbs(&self) -> &[u64];
+}
+impl Int for U256 {
+ const ZERO: U256 = U256([0; 4]);
+ const BYTES: usize = 32;
+ fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+ fn limbs(&self) -> &[u64] { &self.0 }
+}
+impl Int for U384 {
+ const ZERO: U384 = U384([0; 6]);
+ const BYTES: usize = 48;
+ fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+ fn limbs(&self) -> &[u64] { &self.0 }
+}
+
+/// Defines a *PRIME* Modulus
+pub(super) trait PrimeModulus<I: Int> {
+ const PRIME: I;
+ const R_SQUARED_MOD_PRIME: I;
+ const NEGATIVE_PRIME_INV_MOD_R: I;
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
+
macro_rules! debug_unwrap { ($v: expr) => { {
let v = $v;
debug_assert!(v.is_ok());
} }
define_add!(add_2, 2);
+define_add!(add_3, 3);
define_add!(add_4, 4);
+define_add!(add_6, 6);
define_add!(add_8, 8);
+define_add!(add_12, 12);
define_add!(add_16, 16);
define_add!(add_32, 32);
define_add!(add_64, 64);
define_add!(add_128, 128);
-macro_rules! define_sub { ($name: ident, $len: expr) => {
+macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
/// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
/// $len-64-bit integer and an overflow bit, with the same semantics as the std
/// [`u64::overflowing_sub`] method.
}
(r, carry)
}
+
+ /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
+ /// $len-64-bit integer representing the absolute value of the result, as well as a sign bit.
+ #[allow(unused)]
+ const fn $name_abs(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
+ let (mut res, neg) = $name(a, b);
+ if neg {
+ negate!(res);
+ }
+ (res, neg)
+ }
} }
-define_sub!(sub_2, 2);
-define_sub!(sub_4, 4);
-define_sub!(sub_8, 8);
-define_sub!(sub_16, 16);
-define_sub!(sub_32, 32);
-define_sub!(sub_64, 64);
-#[cfg(debug_assertions)]
-define_sub!(sub_128, 128);
+define_sub!(sub_2, sub_abs_2, 2);
+define_sub!(sub_3, sub_abs_3, 3);
+define_sub!(sub_4, sub_abs_4, 4);
+define_sub!(sub_6, sub_abs_6, 6);
+define_sub!(sub_8, sub_abs_8, 8);
+define_sub!(sub_12, sub_abs_12, 12);
+define_sub!(sub_16, sub_abs_16, 16);
+define_sub!(sub_32, sub_abs_32, 32);
+define_sub!(sub_64, sub_abs_64, 64);
+define_sub!(sub_128, sub_abs_128, 128);
/// Multiplies two 128-bit integers together, returning a new 256-bit integer.
///
[i, j, k, l]
}
+const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
+ debug_assert!(a.len() == 3);
+ debug_assert!(b.len() == 3);
+
+ let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128);
+ let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128);
+
+ let m4 = a2 * b2;
+ let m3a = a2 * b1;
+ let m3b = a1 * b2;
+ let m2a = a2 * b0;
+ let m2b = a1 * b1;
+ let m2c = a0 * b2;
+ let m1a = a1 * b0;
+ let m1b = a0 * b1;
+ let m0 = a0 * b0;
+
+ let r5 = ((m4 >> 0) & 0xffff_ffff_ffff_ffff) as u64;
+
+ let r4a = ((m4 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r4b = ((m3a >> 0) & 0xffff_ffff_ffff_ffff) as u64;
+ let r4c = ((m3b >> 0) & 0xffff_ffff_ffff_ffff) as u64;
+
+ let r3a = ((m3a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r3b = ((m3b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r3c = ((m2a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+ let r3d = ((m2b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+ let r3e = ((m2c >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+
+ let r2a = ((m2a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r2b = ((m2b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r2c = ((m2c >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r2d = ((m1a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+ let r2e = ((m1b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+
+ let r1a = ((m1a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r1b = ((m1b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+ let r1c = ((m0 >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+
+ let r0a = ((m0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+
+ let (r4, r3_ca) = r4a.overflowing_add(r4b);
+ let (r4, r3_cb) = r4.overflowing_add(r4c);
+ let r3_c = r3_ca as u64 + r3_cb as u64;
+
+ let (r3, r2_ca) = r3a.overflowing_add(r3b);
+ let (r3, r2_cb) = r3.overflowing_add(r3c);
+ let (r3, r2_cc) = r3.overflowing_add(r3d);
+ let (r3, r2_cd) = r3.overflowing_add(r3e);
+ let (r3, r2_ce) = r3.overflowing_add(r3_c);
+ let r2_c = r2_ca as u64 + r2_cb as u64 + r2_cc as u64 + r2_cd as u64 + r2_ce as u64;
+
+ let (r2, r1_ca) = r2a.overflowing_add(r2b);
+ let (r2, r1_cb) = r2.overflowing_add(r2c);
+ let (r2, r1_cc) = r2.overflowing_add(r2d);
+ let (r2, r1_cd) = r2.overflowing_add(r2e);
+ let (r2, r1_ce) = r2.overflowing_add(r2_c);
+ let r1_c = r1_ca as u64 + r1_cb as u64 + r1_cc as u64 + r1_cd as u64 + r1_ce as u64;
+
+ let (r1, r0_ca) = r1a.overflowing_add(r1b);
+ let (r1, r0_cb) = r1.overflowing_add(r1c);
+ let (r1, r0_cc) = r1.overflowing_add(r1_c);
+ let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64;
+
+ let (r0, must_not_overflow) = r0a.overflowing_add(r0_c);
+ debug_assert!(!must_not_overflow);
+
+ [r0, r1, r2, r3, r4, r5]
+}
+
macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
/// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
} }
define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
+define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3);
define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
}
} }
+// TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
+const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) }
+
define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
+define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
-#[cfg(fuzzing)]
macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
#[cfg(fuzzing)]
define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
-#[cfg(fuzzing)]
define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
-#[cfg(fuzzing)]
+define_div_rem!(div_rem_6, 6, sub_6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack!
+#[cfg(debug_assertions)]
define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
+#[cfg(debug_assertions)]
+define_div_rem!(div_rem_12, 12, sub_12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack!
define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
#[cfg(debug_assertions)]
define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
+macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => {
+ /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus,
+ /// if one exists.
+ const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> {
+ if slice_equal(a, &[0; $len]) || slice_equal(m, &[0; $len]) { return Err(()); }
+
+ let (mut s, mut old_s) = ([0; $len], [0; $len]);
+ old_s[$len - 1] = 1;
+ let mut r = *m;
+ let mut old_r = *a;
+
+ let (mut old_s_neg, mut s_neg) = (false, false);
+
+ while !slice_equal(&r, &[0; $len]) {
+ let (quot, new_r) = debug_unwrap!($div(&old_r, &r));
+
+ let new_sa = $mul(", &s);
+ debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
+ let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
+ (true, true) => {
+ let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+ debug_assert!(!overflow);
+ (new_s, true)
+ }
+ (false, true) => {
+ let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+ debug_assert!(!overflow);
+ (new_s, false)
+ },
+ (true, false) => {
+ let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+ debug_assert!(!overflow);
+ (new_s, true)
+ },
+ (false, false) => $sub_abs(&old_s, const_subslice(&new_sa, $len, new_sa.len())),
+ };
+
+ old_r = r;
+ r = new_r;
+
+ old_s = s;
+ old_s_neg = s_neg;
+ s = new_s;
+ s_neg = new_s_neg;
+ }
+
+ // At this point old_r contains our GCD and old_s our first Bézout's identity coefficient.
+ if !slice_equal(const_subslice(&old_r, 0, $len - 1), &[0; $len - 1]) || old_r[$len - 1] != 1 {
+ Err(())
+ } else {
+ debug_assert!(slice_greater_than(m, &old_s));
+ if old_s_neg {
+ let (modinv, underflow) = $sub_abs(m, &old_s);
+ debug_assert!(!underflow);
+ debug_assert!(slice_greater_than(m, &modinv));
+ Ok(modinv)
+ } else {
+ Ok(old_s)
+ }
+ }
+ }
+} }
+#[cfg(fuzzing)]
+define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2);
+define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4);
+define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6);
+#[cfg(fuzzing)]
+define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
+
impl U4096 {
/// Constructs a new [`U4096`] from a variable number of big-endian bytes.
pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
}
}
+const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 {
+ match b {
+ [a, b, c, d, e, f, g, h, ..] => {
+ ((*a as u64) << 8*7) |
+ ((*b as u64) << 8*6) |
+ ((*c as u64) << 8*5) |
+ ((*d as u64) << 8*4) |
+ ((*e as u64) << 8*3) |
+ ((*f as u64) << 8*2) |
+ ((*g as u64) << 8*1) |
+ ((*h as u64) << 8*0)
+ },
+ _ => panic!(),
+ }
+}
+
+const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 {
+ match b {
+ [_, _, _, _, _, _, _, _,
+ a, b, c, d, e, f, g, h, ..] => {
+ ((*a as u64) << 8*7) |
+ ((*b as u64) << 8*6) |
+ ((*c as u64) << 8*5) |
+ ((*d as u64) << 8*4) |
+ ((*e as u64) << 8*3) |
+ ((*f as u64) << 8*2) |
+ ((*g as u64) << 8*1) |
+ ((*h as u64) << 8*0)
+ },
+ _ => panic!(),
+ }
+}
+
+const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 {
+ match b {
+ [_, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ a, b, c, d, e, f, g, h, ..] => {
+ ((*a as u64) << 8*7) |
+ ((*b as u64) << 8*6) |
+ ((*c as u64) << 8*5) |
+ ((*d as u64) << 8*4) |
+ ((*e as u64) << 8*3) |
+ ((*f as u64) << 8*2) |
+ ((*g as u64) << 8*1) |
+ ((*h as u64) << 8*0)
+ },
+ _ => panic!(),
+ }
+}
+
+const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 {
+ match b {
+ [_, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ a, b, c, d, e, f, g, h, ..] => {
+ ((*a as u64) << 8*7) |
+ ((*b as u64) << 8*6) |
+ ((*c as u64) << 8*5) |
+ ((*d as u64) << 8*4) |
+ ((*e as u64) << 8*3) |
+ ((*f as u64) << 8*2) |
+ ((*g as u64) << 8*1) |
+ ((*h as u64) << 8*0)
+ },
+ _ => panic!(),
+ }
+}
+
+const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 {
+ match b {
+ [_, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ a, b, c, d, e, f, g, h, ..] => {
+ ((*a as u64) << 8*7) |
+ ((*b as u64) << 8*6) |
+ ((*c as u64) << 8*5) |
+ ((*d as u64) << 8*4) |
+ ((*e as u64) << 8*3) |
+ ((*f as u64) << 8*2) |
+ ((*g as u64) << 8*1) |
+ ((*h as u64) << 8*0)
+ },
+ _ => panic!(),
+ }
+}
+
+const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 {
+ match b {
+ [_, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ _, _, _, _, _, _, _, _,
+ a, b, c, d, e, f, g, h, ..] => {
+ ((*a as u64) << 8*7) |
+ ((*b as u64) << 8*6) |
+ ((*c as u64) << 8*5) |
+ ((*d as u64) << 8*4) |
+ ((*e as u64) << 8*3) |
+ ((*f as u64) << 8*2) |
+ ((*g as u64) << 8*1) |
+ ((*h as u64) << 8*0)
+ },
+ _ => panic!(),
+ }
+}
+
+impl U256 {
+ /// Constructs a new [`U256`] from a variable number of big-endian bytes.
+ pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U256, ()> {
+ if bytes.len() > 256/8 { return Err(()); }
+ let u64s = (bytes.len() + 7) / 8;
+ let mut res = [0; WORD_COUNT_256];
+ for i in 0..u64s {
+ let mut b = [0; 8];
+ let pos = (u64s - i) * 8;
+ let start = bytes.len().saturating_sub(pos);
+ let end = bytes.len() + 8 - pos;
+ b[8 + start - end..].copy_from_slice(&bytes[start..end]);
+ res[i + WORD_COUNT_256 - u64s] = u64::from_be_bytes(b);
+ }
+ Ok(U256(res))
+ }
+
+ /// Constructs a new [`U256`] from a fixed number of big-endian bytes.
+ pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 {
+ let res = [
+ u64_from_bytes_a_panicking(bytes),
+ u64_from_bytes_b_panicking(bytes),
+ u64_from_bytes_c_panicking(bytes),
+ u64_from_bytes_d_panicking(bytes),
+ ];
+ U256(res)
+ }
+
+ pub(super) const fn zero() -> U256 { U256([0, 0, 0, 0]) }
+ pub(super) const fn one() -> U256 { U256([0, 0, 0, 1]) }
+ pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) }
+}
+
+impl<M: PrimeModulus<U256>> U256Mod<M> {
+ const fn mont_reduction(mu: [u64; 8]) -> Self {
+ #[cfg(debug_assertions)] {
+ // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
+ // should be able to do it at compile time alone.
+ let minus_one_mod_r = mul_4(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+ assert!(slice_equal(const_subslice(&minus_one_mod_r, 4, 8), &[0xffff_ffff_ffff_ffff; 4]));
+ }
+
+ #[cfg(debug_assertions)] {
+ // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
+ // should be able to do it at compile time alone.
+ let r_minus_one = [0xffff_ffff_ffff_ffff; 4];
+ let (mut r_mod_prime, _) = sub_4(&r_minus_one, &M::PRIME.0);
+ add_one!(r_mod_prime);
+ let r_squared = sqr_4(&r_mod_prime);
+ let mut prime_extended = [0; 8];
+ let prime = M::PRIME.0;
+ copy_from_slice!(prime_extended, 4, 8, prime);
+ let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_8(&r_squared, &prime_extended) { v } else { panic!() };
+ assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
+ assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
+ }
+
+ let mu_mod_r = const_subslice(&mu, 4, 8);
+ let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+ const ZEROS: &[u64; 4] = &[0; 4];
+ copy_from_slice!(v, 0, 4, ZEROS); // mod R
+ let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0);
+ let (t1, t1_extra_bit) = add_8(&t0, &mu);
+ let t1_on_r = const_subslice(&t1, 0, 4);
+ let mut res = [0; 4];
+ if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
+ let underflow;
+ (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0);
+ debug_assert!(t1_extra_bit == underflow);
+ } else {
+ copy_from_slice!(res, 0, 4, t1_on_r);
+ }
+ Self(U256(res), PhantomData)
+ }
+
+ pub(super) const fn from_u256_panicking(v: U256) -> Self {
+ assert!(v.0[0] <= M::PRIME.0[0]);
+ if v.0[0] == M::PRIME.0[0] {
+ assert!(v.0[1] <= M::PRIME.0[1]);
+ if v.0[1] == M::PRIME.0[1] {
+ assert!(v.0[2] <= M::PRIME.0[2]);
+ if v.0[2] == M::PRIME.0[2] {
+ assert!(v.0[3] < M::PRIME.0[3]);
+ }
+ }
+ }
+ assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0 || M::PRIME.0[3] != 0);
+ Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+ }
+
+ pub(super) fn from_u256(mut v: U256) -> Self {
+ debug_assert!(M::PRIME.0 != [0; 4]);
+ debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
+ while v >= M::PRIME {
+ let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
+ debug_assert!(!spurious_underflow);
+ v = U256(new_v);
+ }
+ Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+ }
+
+ pub(super) fn from_modinv_of(v: U256) -> Result<Self, ()> {
+ Ok(Self::from_u256(U256(mod_inv_4(&v.0, &M::PRIME.0)?)))
+ }
+
+ /// Multiplies `self` * `b` mod `m`.
+ ///
+ /// Panics if `self`'s modulus is not equal to `b`'s
+ pub(super) fn mul(&self, b: &Self) -> Self {
+ Self::mont_reduction(mul_4(&self.0.0, &b.0.0))
+ }
+
+ /// Doubles `self` mod `m`.
+ pub(super) fn double(&self) -> Self {
+ let mut res = self.0.0;
+ let overflow = double!(res);
+ if overflow || !slice_greater_than(&M::PRIME.0, &res) {
+ let underflow;
+ (res, underflow) = sub_4(&res, &M::PRIME.0);
+ debug_assert_eq!(overflow, underflow);
+ }
+ Self(U256(res), PhantomData)
+ }
+
+ /// Multiplies `self` by 3 mod `m`.
+ pub(super) fn times_three(&self) -> Self {
+ // TODO: Optimize this a lot
+ self.mul(&U256Mod::from_u256(U256::three()))
+ }
+
+ /// Multiplies `self` by 4 mod `m`.
+ pub(super) fn times_four(&self) -> Self {
+ // TODO: Optimize this somewhat?
+ self.double().double()
+ }
+
+ /// Multiplies `self` by 8 mod `m`.
+ pub(super) fn times_eight(&self) -> Self {
+ // TODO: Optimize this somewhat?
+ self.double().double().double()
+ }
+
+ /// Multiplies `self` by 8 mod `m`.
+ pub(super) fn square(&self) -> Self {
+ Self::mont_reduction(sqr_4(&self.0.0))
+ }
+
+ /// Subtracts `b` from `self` % `m`.
+ pub(super) fn sub(&self, b: &Self) -> Self {
+ let (mut val, underflow) = sub_4(&self.0.0, &b.0.0);
+ if underflow {
+ let overflow;
+ (val, overflow) = add_4(&val, &M::PRIME.0);
+ debug_assert_eq!(overflow, underflow);
+ }
+ Self(U256(val), PhantomData)
+ }
+
+ /// Adds `b` to `self` % `m`.
+ pub(super) fn add(&self, b: &Self) -> Self {
+ let (mut val, overflow) = add_4(&self.0.0, &b.0.0);
+ if overflow || !slice_greater_than(&M::PRIME.0, &val) {
+ let underflow;
+ (val, underflow) = sub_4(&val, &M::PRIME.0);
+ debug_assert_eq!(overflow, underflow);
+ }
+ Self(U256(val), PhantomData)
+ }
+
+ /// Returns the underlying [`U256`].
+ pub(super) fn into_u256(self) -> U256 {
+ let mut expanded_self = [0; 8];
+ expanded_self[4..].copy_from_slice(&self.0.0);
+ Self::mont_reduction(expanded_self).0
+ }
+}
+
+impl U384 {
+ /// Constructs a new [`U384`] from a variable number of big-endian bytes.
+ pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U384, ()> {
+ if bytes.len() > 384/8 { return Err(()); }
+ let u64s = (bytes.len() + 7) / 8;
+ let mut res = [0; WORD_COUNT_384];
+ for i in 0..u64s {
+ let mut b = [0; 8];
+ let pos = (u64s - i) * 8;
+ let start = bytes.len().saturating_sub(pos);
+ let end = bytes.len() + 8 - pos;
+ b[8 + start - end..].copy_from_slice(&bytes[start..end]);
+ res[i + WORD_COUNT_384 - u64s] = u64::from_be_bytes(b);
+ }
+ Ok(U384(res))
+ }
+
+ /// Constructs a new [`U384`] from a fixed number of big-endian bytes.
+ pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 {
+ let res = [
+ u64_from_bytes_a_panicking(bytes),
+ u64_from_bytes_b_panicking(bytes),
+ u64_from_bytes_c_panicking(bytes),
+ u64_from_bytes_d_panicking(bytes),
+ u64_from_bytes_e_panicking(bytes),
+ u64_from_bytes_f_panicking(bytes),
+ ];
+ U384(res)
+ }
+
+ pub(super) const fn zero() -> U384 { U384([0, 0, 0, 0, 0, 0]) }
+ pub(super) const fn one() -> U384 { U384([0, 0, 0, 0, 0, 1]) }
+ pub(super) const fn three() -> U384 { U384([0, 0, 0, 0, 0, 3]) }
+}
+
+impl<M: PrimeModulus<U384>> U384Mod<M> {
+ const fn mont_reduction(mu: [u64; 12]) -> Self {
+ #[cfg(debug_assertions)] {
+ // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
+ // should be able to do it at compile time alone.
+ let minus_one_mod_r = mul_6(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+ assert!(slice_equal(const_subslice(&minus_one_mod_r, 6, 12), &[0xffff_ffff_ffff_ffff; 6]));
+ }
+
+ #[cfg(debug_assertions)] {
+ // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
+ // should be able to do it at compile time alone.
+ let r_minus_one = [0xffff_ffff_ffff_ffff; 6];
+ let (mut r_mod_prime, _) = sub_6(&r_minus_one, &M::PRIME.0);
+ add_one!(r_mod_prime);
+ let r_squared = sqr_6(&r_mod_prime);
+ let mut prime_extended = [0; 12];
+ let prime = M::PRIME.0;
+ copy_from_slice!(prime_extended, 6, 12, prime);
+ let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_12(&r_squared, &prime_extended) { v } else { panic!() };
+ assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
+ assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
+ }
+
+ let mu_mod_r = const_subslice(&mu, 6, 12);
+ let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+ const ZEROS: &[u64; 6] = &[0; 6];
+ copy_from_slice!(v, 0, 6, ZEROS); // mod R
+ let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0);
+ let (t1, t1_extra_bit) = add_12(&t0, &mu);
+ let t1_on_r = const_subslice(&t1, 0, 6);
+ let mut res = [0; 6];
+ if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
+ let underflow;
+ (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0);
+ debug_assert!(t1_extra_bit == underflow);
+ } else {
+ copy_from_slice!(res, 0, 6, t1_on_r);
+ }
+ Self(U384(res), PhantomData)
+ }
+
+ pub(super) const fn from_u384_panicking(v: U384) -> Self {
+ assert!(v.0[0] <= M::PRIME.0[0]);
+ if v.0[0] == M::PRIME.0[0] {
+ assert!(v.0[1] <= M::PRIME.0[1]);
+ if v.0[1] == M::PRIME.0[1] {
+ assert!(v.0[2] <= M::PRIME.0[2]);
+ if v.0[2] == M::PRIME.0[2] {
+ assert!(v.0[3] <= M::PRIME.0[3]);
+ if v.0[3] == M::PRIME.0[3] {
+ assert!(v.0[4] <= M::PRIME.0[4]);
+ if v.0[4] == M::PRIME.0[4] {
+ assert!(v.0[5] < M::PRIME.0[5]);
+ }
+ }
+ }
+ }
+ }
+ assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0
+ || M::PRIME.0[3] != 0|| M::PRIME.0[4] != 0|| M::PRIME.0[5] != 0);
+ Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+ }
+
+ pub(super) fn from_u384(mut v: U384) -> Self {
+ debug_assert!(M::PRIME.0 != [0; 6]);
+ debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
+ while v >= M::PRIME {
+ let (new_v, spurious_underflow) = sub_6(&v.0, &M::PRIME.0);
+ debug_assert!(!spurious_underflow);
+ v = U384(new_v);
+ }
+ Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+ }
+
+ pub(super) fn from_modinv_of(v: U384) -> Result<Self, ()> {
+ Ok(Self::from_u384(U384(mod_inv_6(&v.0, &M::PRIME.0)?)))
+ }
+
+ /// Multiplies `self` * `b` mod `m`.
+ ///
+ /// Panics if `self`'s modulus is not equal to `b`'s
+ pub(super) fn mul(&self, b: &Self) -> Self {
+ Self::mont_reduction(mul_6(&self.0.0, &b.0.0))
+ }
+
+ /// Doubles `self` mod `m`.
+ pub(super) fn double(&self) -> Self {
+ let mut res = self.0.0;
+ let overflow = double!(res);
+ if overflow || !slice_greater_than(&M::PRIME.0, &res) {
+ let underflow;
+ (res, underflow) = sub_6(&res, &M::PRIME.0);
+ debug_assert_eq!(overflow, underflow);
+ }
+ Self(U384(res), PhantomData)
+ }
+
+ /// Multiplies `self` by 3 mod `m`.
+ pub(super) fn times_three(&self) -> Self {
+ // TODO: Optimize this a lot
+ self.mul(&U384Mod::from_u384(U384::three()))
+ }
+
+ /// Multiplies `self` by 4 mod `m`.
+ pub(super) fn times_four(&self) -> Self {
+ // TODO: Optimize this somewhat?
+ self.double().double()
+ }
+
+ /// Multiplies `self` by 8 mod `m`.
+ pub(super) fn times_eight(&self) -> Self {
+ // TODO: Optimize this somewhat?
+ self.double().double().double()
+ }
+
+ /// Multiplies `self` by 8 mod `m`.
+ pub(super) fn square(&self) -> Self {
+ Self::mont_reduction(sqr_6(&self.0.0))
+ }
+
+ /// Subtracts `b` from `self` % `m`.
+ pub(super) fn sub(&self, b: &Self) -> Self {
+ let (mut val, underflow) = sub_6(&self.0.0, &b.0.0);
+ if underflow {
+ let overflow;
+ (val, overflow) = add_6(&val, &M::PRIME.0);
+ debug_assert_eq!(overflow, underflow);
+ }
+ Self(U384(val), PhantomData)
+ }
+
+ /// Adds `b` to `self` % `m`.
+ pub(super) fn add(&self, b: &Self) -> Self {
+ let (mut val, overflow) = add_6(&self.0.0, &b.0.0);
+ if overflow || !slice_greater_than(&M::PRIME.0, &val) {
+ let underflow;
+ (val, underflow) = sub_6(&val, &M::PRIME.0);
+ debug_assert_eq!(overflow, underflow);
+ }
+ Self(U384(val), PhantomData)
+ }
+
+ /// Returns the underlying [`U384`].
+ pub(super) fn into_u384(self) -> U384 {
+ let mut expanded_self = [0; 12];
+ expanded_self[6..].copy_from_slice(&self.0.0);
+ Self::mont_reduction(expanded_self).0
+ }
+}
+
+#[cfg(fuzzing)]
+mod fuzz_moduli {
+ use super::*;
+
+ pub struct P256();
+ impl PrimeModulus<U256> for P256 {
+ const PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
+ "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"));
+ const R_SQUARED_MOD_PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
+ "00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003"));
+ const NEGATIVE_PRIME_INV_MOD_R: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
+ "ffffffff00000002000000000000000000000001000000000000000000000001"));
+ }
+
+ pub struct P384();
+ impl PrimeModulus<U384> for P384 {
+ const PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
+ "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff"));
+ const R_SQUARED_MOD_PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
+ "000000000000000000000000000000010000000200000000fffffffe000000000000000200000000fffffffe00000001"));
+ const NEGATIVE_PRIME_INV_MOD_R: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
+ "00000014000000140000000c00000002fffffffcfffffffafffffffbfffffffe00000000000000010000000100000001"));
+ }
+}
+
#[cfg(fuzzing)]
extern crate ibig;
#[cfg(fuzzing)]
b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
}
- macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident) => {
+ macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
let res = $mul(&a_u64s, &b_u64s);
let mut res_bytes = Vec::with_capacity(input.len() / 2);
for i in res {
let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi);
assert_eq!(ibig::UBig::from_be_bytes("_bytes), quoti);
assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi);
+
+ if ai != ibig::UBig::from(0u32) { // ibig provides a spurious modular inverse for 0
+ let ring = ibig::modular::ModuloRing::new(&bi);
+ let ar = ring.from(ai.clone());
+ let invi = ar.inverse().map(|i| i.residue());
+
+ if let Ok(modinv) = $mod_inv(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
+ let mut modinv_bytes = Vec::with_capacity(input.len() / 2);
+ for i in modinv {
+ modinv_bytes.extend_from_slice(&i.to_be_bytes());
+ }
+ assert_eq!(invi.unwrap(), ibig::UBig::from_be_bytes(&modinv_bytes));
+ } else {
+ assert!(invi.is_none());
+ }
+ }
+ } }
+
+ macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => {
+ // Test the U256/U384Mod wrapper, which operates in Montgomery representation
+ let mut p_extended = [0; $len * 2];
+ p_extended[$len..].copy_from_slice(&$PRIME);
+
+ let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
+ assert_eq!(&amodp_squared[..$len], &[0; $len]);
+ assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
+
+ let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1;
+ assert_eq!(&abmodp[..$len], &[0; $len]);
+ assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
+
+ let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s);
+ let mut aplusb_extended = [0; $len * 2];
+ aplusb_extended[$len..].copy_from_slice(&aplusb);
+ if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
+ let aplusbmodp = $div_rem_double(&aplusb_extended, &p_extended).unwrap().1;
+ assert_eq!(&aplusbmodp[..$len], &[0; $len]);
+ assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
+
+ let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s);
+ if aminusb_underflow {
+ let mut overflow;
+ (aminusb, overflow) = $add(&aminusb, &$PRIME);
+ if !overflow {
+ (aminusb, overflow) = $add(&aminusb, &$PRIME);
+ }
+ assert!(overflow);
+ }
+ let aminusbmodp = $div_rem(&aminusb, &$PRIME).unwrap().1;
+ assert_eq!(&$amodp.sub(&$bmodp).$into().0, &aminusbmodp);
} }
if a_u64s.len() == 2 {
- test!(mul_2, sqr_2, add_2, sub_2, div_rem_2);
+ test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2);
} else if a_u64s.len() == 4 {
- test!(mul_4, sqr_4, add_4, sub_4, div_rem_4);
+ test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4);
+ let amodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(a_u64s[..].try_into().unwrap()));
+ let bmodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(b_u64s[..].try_into().unwrap()));
+ test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4);
+ } else if a_u64s.len() == 6 {
+ test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6);
+ let amodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(a_u64s[..].try_into().unwrap()));
+ let bmodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(b_u64s[..].try_into().unwrap()));
+ test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6);
} else if a_u64s.len() == 8 {
- test!(mul_8, sqr_8, add_8, sub_8, div_rem_8);
+ test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8);
} else if input.len() == 512*2 + 4 {
let mut e_bytes = [0; 4];
e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);