use alloc::vec::Vec;
use core::marker::PhantomData;
-const WORD_COUNT_4096: usize = 4096 / 64;
-const WORD_COUNT_256: usize = 256 / 64;
-const WORD_COUNT_384: usize = 384 / 64;
-
-// RFC 5702 indicates RSA keys can be up to 4096 bits
-#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub(super) struct U4096([u64; WORD_COUNT_4096]);
-
-#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub(super) struct U256([u64; WORD_COUNT_256]);
-
-#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub(super) struct U384([u64; WORD_COUNT_384]);
-
-pub(super) trait Int: Clone + Ord + Sized {
- const ZERO: Self;
- const BYTES: usize;
- fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
- fn limbs(&self) -> &[u64];
-}
-impl Int for U256 {
- const ZERO: U256 = U256([0; 4]);
- const BYTES: usize = 32;
- fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
- fn limbs(&self) -> &[u64] { &self.0 }
-}
-impl Int for U384 {
- const ZERO: U384 = U384([0; 6]);
- const BYTES: usize = 48;
- fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
- fn limbs(&self) -> &[u64] { &self.0 }
-}
-
-/// Defines a *PRIME* Modulus
-pub(super) trait PrimeModulus<I: Int> {
- const PRIME: I;
- const R_SQUARED_MOD_PRIME: I;
- const NEGATIVE_PRIME_INV_MOD_R: I;
-}
-
-#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
-pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
-
-#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
-pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
+// **************************************
+// * Implementations of math primitives *
+// **************************************
macro_rules! debug_unwrap { ($v: expr) => { {
let v = $v;
$($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> {
if slice_equal(b, &[0; $len]) { return Err(()); }
+ // Very naively divide `a` by `b` by calculating all the powers of two times `b` up to `a`,
+ // then subtracting the powers of two in decreasing order. What's left is the remainder.
+ //
+ // This requires storing all the multiples of `b` in `pow2s`, which may be a vec or an
+ // array. `$pre_push!()` sets up the next element with zeros and then we can overwrite it.
let mut b_pow = *b;
let mut pow2s = $heap_init;
let mut pow2s_count = 0;
while pow2 >= 0 {
let b_pow = pow2s[pow2 as usize];
let overflow = double!(quot);
- debug_assert!(!overflow);
+ debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
if slice_greater_than(&rem, &b_pow) {
- let (r, carry) = $sub(&rem, &b_pow);
- debug_assert!(!carry);
+ let (r, underflow) = $sub(&rem, &b_pow);
+ debug_assert!(!underflow, "rem was just checked to be > b_pow, so sub cannot underflow");
rem = r;
quot[$len - 1] |= 1;
}
}
if slice_equal(&rem, b) {
let overflow = add_u64!(quot, 1);
- debug_assert!(!overflow);
+ debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
Ok((quot, [0; $len]))
} else {
Ok((quot, rem))
#[cfg(fuzzing)]
define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
+// ******************
+// * The public API *
+// ******************
+
+const WORD_COUNT_4096: usize = 4096 / 64;
+const WORD_COUNT_256: usize = 256 / 64;
+const WORD_COUNT_384: usize = 384 / 64;
+
+// RFC 5702 indicates RSA keys can be up to 4096 bits, so we always use 4096-bit integers
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U4096([u64; WORD_COUNT_4096]);
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U256([u64; WORD_COUNT_256]);
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U384([u64; WORD_COUNT_384]);
+
+pub(super) trait Int: Clone + Ord + Sized {
+ const ZERO: Self;
+ const BYTES: usize;
+ fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
+ fn limbs(&self) -> &[u64];
+}
+impl Int for U256 {
+ const ZERO: U256 = U256([0; 4]);
+ const BYTES: usize = 32;
+ fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+ fn limbs(&self) -> &[u64] { &self.0 }
+}
+impl Int for U384 {
+ const ZERO: U384 = U384([0; 6]);
+ const BYTES: usize = 48;
+ fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+ fn limbs(&self) -> &[u64] { &self.0 }
+}
+
+/// Defines a *PRIME* Modulus
+pub(super) trait PrimeModulus<I: Int> {
+ const PRIME: I;
+ const R_SQUARED_MOD_PRIME: I;
+ const NEGATIVE_PRIME_INV_MOD_R: I;
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
+
impl U4096 {
/// Constructs a new [`U4096`] from a variable number of big-endian bytes.
pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
// Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is
// guaranteed to be co-prime with any non-even integer.
+ // We use a single 4096-bit integer type for all our RSA operations, though in most cases
+ // we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math
+ // here which debug_assert's the required bits are 0s and then uses faster math primitives.
+
type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
(64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
};
+ // r is always the even value with one bit set above the word count we're using.
let mut r = [0; WORD_COUNT_4096 * 2];
r[WORD_COUNT_4096 * 2 - word_count - 1] = 1;
}
m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0);
- // We want the negative modular inverse of m mod R, so subtract m_inv from R.
+ // `m_inv` is the negative modular inverse of m mod R, so subtract m_inv from R.
let mut m_inv = m_inv_pos;
negate!(m_inv);
m_inv[..WORD_COUNT_4096 - word_count].fill(0);
// R - 1 == -1 % R
&[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]);
- debug_assert_eq!(&m_inv[..WORD_COUNT_4096 - word_count], &[0; WORD_COUNT_4096][..WORD_COUNT_4096 - word_count]);
-
let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] {
debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2],
&[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
+ // Do a montgomery reduction of `mu`
+
+ // mu % R is just the bottom word_count bytes of mu
let mut mu_mod_r = [0; WORD_COUNT_4096];
mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
+
+ // v = ((mu % R) * negative_modulus_inverse) % R
let mut v = mul(&mu_mod_r, &m_inv);
v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
+
+ // t_on_r = (mu + v*modulus) / R
let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
let (t1, t1_extra_bit) = add_double(&t0, &mu);
+
+ // Note that dividing t1 by R is simply a matter of shifting right by word_count bytes
+ // We only need to maintain word_count bytes (plus `t1_extra_bit` which is implicitly
+ // an extra bit) because t_on_r is guarantee to be, at max, 2*m - 1.
let mut t1_on_r = [0; WORD_COUNT_4096];
debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..],
"t1 should be divisible by r");
t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]);
+
+ // The modulus has only word_count bytes, so if t1_extra_bit is set we are definitely
+ // larger than the modulus.
if t1_extra_bit || t1_on_r >= m.0 {
let underflow;
(t1_on_r, underflow) = sub(&t1_on_r, &m.0);
- debug_assert_eq!(t1_extra_bit, underflow);
+ debug_assert_eq!(t1_extra_bit, underflow,
+ "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
}
t1_on_r
};
}
debug_assert!(r2_mod_m < m.0);
+ // Finally, actually do the exponentiation...
+
// Calculate t * R and a * R as mont multiplications by R^2 mod m
let mut tr = mont_reduction(mul(&r2_mod_m, &t));
let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
}
}
-const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 {
- match b {
- [a, b, c, d, e, f, g, h, ..] => {
- ((*a as u64) << 8*7) |
- ((*b as u64) << 8*6) |
- ((*c as u64) << 8*5) |
- ((*d as u64) << 8*4) |
- ((*e as u64) << 8*3) |
- ((*f as u64) << 8*2) |
- ((*g as u64) << 8*1) |
- ((*h as u64) << 8*0)
- },
- _ => panic!(),
- }
-}
-
-const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 {
- match b {
- [_, _, _, _, _, _, _, _,
- a, b, c, d, e, f, g, h, ..] => {
- ((*a as u64) << 8*7) |
- ((*b as u64) << 8*6) |
- ((*c as u64) << 8*5) |
- ((*d as u64) << 8*4) |
- ((*e as u64) << 8*3) |
- ((*f as u64) << 8*2) |
- ((*g as u64) << 8*1) |
- ((*h as u64) << 8*0)
- },
- _ => panic!(),
- }
-}
-
-const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 {
- match b {
- [_, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- a, b, c, d, e, f, g, h, ..] => {
- ((*a as u64) << 8*7) |
- ((*b as u64) << 8*6) |
- ((*c as u64) << 8*5) |
- ((*d as u64) << 8*4) |
- ((*e as u64) << 8*3) |
- ((*f as u64) << 8*2) |
- ((*g as u64) << 8*1) |
- ((*h as u64) << 8*0)
- },
- _ => panic!(),
- }
-}
-
-const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 {
- match b {
- [_, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- a, b, c, d, e, f, g, h, ..] => {
- ((*a as u64) << 8*7) |
- ((*b as u64) << 8*6) |
- ((*c as u64) << 8*5) |
- ((*d as u64) << 8*4) |
- ((*e as u64) << 8*3) |
- ((*f as u64) << 8*2) |
- ((*g as u64) << 8*1) |
- ((*h as u64) << 8*0)
- },
- _ => panic!(),
- }
-}
-
-const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 {
- match b {
- [_, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- a, b, c, d, e, f, g, h, ..] => {
- ((*a as u64) << 8*7) |
- ((*b as u64) << 8*6) |
- ((*c as u64) << 8*5) |
- ((*d as u64) << 8*4) |
- ((*e as u64) << 8*3) |
- ((*f as u64) << 8*2) |
- ((*g as u64) << 8*1) |
- ((*h as u64) << 8*0)
- },
- _ => panic!(),
- }
-}
-
-const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 {
- match b {
- [_, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- _, _, _, _, _, _, _, _,
- a, b, c, d, e, f, g, h, ..] => {
- ((*a as u64) << 8*7) |
- ((*b as u64) << 8*6) |
- ((*c as u64) << 8*5) |
- ((*d as u64) << 8*4) |
- ((*e as u64) << 8*3) |
- ((*f as u64) << 8*2) |
- ((*g as u64) << 8*1) |
- ((*h as u64) << 8*0)
- },
- _ => panic!(),
- }
+// In a const context we can't subslice a slice, so instead we pick the eight bytes we want and
+// pass them here to build u64s from arrays.
+const fn eight_bytes_to_u64_be(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 {
+ let b = [a, b, c, d, e, f, g, h];
+ u64::from_be_bytes(b)
}
impl U256 {
/// Constructs a new [`U256`] from a fixed number of big-endian bytes.
pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 {
let res = [
- u64_from_bytes_a_panicking(bytes),
- u64_from_bytes_b_panicking(bytes),
- u64_from_bytes_c_panicking(bytes),
- u64_from_bytes_d_panicking(bytes),
+ eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3],
+ bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]),
+ eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3],
+ bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]),
+ eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3],
+ bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]),
+ eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3],
+ bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]),
];
U256(res)
}
pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) }
}
+// Values modulus M::PRIME.0, stored in montgomery form.
impl<M: PrimeModulus<U256>> U256Mod<M> {
const fn mont_reduction(mu: [u64; 8]) -> Self {
#[cfg(debug_assertions)] {
assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
}
+ // mu % R is just the bottom 4 bytes of mu
let mu_mod_r = const_subslice(&mu, 4, 8);
+ // v = ((mu % R) * negative_modulus_inverse) % R
let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
const ZEROS: &[u64; 4] = &[0; 4];
copy_from_slice!(v, 0, 4, ZEROS); // mod R
+
+ // t_on_r = (mu + v*modulus) / R
let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0);
let (t1, t1_extra_bit) = add_8(&t0, &mu);
+
+ // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
+ // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
+ // because t_on_r is guarantee to be, at max, 2*m - 1.
let t1_on_r = const_subslice(&t1, 0, 4);
+
let mut res = [0; 4];
+ // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the
+ // modulus.
if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
let underflow;
(res, underflow) = sub_4(&t1_on_r, &M::PRIME.0);
- debug_assert!(t1_extra_bit == underflow);
+ debug_assert_eq!(t1_extra_bit, underflow,
+ "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
} else {
copy_from_slice!(res, 0, 4, t1_on_r);
}
debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
while v >= M::PRIME {
let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
- debug_assert!(!spurious_underflow);
+ debug_assert!(!spurious_underflow, "v was > M::PRIME.0");
v = U256(new_v);
}
Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
}
}
+// Values modulus M::PRIME.0, stored in montgomery form.
impl U384 {
/// Constructs a new [`U384`] from a variable number of big-endian bytes.
pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U384, ()> {
/// Constructs a new [`U384`] from a fixed number of big-endian bytes.
pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 {
let res = [
- u64_from_bytes_a_panicking(bytes),
- u64_from_bytes_b_panicking(bytes),
- u64_from_bytes_c_panicking(bytes),
- u64_from_bytes_d_panicking(bytes),
- u64_from_bytes_e_panicking(bytes),
- u64_from_bytes_f_panicking(bytes),
+ eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3],
+ bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]),
+ eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3],
+ bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]),
+ eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3],
+ bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]),
+ eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3],
+ bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]),
+ eight_bytes_to_u64_be(bytes[4*8 + 0], bytes[4*8 + 1], bytes[4*8 + 2], bytes[4*8 + 3],
+ bytes[4*8 + 4], bytes[4*8 + 5], bytes[4*8 + 6], bytes[4*8 + 7]),
+ eight_bytes_to_u64_be(bytes[5*8 + 0], bytes[5*8 + 1], bytes[5*8 + 2], bytes[5*8 + 3],
+ bytes[5*8 + 4], bytes[5*8 + 5], bytes[5*8 + 6], bytes[5*8 + 7]),
];
U384(res)
}
assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
}
+ // mu % R is just the bottom 4 bytes of mu
let mu_mod_r = const_subslice(&mu, 6, 12);
+ // v = ((mu % R) * negative_modulus_inverse) % R
let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
const ZEROS: &[u64; 6] = &[0; 6];
copy_from_slice!(v, 0, 6, ZEROS); // mod R
+
+ // t_on_r = (mu + v*modulus) / R
let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0);
let (t1, t1_extra_bit) = add_12(&t0, &mu);
+
+ // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
+ // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
+ // because t_on_r is guarantee to be, at max, 2*m - 1.
let t1_on_r = const_subslice(&t1, 0, 6);
+
let mut res = [0; 6];
+ // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the
+ // modulus.
if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
let underflow;
(res, underflow) = sub_6(&t1_on_r, &M::PRIME.0);