From 1d725ca4a022415c85072bc763d50738df863d6d Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Fri, 3 May 2024 18:41:46 +0000 Subject: [PATCH] Clean up and better comment math somewhat further --- src/crypto/bigint.rs | 312 +++++++++++++++++++------------------------ 1 file changed, 138 insertions(+), 174 deletions(-) diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index d07bb56..7c187ca 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -3,51 +3,9 @@ use alloc::vec::Vec; use core::marker::PhantomData; -const WORD_COUNT_4096: usize = 4096 / 64; -const WORD_COUNT_256: usize = 256 / 64; -const WORD_COUNT_384: usize = 384 / 64; - -// RFC 5702 indicates RSA keys can be up to 4096 bits -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub(super) struct U4096([u64; WORD_COUNT_4096]); - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub(super) struct U256([u64; WORD_COUNT_256]); - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub(super) struct U384([u64; WORD_COUNT_384]); - -pub(super) trait Int: Clone + Ord + Sized { - const ZERO: Self; - const BYTES: usize; - fn from_be_bytes(b: &[u8]) -> Result; - fn limbs(&self) -> &[u64]; -} -impl Int for U256 { - const ZERO: U256 = U256([0; 4]); - const BYTES: usize = 32; - fn from_be_bytes(b: &[u8]) -> Result { Self::from_be_bytes(b) } - fn limbs(&self) -> &[u64] { &self.0 } -} -impl Int for U384 { - const ZERO: U384 = U384([0; 6]); - const BYTES: usize = 48; - fn from_be_bytes(b: &[u8]) -> Result { Self::from_be_bytes(b) } - fn limbs(&self) -> &[u64] { &self.0 } -} - -/// Defines a *PRIME* Modulus -pub(super) trait PrimeModulus { - const PRIME: I; - const R_SQUARED_MOD_PRIME: I; - const NEGATIVE_PRIME_INV_MOD_R: I; -} - -#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor -pub(super) struct U256Mod>(U256, PhantomData); - -#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor -pub(super) struct U384Mod>(U384, PhantomData); +// ************************************** +// * Implementations of math primitives * +// ************************************** macro_rules! debug_unwrap { ($v: expr) => { { let v = $v; @@ -522,6 +480,11 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init $($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> { if slice_equal(b, &[0; $len]) { return Err(()); } + // Very naively divide `a` by `b` by calculating all the powers of two times `b` up to `a`, + // then subtracting the powers of two in decreasing order. What's left is the remainder. + // + // This requires storing all the multiples of `b` in `pow2s`, which may be a vec or an + // array. `$pre_push!()` sets up the next element with zeros and then we can overwrite it. let mut b_pow = *b; let mut pow2s = $heap_init; let mut pow2s_count = 0; @@ -538,10 +501,10 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init while pow2 >= 0 { let b_pow = pow2s[pow2 as usize]; let overflow = double!(quot); - debug_assert!(!overflow); + debug_assert!(!overflow, "quotient should be expressible in $len*64 bits"); if slice_greater_than(&rem, &b_pow) { - let (r, carry) = $sub(&rem, &b_pow); - debug_assert!(!carry); + let (r, underflow) = $sub(&rem, &b_pow); + debug_assert!(!underflow, "rem was just checked to be > b_pow, so sub cannot underflow"); rem = r; quot[$len - 1] |= 1; } @@ -549,7 +512,7 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init } if slice_equal(&rem, b) { let overflow = add_u64!(quot, 1); - debug_assert!(!overflow); + debug_assert!(!overflow, "quotient should be expressible in $len*64 bits"); Ok((quot, [0; $len])) } else { Ok((quot, rem)) @@ -638,6 +601,56 @@ define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6); #[cfg(fuzzing)] define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8); +// ****************** +// * The public API * +// ****************** + +const WORD_COUNT_4096: usize = 4096 / 64; +const WORD_COUNT_256: usize = 256 / 64; +const WORD_COUNT_384: usize = 384 / 64; + +// RFC 5702 indicates RSA keys can be up to 4096 bits, so we always use 4096-bit integers +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct U4096([u64; WORD_COUNT_4096]); + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct U256([u64; WORD_COUNT_256]); + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub(super) struct U384([u64; WORD_COUNT_384]); + +pub(super) trait Int: Clone + Ord + Sized { + const ZERO: Self; + const BYTES: usize; + fn from_be_bytes(b: &[u8]) -> Result; + fn limbs(&self) -> &[u64]; +} +impl Int for U256 { + const ZERO: U256 = U256([0; 4]); + const BYTES: usize = 32; + fn from_be_bytes(b: &[u8]) -> Result { Self::from_be_bytes(b) } + fn limbs(&self) -> &[u64] { &self.0 } +} +impl Int for U384 { + const ZERO: U384 = U384([0; 6]); + const BYTES: usize = 48; + fn from_be_bytes(b: &[u8]) -> Result { Self::from_be_bytes(b) } + fn limbs(&self) -> &[u64] { &self.0 } +} + +/// Defines a *PRIME* Modulus +pub(super) trait PrimeModulus { + const PRIME: I; + const R_SQUARED_MOD_PRIME: I; + const NEGATIVE_PRIME_INV_MOD_R: I; +} + +#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor +pub(super) struct U256Mod>(U256, PhantomData); + +#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor +pub(super) struct U384Mod>(U384, PhantomData); + impl U4096 { /// Constructs a new [`U4096`] from a variable number of big-endian bytes. pub(super) fn from_be_bytes(bytes: &[u8]) -> Result { @@ -694,6 +707,10 @@ impl U4096 { // Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is // guaranteed to be co-prime with any non-even integer. + // We use a single 4096-bit integer type for all our RSA operations, though in most cases + // we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math + // here which debug_assert's the required bits are 0s and then uses faster math primitives. + type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2]; type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2]; type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool); @@ -785,6 +802,7 @@ impl U4096 { (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty) }; + // r is always the even value with one bit set above the word count we're using. let mut r = [0; WORD_COUNT_4096 * 2]; r[WORD_COUNT_4096 * 2 - word_count - 1] = 1; @@ -800,7 +818,7 @@ impl U4096 { } m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0); - // We want the negative modular inverse of m mod R, so subtract m_inv from R. + // `m_inv` is the negative modular inverse of m mod R, so subtract m_inv from R. let mut m_inv = m_inv_pos; negate!(m_inv); m_inv[..WORD_COUNT_4096 - word_count].fill(0); @@ -808,25 +826,38 @@ impl U4096 { // R - 1 == -1 % R &[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]); - debug_assert_eq!(&m_inv[..WORD_COUNT_4096 - word_count], &[0; WORD_COUNT_4096][..WORD_COUNT_4096 - word_count]); - let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] { debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2], &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]); + // Do a montgomery reduction of `mu` + + // mu % R is just the bottom word_count bytes of mu let mut mu_mod_r = [0; WORD_COUNT_4096]; mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]); + + // v = ((mu % R) * negative_modulus_inverse) % R let mut v = mul(&mu_mod_r, &m_inv); v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R + + // t_on_r = (mu + v*modulus) / R let t0 = mul(&v[WORD_COUNT_4096..], &m.0); let (t1, t1_extra_bit) = add_double(&t0, &mu); + + // Note that dividing t1 by R is simply a matter of shifting right by word_count bytes + // We only need to maintain word_count bytes (plus `t1_extra_bit` which is implicitly + // an extra bit) because t_on_r is guarantee to be, at max, 2*m - 1. let mut t1_on_r = [0; WORD_COUNT_4096]; debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..], "t1 should be divisible by r"); t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]); + + // The modulus has only word_count bytes, so if t1_extra_bit is set we are definitely + // larger than the modulus. if t1_extra_bit || t1_on_r >= m.0 { let underflow; (t1_on_r, underflow) = sub(&t1_on_r, &m.0); - debug_assert_eq!(t1_extra_bit, underflow); + debug_assert_eq!(t1_extra_bit, underflow, + "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set."); } t1_on_r }; @@ -868,6 +899,8 @@ impl U4096 { } debug_assert!(r2_mod_m < m.0); + // Finally, actually do the exponentiation... + // Calculate t * R and a * R as mont multiplications by R^2 mod m let mut tr = mont_reduction(mul(&r2_mod_m, &t)); let mut ar = mont_reduction(mul(&r2_mod_m, &self.0)); @@ -893,115 +926,11 @@ impl U4096 { } } -const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 { - match b { - [a, b, c, d, e, f, g, h, ..] => { - ((*a as u64) << 8*7) | - ((*b as u64) << 8*6) | - ((*c as u64) << 8*5) | - ((*d as u64) << 8*4) | - ((*e as u64) << 8*3) | - ((*f as u64) << 8*2) | - ((*g as u64) << 8*1) | - ((*h as u64) << 8*0) - }, - _ => panic!(), - } -} - -const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 { - match b { - [_, _, _, _, _, _, _, _, - a, b, c, d, e, f, g, h, ..] => { - ((*a as u64) << 8*7) | - ((*b as u64) << 8*6) | - ((*c as u64) << 8*5) | - ((*d as u64) << 8*4) | - ((*e as u64) << 8*3) | - ((*f as u64) << 8*2) | - ((*g as u64) << 8*1) | - ((*h as u64) << 8*0) - }, - _ => panic!(), - } -} - -const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 { - match b { - [_, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - a, b, c, d, e, f, g, h, ..] => { - ((*a as u64) << 8*7) | - ((*b as u64) << 8*6) | - ((*c as u64) << 8*5) | - ((*d as u64) << 8*4) | - ((*e as u64) << 8*3) | - ((*f as u64) << 8*2) | - ((*g as u64) << 8*1) | - ((*h as u64) << 8*0) - }, - _ => panic!(), - } -} - -const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 { - match b { - [_, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - a, b, c, d, e, f, g, h, ..] => { - ((*a as u64) << 8*7) | - ((*b as u64) << 8*6) | - ((*c as u64) << 8*5) | - ((*d as u64) << 8*4) | - ((*e as u64) << 8*3) | - ((*f as u64) << 8*2) | - ((*g as u64) << 8*1) | - ((*h as u64) << 8*0) - }, - _ => panic!(), - } -} - -const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 { - match b { - [_, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - a, b, c, d, e, f, g, h, ..] => { - ((*a as u64) << 8*7) | - ((*b as u64) << 8*6) | - ((*c as u64) << 8*5) | - ((*d as u64) << 8*4) | - ((*e as u64) << 8*3) | - ((*f as u64) << 8*2) | - ((*g as u64) << 8*1) | - ((*h as u64) << 8*0) - }, - _ => panic!(), - } -} - -const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 { - match b { - [_, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - _, _, _, _, _, _, _, _, - a, b, c, d, e, f, g, h, ..] => { - ((*a as u64) << 8*7) | - ((*b as u64) << 8*6) | - ((*c as u64) << 8*5) | - ((*d as u64) << 8*4) | - ((*e as u64) << 8*3) | - ((*f as u64) << 8*2) | - ((*g as u64) << 8*1) | - ((*h as u64) << 8*0) - }, - _ => panic!(), - } +// In a const context we can't subslice a slice, so instead we pick the eight bytes we want and +// pass them here to build u64s from arrays. +const fn eight_bytes_to_u64_be(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 { + let b = [a, b, c, d, e, f, g, h]; + u64::from_be_bytes(b) } impl U256 { @@ -1024,10 +953,14 @@ impl U256 { /// Constructs a new [`U256`] from a fixed number of big-endian bytes. pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 { let res = [ - u64_from_bytes_a_panicking(bytes), - u64_from_bytes_b_panicking(bytes), - u64_from_bytes_c_panicking(bytes), - u64_from_bytes_d_panicking(bytes), + eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3], + bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]), + eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3], + bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]), + eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3], + bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]), + eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3], + bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]), ]; U256(res) } @@ -1037,6 +970,7 @@ impl U256 { pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) } } +// Values modulus M::PRIME.0, stored in montgomery form. impl> U256Mod { const fn mont_reduction(mu: [u64; 8]) -> Self { #[cfg(debug_assertions)] { @@ -1061,18 +995,30 @@ impl> U256Mod { assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0)); } + // mu % R is just the bottom 4 bytes of mu let mu_mod_r = const_subslice(&mu, 4, 8); + // v = ((mu % R) * negative_modulus_inverse) % R let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0); const ZEROS: &[u64; 4] = &[0; 4]; copy_from_slice!(v, 0, 4, ZEROS); // mod R + + // t_on_r = (mu + v*modulus) / R let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0); let (t1, t1_extra_bit) = add_8(&t0, &mu); + + // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes. + // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit) + // because t_on_r is guarantee to be, at max, 2*m - 1. let t1_on_r = const_subslice(&t1, 0, 4); + let mut res = [0; 4]; + // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the + // modulus. if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) { let underflow; (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0); - debug_assert!(t1_extra_bit == underflow); + debug_assert_eq!(t1_extra_bit, underflow, + "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set."); } else { copy_from_slice!(res, 0, 4, t1_on_r); } @@ -1099,7 +1045,7 @@ impl> U256Mod { debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set"); while v >= M::PRIME { let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0); - debug_assert!(!spurious_underflow); + debug_assert!(!spurious_underflow, "v was > M::PRIME.0"); v = U256(new_v); } Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0)) @@ -1181,6 +1127,7 @@ impl> U256Mod { } } +// Values modulus M::PRIME.0, stored in montgomery form. impl U384 { /// Constructs a new [`U384`] from a variable number of big-endian bytes. pub(super) fn from_be_bytes(bytes: &[u8]) -> Result { @@ -1201,12 +1148,18 @@ impl U384 { /// Constructs a new [`U384`] from a fixed number of big-endian bytes. pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 { let res = [ - u64_from_bytes_a_panicking(bytes), - u64_from_bytes_b_panicking(bytes), - u64_from_bytes_c_panicking(bytes), - u64_from_bytes_d_panicking(bytes), - u64_from_bytes_e_panicking(bytes), - u64_from_bytes_f_panicking(bytes), + eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3], + bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]), + eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3], + bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]), + eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3], + bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]), + eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3], + bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]), + eight_bytes_to_u64_be(bytes[4*8 + 0], bytes[4*8 + 1], bytes[4*8 + 2], bytes[4*8 + 3], + bytes[4*8 + 4], bytes[4*8 + 5], bytes[4*8 + 6], bytes[4*8 + 7]), + eight_bytes_to_u64_be(bytes[5*8 + 0], bytes[5*8 + 1], bytes[5*8 + 2], bytes[5*8 + 3], + bytes[5*8 + 4], bytes[5*8 + 5], bytes[5*8 + 6], bytes[5*8 + 7]), ]; U384(res) } @@ -1240,14 +1193,25 @@ impl> U384Mod { assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0)); } + // mu % R is just the bottom 4 bytes of mu let mu_mod_r = const_subslice(&mu, 6, 12); + // v = ((mu % R) * negative_modulus_inverse) % R let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0); const ZEROS: &[u64; 6] = &[0; 6]; copy_from_slice!(v, 0, 6, ZEROS); // mod R + + // t_on_r = (mu + v*modulus) / R let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0); let (t1, t1_extra_bit) = add_12(&t0, &mu); + + // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes. + // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit) + // because t_on_r is guarantee to be, at max, 2*m - 1. let t1_on_r = const_subslice(&t1, 0, 6); + let mut res = [0; 6]; + // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the + // modulus. if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) { let underflow; (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0); -- 2.39.5