X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=src%2Fcrypto%2Fbigint.rs;h=c9264cbe5a17151eeb13a404229fc7a81b808db4;hb=9c6775ca84026a2139d1845f061670f581088fa3;hp=62beef434bee5290c28f6991dbb23ca799ec4ef8;hpb=ed08985212345ac11904b2db5b7f78be0d8885e3;p=dnssec-prover diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index 62beef4..c9264cb 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -1,13 +1,54 @@ //! Simple variable-time big integer implementation use alloc::vec::Vec; +use core::marker::PhantomData; const WORD_COUNT_4096: usize = 4096 / 64; +const WORD_COUNT_256: usize = 256 / 64; +const WORD_COUNT_384: usize = 384 / 64; // RFC 5702 indicates RSA keys can be up to 4096 bits #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub(super) struct U4096([u64; WORD_COUNT_4096]); +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct U256([u64; WORD_COUNT_256]); + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct U384([u64; WORD_COUNT_384]); + +pub(super) trait Int: Clone + Ord + Sized { + const ZERO: Self; + const BYTES: usize; + fn from_be_bytes(b: &[u8]) -> Result; + fn limbs(&self) -> &[u64]; +} +impl Int for U256 { + const ZERO: U256 = U256([0; 4]); + const BYTES: usize = 32; + fn from_be_bytes(b: &[u8]) -> Result { Self::from_be_bytes(b) } + fn limbs(&self) -> &[u64] { &self.0 } +} +impl Int for U384 { + const ZERO: U384 = U384([0; 6]); + const BYTES: usize = 48; + fn from_be_bytes(b: &[u8]) -> Result { Self::from_be_bytes(b) } + fn limbs(&self) -> &[u64] { &self.0 } +} + +/// Defines a *PRIME* Modulus +pub(super) trait PrimeModulus { + const PRIME: I; + const R_SQUARED_MOD_PRIME: I; + const NEGATIVE_PRIME_INV_MOD_R: I; +} + +#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor +pub(super) struct U256Mod>(U256, PhantomData); + +#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor +pub(super) struct U384Mod>(U384, PhantomData); + macro_rules! debug_unwrap { ($v: expr) => { { let v = $v; debug_assert!(v.is_ok()); @@ -150,14 +191,17 @@ macro_rules! define_add { ($name: ident, $len: expr) => { } } define_add!(add_2, 2); +define_add!(add_3, 3); define_add!(add_4, 4); +define_add!(add_6, 6); define_add!(add_8, 8); +define_add!(add_12, 12); define_add!(add_16, 16); define_add!(add_32, 32); define_add!(add_64, 64); define_add!(add_128, 128); -macro_rules! define_sub { ($name: ident, $len: expr) => { +macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => { /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new /// $len-64-bit integer and an overflow bit, with the same semantics as the std /// [`u64::overflowing_sub`] method. @@ -178,16 +222,29 @@ macro_rules! define_sub { ($name: ident, $len: expr) => { } (r, carry) } + + /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new + /// $len-64-bit integer representing the absolute value of the result, as well as a sign bit. + #[allow(unused)] + const fn $name_abs(a: &[u64], b: &[u64]) -> ([u64; $len], bool) { + let (mut res, neg) = $name(a, b); + if neg { + negate!(res); + } + (res, neg) + } } } -define_sub!(sub_2, 2); -define_sub!(sub_4, 4); -define_sub!(sub_8, 8); -define_sub!(sub_16, 16); -define_sub!(sub_32, 32); -define_sub!(sub_64, 64); -#[cfg(debug_assertions)] -define_sub!(sub_128, 128); +define_sub!(sub_2, sub_abs_2, 2); +define_sub!(sub_3, sub_abs_3, 3); +define_sub!(sub_4, sub_abs_4, 4); +define_sub!(sub_6, sub_abs_6, 6); +define_sub!(sub_8, sub_abs_8, 8); +define_sub!(sub_12, sub_abs_12, 12); +define_sub!(sub_16, sub_abs_16, 16); +define_sub!(sub_32, sub_abs_32, 32); +define_sub!(sub_64, sub_abs_64, 64); +define_sub!(sub_128, sub_abs_128, 128); /// Multiplies two 128-bit integers together, returning a new 256-bit integer. /// @@ -232,6 +289,76 @@ const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] { [i, j, k, l] } +const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] { + debug_assert!(a.len() == 3); + debug_assert!(b.len() == 3); + + let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128); + let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128); + + let m4 = a2 * b2; + let m3a = a2 * b1; + let m3b = a1 * b2; + let m2a = a2 * b0; + let m2b = a1 * b1; + let m2c = a0 * b2; + let m1a = a1 * b0; + let m1b = a0 * b1; + let m0 = a0 * b0; + + let r5 = ((m4 >> 0) & 0xffff_ffff_ffff_ffff) as u64; + + let r4a = ((m4 >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r4b = ((m3a >> 0) & 0xffff_ffff_ffff_ffff) as u64; + let r4c = ((m3b >> 0) & 0xffff_ffff_ffff_ffff) as u64; + + let r3a = ((m3a >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r3b = ((m3b >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r3c = ((m2a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64; + let r3d = ((m2b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64; + let r3e = ((m2c >> 0 ) & 0xffff_ffff_ffff_ffff) as u64; + + let r2a = ((m2a >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r2b = ((m2b >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r2c = ((m2c >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r2d = ((m1a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64; + let r2e = ((m1b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64; + + let r1a = ((m1a >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r1b = ((m1b >> 64) & 0xffff_ffff_ffff_ffff) as u64; + let r1c = ((m0 >> 0 ) & 0xffff_ffff_ffff_ffff) as u64; + + let r0a = ((m0 >> 64) & 0xffff_ffff_ffff_ffff) as u64; + + let (r4, r3_ca) = r4a.overflowing_add(r4b); + let (r4, r3_cb) = r4.overflowing_add(r4c); + let r3_c = r3_ca as u64 + r3_cb as u64; + + let (r3, r2_ca) = r3a.overflowing_add(r3b); + let (r3, r2_cb) = r3.overflowing_add(r3c); + let (r3, r2_cc) = r3.overflowing_add(r3d); + let (r3, r2_cd) = r3.overflowing_add(r3e); + let (r3, r2_ce) = r3.overflowing_add(r3_c); + let r2_c = r2_ca as u64 + r2_cb as u64 + r2_cc as u64 + r2_cd as u64 + r2_ce as u64; + + let (r2, r1_ca) = r2a.overflowing_add(r2b); + let (r2, r1_cb) = r2.overflowing_add(r2c); + let (r2, r1_cc) = r2.overflowing_add(r2d); + let (r2, r1_cd) = r2.overflowing_add(r2e); + let (r2, r1_ce) = r2.overflowing_add(r2_c); + let r1_c = r1_ca as u64 + r1_cb as u64 + r1_cc as u64 + r1_cd as u64 + r1_ce as u64; + + let (r1, r0_ca) = r1a.overflowing_add(r1b); + let (r1, r0_cb) = r1.overflowing_add(r1c); + let (r1, r0_cc) = r1.overflowing_add(r1_c); + let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64; + + let (r0, must_not_overflow) = r0a.overflowing_add(r0_c); + debug_assert!(!must_not_overflow); + + [r0, r1, r2, r3, r4, r5] +} + macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => { /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer. const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] { @@ -303,6 +430,7 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident } } define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2); +define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3); define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4); define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8); define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16); @@ -393,13 +521,16 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id } } } +// TODO: Write an optimized sqr_3 (though secp384r1 is barely used) +const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) } + define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2); +define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3); define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4); define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8); define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16); define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32); -#[cfg(fuzzing)] macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} } macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } } @@ -447,14 +578,85 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init #[cfg(fuzzing)] define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const); -#[cfg(fuzzing)] define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack -#[cfg(fuzzing)] +define_div_rem!(div_rem_6, 6, sub_6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack! +#[cfg(debug_assertions)] define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack! +#[cfg(debug_assertions)] +define_div_rem!(div_rem_12, 12, sub_12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack! define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap #[cfg(debug_assertions)] define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap +macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => { + /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus, + /// if one exists. + const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> { + if slice_equal(a, &[0; $len]) || slice_equal(m, &[0; $len]) { return Err(()); } + + let (mut s, mut old_s) = ([0; $len], [0; $len]); + old_s[$len - 1] = 1; + let mut r = *m; + let mut old_r = *a; + + let (mut old_s_neg, mut s_neg) = (false, false); + + while !slice_equal(&r, &[0; $len]) { + let (quot, new_r) = debug_unwrap!($div(&old_r, &r)); + + let new_sa = $mul(", &s); + debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed"); + let (new_s, new_s_neg) = match (old_s_neg, s_neg) { + (true, true) => { + let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len())); + debug_assert!(!overflow); + (new_s, true) + } + (false, true) => { + let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len())); + debug_assert!(!overflow); + (new_s, false) + }, + (true, false) => { + let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len())); + debug_assert!(!overflow); + (new_s, true) + }, + (false, false) => $sub_abs(&old_s, const_subslice(&new_sa, $len, new_sa.len())), + }; + + old_r = r; + r = new_r; + + old_s = s; + old_s_neg = s_neg; + s = new_s; + s_neg = new_s_neg; + } + + // At this point old_r contains our GCD and old_s our first Bézout's identity coefficient. + if !slice_equal(const_subslice(&old_r, 0, $len - 1), &[0; $len - 1]) || old_r[$len - 1] != 1 { + Err(()) + } else { + debug_assert!(slice_greater_than(m, &old_s)); + if old_s_neg { + let (modinv, underflow) = $sub_abs(m, &old_s); + debug_assert!(!underflow); + debug_assert!(slice_greater_than(m, &modinv)); + Ok(modinv) + } else { + Ok(old_s) + } + } + } +} } +#[cfg(fuzzing)] +define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2); +define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4); +define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6); +#[cfg(fuzzing)] +define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8); + impl U4096 { /// Constructs a new [`U4096`] from a variable number of big-endian bytes. pub(super) fn from_be_bytes(bytes: &[u8]) -> Result { @@ -710,6 +912,505 @@ impl U4096 { } } +const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 { + match b { + [a, b, c, d, e, f, g, h, ..] => { + ((*a as u64) << 8*7) | + ((*b as u64) << 8*6) | + ((*c as u64) << 8*5) | + ((*d as u64) << 8*4) | + ((*e as u64) << 8*3) | + ((*f as u64) << 8*2) | + ((*g as u64) << 8*1) | + ((*h as u64) << 8*0) + }, + _ => panic!(), + } +} + +const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 { + match b { + [_, _, _, _, _, _, _, _, + a, b, c, d, e, f, g, h, ..] => { + ((*a as u64) << 8*7) | + ((*b as u64) << 8*6) | + ((*c as u64) << 8*5) | + ((*d as u64) << 8*4) | + ((*e as u64) << 8*3) | + ((*f as u64) << 8*2) | + ((*g as u64) << 8*1) | + ((*h as u64) << 8*0) + }, + _ => panic!(), + } +} + +const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 { + match b { + [_, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + a, b, c, d, e, f, g, h, ..] => { + ((*a as u64) << 8*7) | + ((*b as u64) << 8*6) | + ((*c as u64) << 8*5) | + ((*d as u64) << 8*4) | + ((*e as u64) << 8*3) | + ((*f as u64) << 8*2) | + ((*g as u64) << 8*1) | + ((*h as u64) << 8*0) + }, + _ => panic!(), + } +} + +const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 { + match b { + [_, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + a, b, c, d, e, f, g, h, ..] => { + ((*a as u64) << 8*7) | + ((*b as u64) << 8*6) | + ((*c as u64) << 8*5) | + ((*d as u64) << 8*4) | + ((*e as u64) << 8*3) | + ((*f as u64) << 8*2) | + ((*g as u64) << 8*1) | + ((*h as u64) << 8*0) + }, + _ => panic!(), + } +} + +const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 { + match b { + [_, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + a, b, c, d, e, f, g, h, ..] => { + ((*a as u64) << 8*7) | + ((*b as u64) << 8*6) | + ((*c as u64) << 8*5) | + ((*d as u64) << 8*4) | + ((*e as u64) << 8*3) | + ((*f as u64) << 8*2) | + ((*g as u64) << 8*1) | + ((*h as u64) << 8*0) + }, + _ => panic!(), + } +} + +const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 { + match b { + [_, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + _, _, _, _, _, _, _, _, + a, b, c, d, e, f, g, h, ..] => { + ((*a as u64) << 8*7) | + ((*b as u64) << 8*6) | + ((*c as u64) << 8*5) | + ((*d as u64) << 8*4) | + ((*e as u64) << 8*3) | + ((*f as u64) << 8*2) | + ((*g as u64) << 8*1) | + ((*h as u64) << 8*0) + }, + _ => panic!(), + } +} + +impl U256 { + /// Constructs a new [`U256`] from a variable number of big-endian bytes. + pub(super) fn from_be_bytes(bytes: &[u8]) -> Result { + if bytes.len() > 256/8 { return Err(()); } + let u64s = (bytes.len() + 7) / 8; + let mut res = [0; WORD_COUNT_256]; + for i in 0..u64s { + let mut b = [0; 8]; + let pos = (u64s - i) * 8; + let start = bytes.len().saturating_sub(pos); + let end = bytes.len() + 8 - pos; + b[8 + start - end..].copy_from_slice(&bytes[start..end]); + res[i + WORD_COUNT_256 - u64s] = u64::from_be_bytes(b); + } + Ok(U256(res)) + } + + /// Constructs a new [`U256`] from a fixed number of big-endian bytes. + pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 { + let res = [ + u64_from_bytes_a_panicking(bytes), + u64_from_bytes_b_panicking(bytes), + u64_from_bytes_c_panicking(bytes), + u64_from_bytes_d_panicking(bytes), + ]; + U256(res) + } + + pub(super) const fn zero() -> U256 { U256([0, 0, 0, 0]) } + pub(super) const fn one() -> U256 { U256([0, 0, 0, 1]) } + pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) } +} + +impl> U256Mod { + const fn mont_reduction(mu: [u64; 8]) -> Self { + #[cfg(debug_assertions)] { + // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler + // should be able to do it at compile time alone. + let minus_one_mod_r = mul_4(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0); + assert!(slice_equal(const_subslice(&minus_one_mod_r, 4, 8), &[0xffff_ffff_ffff_ffff; 4])); + } + + #[cfg(debug_assertions)] { + // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler + // should be able to do it at compile time alone. + let r_minus_one = [0xffff_ffff_ffff_ffff; 4]; + let (mut r_mod_prime, _) = sub_4(&r_minus_one, &M::PRIME.0); + add_one!(r_mod_prime); + let r_squared = sqr_4(&r_mod_prime); + let mut prime_extended = [0; 8]; + let prime = M::PRIME.0; + copy_from_slice!(prime_extended, 4, 8, prime); + let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_8(&r_squared, &prime_extended) { v } else { panic!() }; + assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime)); + assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0)); + } + + let mu_mod_r = const_subslice(&mu, 4, 8); + let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0); + const ZEROS: &[u64; 4] = &[0; 4]; + copy_from_slice!(v, 0, 4, ZEROS); // mod R + let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0); + let (t1, t1_extra_bit) = add_8(&t0, &mu); + let t1_on_r = const_subslice(&t1, 0, 4); + let mut res = [0; 4]; + if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) { + let underflow; + (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0); + debug_assert!(t1_extra_bit == underflow); + } else { + copy_from_slice!(res, 0, 4, t1_on_r); + } + Self(U256(res), PhantomData) + } + + pub(super) const fn from_u256_panicking(v: U256) -> Self { + assert!(v.0[0] <= M::PRIME.0[0]); + if v.0[0] == M::PRIME.0[0] { + assert!(v.0[1] <= M::PRIME.0[1]); + if v.0[1] == M::PRIME.0[1] { + assert!(v.0[2] <= M::PRIME.0[2]); + if v.0[2] == M::PRIME.0[2] { + assert!(v.0[3] < M::PRIME.0[3]); + } + } + } + assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0 || M::PRIME.0[3] != 0); + Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0)) + } + + pub(super) fn from_u256(mut v: U256) -> Self { + debug_assert!(M::PRIME.0 != [0; 4]); + debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set"); + while v >= M::PRIME { + let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0); + debug_assert!(!spurious_underflow); + v = U256(new_v); + } + Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0)) + } + + pub(super) fn from_modinv_of(v: U256) -> Result { + Ok(Self::from_u256(U256(mod_inv_4(&v.0, &M::PRIME.0)?))) + } + + /// Multiplies `self` * `b` mod `m`. + /// + /// Panics if `self`'s modulus is not equal to `b`'s + pub(super) fn mul(&self, b: &Self) -> Self { + Self::mont_reduction(mul_4(&self.0.0, &b.0.0)) + } + + /// Doubles `self` mod `m`. + pub(super) fn double(&self) -> Self { + let mut res = self.0.0; + let overflow = double!(res); + if overflow || !slice_greater_than(&M::PRIME.0, &res) { + let underflow; + (res, underflow) = sub_4(&res, &M::PRIME.0); + debug_assert_eq!(overflow, underflow); + } + Self(U256(res), PhantomData) + } + + /// Multiplies `self` by 3 mod `m`. + pub(super) fn times_three(&self) -> Self { + // TODO: Optimize this a lot + self.mul(&U256Mod::from_u256(U256::three())) + } + + /// Multiplies `self` by 4 mod `m`. + pub(super) fn times_four(&self) -> Self { + // TODO: Optimize this somewhat? + self.double().double() + } + + /// Multiplies `self` by 8 mod `m`. + pub(super) fn times_eight(&self) -> Self { + // TODO: Optimize this somewhat? + self.double().double().double() + } + + /// Multiplies `self` by 8 mod `m`. + pub(super) fn square(&self) -> Self { + Self::mont_reduction(sqr_4(&self.0.0)) + } + + /// Subtracts `b` from `self` % `m`. + pub(super) fn sub(&self, b: &Self) -> Self { + let (mut val, underflow) = sub_4(&self.0.0, &b.0.0); + if underflow { + let overflow; + (val, overflow) = add_4(&val, &M::PRIME.0); + debug_assert_eq!(overflow, underflow); + } + Self(U256(val), PhantomData) + } + + /// Adds `b` to `self` % `m`. + pub(super) fn add(&self, b: &Self) -> Self { + let (mut val, overflow) = add_4(&self.0.0, &b.0.0); + if overflow || !slice_greater_than(&M::PRIME.0, &val) { + let underflow; + (val, underflow) = sub_4(&val, &M::PRIME.0); + debug_assert_eq!(overflow, underflow); + } + Self(U256(val), PhantomData) + } + + /// Returns the underlying [`U256`]. + pub(super) fn into_u256(self) -> U256 { + let mut expanded_self = [0; 8]; + expanded_self[4..].copy_from_slice(&self.0.0); + Self::mont_reduction(expanded_self).0 + } +} + +impl U384 { + /// Constructs a new [`U384`] from a variable number of big-endian bytes. + pub(super) fn from_be_bytes(bytes: &[u8]) -> Result { + if bytes.len() > 384/8 { return Err(()); } + let u64s = (bytes.len() + 7) / 8; + let mut res = [0; WORD_COUNT_384]; + for i in 0..u64s { + let mut b = [0; 8]; + let pos = (u64s - i) * 8; + let start = bytes.len().saturating_sub(pos); + let end = bytes.len() + 8 - pos; + b[8 + start - end..].copy_from_slice(&bytes[start..end]); + res[i + WORD_COUNT_384 - u64s] = u64::from_be_bytes(b); + } + Ok(U384(res)) + } + + /// Constructs a new [`U384`] from a fixed number of big-endian bytes. + pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 { + let res = [ + u64_from_bytes_a_panicking(bytes), + u64_from_bytes_b_panicking(bytes), + u64_from_bytes_c_panicking(bytes), + u64_from_bytes_d_panicking(bytes), + u64_from_bytes_e_panicking(bytes), + u64_from_bytes_f_panicking(bytes), + ]; + U384(res) + } + + pub(super) const fn zero() -> U384 { U384([0, 0, 0, 0, 0, 0]) } + pub(super) const fn one() -> U384 { U384([0, 0, 0, 0, 0, 1]) } + pub(super) const fn three() -> U384 { U384([0, 0, 0, 0, 0, 3]) } +} + +impl> U384Mod { + const fn mont_reduction(mu: [u64; 12]) -> Self { + #[cfg(debug_assertions)] { + // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler + // should be able to do it at compile time alone. + let minus_one_mod_r = mul_6(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0); + assert!(slice_equal(const_subslice(&minus_one_mod_r, 6, 12), &[0xffff_ffff_ffff_ffff; 6])); + } + + #[cfg(debug_assertions)] { + // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler + // should be able to do it at compile time alone. + let r_minus_one = [0xffff_ffff_ffff_ffff; 6]; + let (mut r_mod_prime, _) = sub_6(&r_minus_one, &M::PRIME.0); + add_one!(r_mod_prime); + let r_squared = sqr_6(&r_mod_prime); + let mut prime_extended = [0; 12]; + let prime = M::PRIME.0; + copy_from_slice!(prime_extended, 6, 12, prime); + let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_12(&r_squared, &prime_extended) { v } else { panic!() }; + assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime)); + assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0)); + } + + let mu_mod_r = const_subslice(&mu, 6, 12); + let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0); + const ZEROS: &[u64; 6] = &[0; 6]; + copy_from_slice!(v, 0, 6, ZEROS); // mod R + let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0); + let (t1, t1_extra_bit) = add_12(&t0, &mu); + let t1_on_r = const_subslice(&t1, 0, 6); + let mut res = [0; 6]; + if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) { + let underflow; + (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0); + debug_assert!(t1_extra_bit == underflow); + } else { + copy_from_slice!(res, 0, 6, t1_on_r); + } + Self(U384(res), PhantomData) + } + + pub(super) const fn from_u384_panicking(v: U384) -> Self { + assert!(v.0[0] <= M::PRIME.0[0]); + if v.0[0] == M::PRIME.0[0] { + assert!(v.0[1] <= M::PRIME.0[1]); + if v.0[1] == M::PRIME.0[1] { + assert!(v.0[2] <= M::PRIME.0[2]); + if v.0[2] == M::PRIME.0[2] { + assert!(v.0[3] <= M::PRIME.0[3]); + if v.0[3] == M::PRIME.0[3] { + assert!(v.0[4] <= M::PRIME.0[4]); + if v.0[4] == M::PRIME.0[4] { + assert!(v.0[5] < M::PRIME.0[5]); + } + } + } + } + } + assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0 + || M::PRIME.0[3] != 0|| M::PRIME.0[4] != 0|| M::PRIME.0[5] != 0); + Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0)) + } + + pub(super) fn from_u384(mut v: U384) -> Self { + debug_assert!(M::PRIME.0 != [0; 6]); + debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set"); + while v >= M::PRIME { + let (new_v, spurious_underflow) = sub_6(&v.0, &M::PRIME.0); + debug_assert!(!spurious_underflow); + v = U384(new_v); + } + Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0)) + } + + pub(super) fn from_modinv_of(v: U384) -> Result { + Ok(Self::from_u384(U384(mod_inv_6(&v.0, &M::PRIME.0)?))) + } + + /// Multiplies `self` * `b` mod `m`. + /// + /// Panics if `self`'s modulus is not equal to `b`'s + pub(super) fn mul(&self, b: &Self) -> Self { + Self::mont_reduction(mul_6(&self.0.0, &b.0.0)) + } + + /// Doubles `self` mod `m`. + pub(super) fn double(&self) -> Self { + let mut res = self.0.0; + let overflow = double!(res); + if overflow || !slice_greater_than(&M::PRIME.0, &res) { + let underflow; + (res, underflow) = sub_6(&res, &M::PRIME.0); + debug_assert_eq!(overflow, underflow); + } + Self(U384(res), PhantomData) + } + + /// Multiplies `self` by 3 mod `m`. + pub(super) fn times_three(&self) -> Self { + // TODO: Optimize this a lot + self.mul(&U384Mod::from_u384(U384::three())) + } + + /// Multiplies `self` by 4 mod `m`. + pub(super) fn times_four(&self) -> Self { + // TODO: Optimize this somewhat? + self.double().double() + } + + /// Multiplies `self` by 8 mod `m`. + pub(super) fn times_eight(&self) -> Self { + // TODO: Optimize this somewhat? + self.double().double().double() + } + + /// Multiplies `self` by 8 mod `m`. + pub(super) fn square(&self) -> Self { + Self::mont_reduction(sqr_6(&self.0.0)) + } + + /// Subtracts `b` from `self` % `m`. + pub(super) fn sub(&self, b: &Self) -> Self { + let (mut val, underflow) = sub_6(&self.0.0, &b.0.0); + if underflow { + let overflow; + (val, overflow) = add_6(&val, &M::PRIME.0); + debug_assert_eq!(overflow, underflow); + } + Self(U384(val), PhantomData) + } + + /// Adds `b` to `self` % `m`. + pub(super) fn add(&self, b: &Self) -> Self { + let (mut val, overflow) = add_6(&self.0.0, &b.0.0); + if overflow || !slice_greater_than(&M::PRIME.0, &val) { + let underflow; + (val, underflow) = sub_6(&val, &M::PRIME.0); + debug_assert_eq!(overflow, underflow); + } + Self(U384(val), PhantomData) + } + + /// Returns the underlying [`U384`]. + pub(super) fn into_u384(self) -> U384 { + let mut expanded_self = [0; 12]; + expanded_self[6..].copy_from_slice(&self.0.0); + Self::mont_reduction(expanded_self).0 + } +} + +#[cfg(fuzzing)] +mod fuzz_moduli { + use super::*; + + pub struct P256(); + impl PrimeModulus for P256 { + const PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!( + "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff")); + const R_SQUARED_MOD_PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!( + "00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003")); + const NEGATIVE_PRIME_INV_MOD_R: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!( + "ffffffff00000002000000000000000000000001000000000000000000000001")); + } + + pub struct P384(); + impl PrimeModulus for P384 { + const PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!( + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff")); + const R_SQUARED_MOD_PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!( + "000000000000000000000000000000010000000200000000fffffffe000000000000000200000000fffffffe00000001")); + const NEGATIVE_PRIME_INV_MOD_R: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!( + "00000014000000140000000c00000002fffffffcfffffffafffffffbfffffffe00000000000000010000000100000001")); + } +} + #[cfg(fuzzing)] extern crate ibig; #[cfg(fuzzing)] @@ -732,7 +1433,7 @@ pub fn fuzz_math(input: &[u8]) { b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap())); } - macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident) => { + macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => { let res = $mul(&a_u64s, &b_u64s); let mut res_bytes = Vec::with_capacity(input.len() / 2); for i in res { @@ -784,14 +1485,72 @@ pub fn fuzz_math(input: &[u8]) { let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi); assert_eq!(ibig::UBig::from_be_bytes("_bytes), quoti); assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi); + + if ai != ibig::UBig::from(0u32) { // ibig provides a spurious modular inverse for 0 + let ring = ibig::modular::ModuloRing::new(&bi); + let ar = ring.from(ai.clone()); + let invi = ar.inverse().map(|i| i.residue()); + + if let Ok(modinv) = $mod_inv(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) { + let mut modinv_bytes = Vec::with_capacity(input.len() / 2); + for i in modinv { + modinv_bytes.extend_from_slice(&i.to_be_bytes()); + } + assert_eq!(invi.unwrap(), ibig::UBig::from_be_bytes(&modinv_bytes)); + } else { + assert!(invi.is_none()); + } + } + } } + + macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => { + // Test the U256/U384Mod wrapper, which operates in Montgomery representation + let mut p_extended = [0; $len * 2]; + p_extended[$len..].copy_from_slice(&$PRIME); + + let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1; + assert_eq!(&amodp_squared[..$len], &[0; $len]); + assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]); + + let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1; + assert_eq!(&abmodp[..$len], &[0; $len]); + assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]); + + let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s); + let mut aplusb_extended = [0; $len * 2]; + aplusb_extended[$len..].copy_from_slice(&aplusb); + if aplusb_overflow { aplusb_extended[$len - 1] = 1; } + let aplusbmodp = $div_rem_double(&aplusb_extended, &p_extended).unwrap().1; + assert_eq!(&aplusbmodp[..$len], &[0; $len]); + assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]); + + let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s); + if aminusb_underflow { + let mut overflow; + (aminusb, overflow) = $add(&aminusb, &$PRIME); + if !overflow { + (aminusb, overflow) = $add(&aminusb, &$PRIME); + } + assert!(overflow); + } + let aminusbmodp = $div_rem(&aminusb, &$PRIME).unwrap().1; + assert_eq!(&$amodp.sub(&$bmodp).$into().0, &aminusbmodp); } } if a_u64s.len() == 2 { - test!(mul_2, sqr_2, add_2, sub_2, div_rem_2); + test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2); } else if a_u64s.len() == 4 { - test!(mul_4, sqr_4, add_4, sub_4, div_rem_4); + test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4); + let amodp = U256Mod::::from_u256(U256(a_u64s[..].try_into().unwrap())); + let bmodp = U256Mod::::from_u256(U256(b_u64s[..].try_into().unwrap())); + test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4); + } else if a_u64s.len() == 6 { + test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6); + let amodp = U384Mod::::from_u384(U384(a_u64s[..].try_into().unwrap())); + let bmodp = U384Mod::::from_u384(U384(b_u64s[..].try_into().unwrap())); + test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6); } else if a_u64s.len() == 8 { - test!(mul_8, sqr_8, add_8, sub_8, div_rem_8); + test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8); } else if input.len() == 512*2 + 4 { let mut e_bytes = [0; 4]; e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);