Add U256/U384 and mod-const-prime wrapper utilities of both.
authorMatt Corallo <git@bluematt.me>
Mon, 4 Mar 2024 18:43:52 +0000 (18:43 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 3 Apr 2024 09:15:18 +0000 (09:15 +0000)
In the next commit we'll add support for secp256r1 and secp384r1
validation, which require 256-bit and 384-bit integers. To make
their implementation simple, we also add wrapper structs around
the new integers which are modulo a const-prime, storing and
handling the values in montgommery representation.

src/crypto/bigint.rs

index 62beef434bee5290c28f6991dbb23ca799ec4ef8..c9264cbe5a17151eeb13a404229fc7a81b808db4 100644 (file)
@@ -1,13 +1,54 @@
 //! Simple variable-time big integer implementation
 
 use alloc::vec::Vec;
+use core::marker::PhantomData;
 
 const WORD_COUNT_4096: usize = 4096 / 64;
+const WORD_COUNT_256: usize = 256 / 64;
+const WORD_COUNT_384: usize = 384 / 64;
 
 // RFC 5702 indicates RSA keys can be up to 4096 bits
 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
 pub(super) struct U4096([u64; WORD_COUNT_4096]);
 
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U256([u64; WORD_COUNT_256]);
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U384([u64; WORD_COUNT_384]);
+
+pub(super) trait Int: Clone + Ord + Sized {
+       const ZERO: Self;
+       const BYTES: usize;
+       fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
+       fn limbs(&self) -> &[u64];
+}
+impl Int for U256 {
+       const ZERO: U256 = U256([0; 4]);
+       const BYTES: usize = 32;
+       fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+       fn limbs(&self) -> &[u64] { &self.0 }
+}
+impl Int for U384 {
+       const ZERO: U384 = U384([0; 6]);
+       const BYTES: usize = 48;
+       fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+       fn limbs(&self) -> &[u64] { &self.0 }
+}
+
+/// Defines a *PRIME* Modulus
+pub(super) trait PrimeModulus<I: Int> {
+       const PRIME: I;
+       const R_SQUARED_MOD_PRIME: I;
+       const NEGATIVE_PRIME_INV_MOD_R: I;
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
+
 macro_rules! debug_unwrap { ($v: expr) => { {
        let v = $v;
        debug_assert!(v.is_ok());
@@ -150,14 +191,17 @@ macro_rules! define_add { ($name: ident, $len: expr) => {
 } }
 
 define_add!(add_2, 2);
+define_add!(add_3, 3);
 define_add!(add_4, 4);
+define_add!(add_6, 6);
 define_add!(add_8, 8);
+define_add!(add_12, 12);
 define_add!(add_16, 16);
 define_add!(add_32, 32);
 define_add!(add_64, 64);
 define_add!(add_128, 128);
 
-macro_rules! define_sub { ($name: ident, $len: expr) => {
+macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
        /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
        /// $len-64-bit integer and an overflow bit, with the same semantics as the std
        /// [`u64::overflowing_sub`] method.
@@ -178,16 +222,29 @@ macro_rules! define_sub { ($name: ident, $len: expr) => {
                }
                (r, carry)
        }
+
+       /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
+       /// $len-64-bit integer representing the absolute value of the result, as well as a sign bit.
+       #[allow(unused)]
+       const fn $name_abs(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
+               let (mut res, neg) = $name(a, b);
+               if neg {
+                       negate!(res);
+               }
+               (res, neg)
+       }
 } }
 
-define_sub!(sub_2, 2);
-define_sub!(sub_4, 4);
-define_sub!(sub_8, 8);
-define_sub!(sub_16, 16);
-define_sub!(sub_32, 32);
-define_sub!(sub_64, 64);
-#[cfg(debug_assertions)]
-define_sub!(sub_128, 128);
+define_sub!(sub_2, sub_abs_2, 2);
+define_sub!(sub_3, sub_abs_3, 3);
+define_sub!(sub_4, sub_abs_4, 4);
+define_sub!(sub_6, sub_abs_6, 6);
+define_sub!(sub_8, sub_abs_8, 8);
+define_sub!(sub_12, sub_abs_12, 12);
+define_sub!(sub_16, sub_abs_16, 16);
+define_sub!(sub_32, sub_abs_32, 32);
+define_sub!(sub_64, sub_abs_64, 64);
+define_sub!(sub_128, sub_abs_128, 128);
 
 /// Multiplies two 128-bit integers together, returning a new 256-bit integer.
 ///
@@ -232,6 +289,76 @@ const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
        [i, j, k, l]
 }
 
+const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
+       debug_assert!(a.len() == 3);
+       debug_assert!(b.len() == 3);
+
+       let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128);
+       let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128);
+
+       let m4 = a2 * b2;
+       let m3a = a2 * b1;
+       let m3b = a1 * b2;
+       let m2a = a2 * b0;
+       let m2b = a1 * b1;
+       let m2c = a0 * b2;
+       let m1a = a1 * b0;
+       let m1b = a0 * b1;
+       let m0 = a0 * b0;
+
+       let r5 = ((m4 >> 0) & 0xffff_ffff_ffff_ffff) as u64;
+
+       let r4a = ((m4 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r4b = ((m3a >> 0) & 0xffff_ffff_ffff_ffff) as u64;
+       let r4c = ((m3b >> 0) & 0xffff_ffff_ffff_ffff) as u64;
+
+       let r3a = ((m3a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r3b = ((m3b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r3c = ((m2a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+       let r3d = ((m2b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+       let r3e = ((m2c >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+
+       let r2a = ((m2a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r2b = ((m2b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r2c = ((m2c >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r2d = ((m1a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+       let r2e = ((m1b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+
+       let r1a = ((m1a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r1b = ((m1b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+       let r1c = ((m0  >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
+
+       let r0a = ((m0  >> 64) & 0xffff_ffff_ffff_ffff) as u64;
+
+       let (r4, r3_ca) = r4a.overflowing_add(r4b);
+       let (r4, r3_cb) = r4.overflowing_add(r4c);
+       let r3_c = r3_ca as u64 + r3_cb as u64;
+
+       let (r3, r2_ca) = r3a.overflowing_add(r3b);
+       let (r3, r2_cb) = r3.overflowing_add(r3c);
+       let (r3, r2_cc) = r3.overflowing_add(r3d);
+       let (r3, r2_cd) = r3.overflowing_add(r3e);
+       let (r3, r2_ce) = r3.overflowing_add(r3_c);
+       let r2_c = r2_ca as u64 + r2_cb as u64 + r2_cc as u64 + r2_cd as u64 + r2_ce as u64;
+
+       let (r2, r1_ca) = r2a.overflowing_add(r2b);
+       let (r2, r1_cb) = r2.overflowing_add(r2c);
+       let (r2, r1_cc) = r2.overflowing_add(r2d);
+       let (r2, r1_cd) = r2.overflowing_add(r2e);
+       let (r2, r1_ce) = r2.overflowing_add(r2_c);
+       let r1_c = r1_ca as u64 + r1_cb as u64 + r1_cc as u64 + r1_cd as u64 + r1_ce as u64;
+
+       let (r1, r0_ca) = r1a.overflowing_add(r1b);
+       let (r1, r0_cb) = r1.overflowing_add(r1c);
+       let (r1, r0_cc) = r1.overflowing_add(r1_c);
+       let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64;
+
+       let (r0, must_not_overflow) = r0a.overflowing_add(r0_c);
+       debug_assert!(!must_not_overflow);
+
+       [r0, r1, r2, r3, r4, r5]
+}
+
 macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
        /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
        const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
@@ -303,6 +430,7 @@ macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident
 } }
 
 define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
+define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3);
 define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
 define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
 define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
@@ -393,13 +521,16 @@ macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: id
        }
 } }
 
+// TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
+const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) }
+
 define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
+define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
 define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
 define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
 define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
 define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
 
-#[cfg(fuzzing)]
 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
 
@@ -447,14 +578,85 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init
 
 #[cfg(fuzzing)]
 define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
-#[cfg(fuzzing)]
 define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
-#[cfg(fuzzing)]
+define_div_rem!(div_rem_6, 6, sub_6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack!
+#[cfg(debug_assertions)]
 define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
+#[cfg(debug_assertions)]
+define_div_rem!(div_rem_12, 12, sub_12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack!
 define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
 #[cfg(debug_assertions)]
 define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
 
+macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => {
+       /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus,
+       /// if one exists.
+       const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> {
+               if slice_equal(a, &[0; $len]) || slice_equal(m, &[0; $len]) { return Err(()); }
+
+               let (mut s, mut old_s) = ([0; $len], [0; $len]);
+               old_s[$len - 1] = 1;
+               let mut r = *m;
+               let mut old_r = *a;
+
+               let (mut old_s_neg, mut s_neg) = (false, false);
+
+               while !slice_equal(&r, &[0; $len]) {
+                       let (quot, new_r) = debug_unwrap!($div(&old_r, &r));
+
+                       let new_sa = $mul(&quot, &s);
+                       debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
+                       let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
+                               (true, true) => {
+                                       let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+                                       debug_assert!(!overflow);
+                                       (new_s, true)
+                               }
+                               (false, true) => {
+                                       let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+                                       debug_assert!(!overflow);
+                                       (new_s, false)
+                               },
+                               (true, false) => {
+                                       let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
+                                       debug_assert!(!overflow);
+                                       (new_s, true)
+                               },
+                               (false, false) => $sub_abs(&old_s, const_subslice(&new_sa, $len, new_sa.len())),
+                       };
+
+                       old_r = r;
+                       r = new_r;
+
+                       old_s = s;
+                       old_s_neg = s_neg;
+                       s = new_s;
+                       s_neg = new_s_neg;
+               }
+
+               // At this point old_r contains our GCD and old_s our first Bézout's identity coefficient.
+               if !slice_equal(const_subslice(&old_r, 0, $len - 1), &[0; $len - 1]) || old_r[$len - 1] != 1 {
+                       Err(())
+               } else {
+                       debug_assert!(slice_greater_than(m, &old_s));
+                       if old_s_neg {
+                               let (modinv, underflow) = $sub_abs(m, &old_s);
+                               debug_assert!(!underflow);
+                               debug_assert!(slice_greater_than(m, &modinv));
+                               Ok(modinv)
+                       } else {
+                               Ok(old_s)
+                       }
+               }
+       }
+} }
+#[cfg(fuzzing)]
+define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2);
+define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4);
+define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6);
+#[cfg(fuzzing)]
+define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
+
 impl U4096 {
        /// Constructs a new [`U4096`] from a variable number of big-endian bytes.
        pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
@@ -710,6 +912,505 @@ impl U4096 {
        }
 }
 
+const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 {
+       match b {
+               [a, b, c, d, e, f, g, h, ..] => {
+                       ((*a as u64) << 8*7) |
+                       ((*b as u64) << 8*6) |
+                       ((*c as u64) << 8*5) |
+                       ((*d as u64) << 8*4) |
+                       ((*e as u64) << 8*3) |
+                       ((*f as u64) << 8*2) |
+                       ((*g as u64) << 8*1) |
+                       ((*h as u64) << 8*0)
+               },
+               _ => panic!(),
+       }
+}
+
+const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 {
+       match b {
+               [_, _, _, _, _, _, _, _,
+                a, b, c, d, e, f, g, h, ..] => {
+                       ((*a as u64) << 8*7) |
+                       ((*b as u64) << 8*6) |
+                       ((*c as u64) << 8*5) |
+                       ((*d as u64) << 8*4) |
+                       ((*e as u64) << 8*3) |
+                       ((*f as u64) << 8*2) |
+                       ((*g as u64) << 8*1) |
+                       ((*h as u64) << 8*0)
+               },
+               _ => panic!(),
+       }
+}
+
+const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 {
+       match b {
+               [_, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                a, b, c, d, e, f, g, h, ..] => {
+                       ((*a as u64) << 8*7) |
+                       ((*b as u64) << 8*6) |
+                       ((*c as u64) << 8*5) |
+                       ((*d as u64) << 8*4) |
+                       ((*e as u64) << 8*3) |
+                       ((*f as u64) << 8*2) |
+                       ((*g as u64) << 8*1) |
+                       ((*h as u64) << 8*0)
+               },
+               _ => panic!(),
+       }
+}
+
+const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 {
+       match b {
+               [_, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                a, b, c, d, e, f, g, h, ..] => {
+                       ((*a as u64) << 8*7) |
+                       ((*b as u64) << 8*6) |
+                       ((*c as u64) << 8*5) |
+                       ((*d as u64) << 8*4) |
+                       ((*e as u64) << 8*3) |
+                       ((*f as u64) << 8*2) |
+                       ((*g as u64) << 8*1) |
+                       ((*h as u64) << 8*0)
+               },
+               _ => panic!(),
+       }
+}
+
+const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 {
+       match b {
+               [_, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                a, b, c, d, e, f, g, h, ..] => {
+                       ((*a as u64) << 8*7) |
+                       ((*b as u64) << 8*6) |
+                       ((*c as u64) << 8*5) |
+                       ((*d as u64) << 8*4) |
+                       ((*e as u64) << 8*3) |
+                       ((*f as u64) << 8*2) |
+                       ((*g as u64) << 8*1) |
+                       ((*h as u64) << 8*0)
+               },
+               _ => panic!(),
+       }
+}
+
+const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 {
+       match b {
+               [_, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                _, _, _, _, _, _, _, _,
+                a, b, c, d, e, f, g, h, ..] => {
+                       ((*a as u64) << 8*7) |
+                       ((*b as u64) << 8*6) |
+                       ((*c as u64) << 8*5) |
+                       ((*d as u64) << 8*4) |
+                       ((*e as u64) << 8*3) |
+                       ((*f as u64) << 8*2) |
+                       ((*g as u64) << 8*1) |
+                       ((*h as u64) << 8*0)
+               },
+               _ => panic!(),
+       }
+}
+
+impl U256 {
+       /// Constructs a new [`U256`] from a variable number of big-endian bytes.
+       pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U256, ()> {
+               if bytes.len() > 256/8 { return Err(()); }
+               let u64s = (bytes.len() + 7) / 8;
+               let mut res = [0; WORD_COUNT_256];
+               for i in 0..u64s {
+                       let mut b = [0; 8];
+                       let pos = (u64s - i) * 8;
+                       let start = bytes.len().saturating_sub(pos);
+                       let end = bytes.len() + 8 - pos;
+                       b[8 + start - end..].copy_from_slice(&bytes[start..end]);
+                       res[i + WORD_COUNT_256 - u64s] = u64::from_be_bytes(b);
+               }
+               Ok(U256(res))
+       }
+
+       /// Constructs a new [`U256`] from a fixed number of big-endian bytes.
+       pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 {
+               let res = [
+                       u64_from_bytes_a_panicking(bytes),
+                       u64_from_bytes_b_panicking(bytes),
+                       u64_from_bytes_c_panicking(bytes),
+                       u64_from_bytes_d_panicking(bytes),
+               ];
+               U256(res)
+       }
+
+       pub(super) const fn zero() -> U256 { U256([0, 0, 0, 0]) }
+       pub(super) const fn one() -> U256 { U256([0, 0, 0, 1]) }
+       pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) }
+}
+
+impl<M: PrimeModulus<U256>> U256Mod<M> {
+       const fn mont_reduction(mu: [u64; 8]) -> Self {
+               #[cfg(debug_assertions)] {
+                       // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
+                       // should be able to do it at compile time alone.
+                       let minus_one_mod_r = mul_4(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+                       assert!(slice_equal(const_subslice(&minus_one_mod_r, 4, 8), &[0xffff_ffff_ffff_ffff; 4]));
+               }
+
+               #[cfg(debug_assertions)] {
+                       // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
+                       // should be able to do it at compile time alone.
+                       let r_minus_one = [0xffff_ffff_ffff_ffff; 4];
+                       let (mut r_mod_prime, _) = sub_4(&r_minus_one, &M::PRIME.0);
+                       add_one!(r_mod_prime);
+                       let r_squared = sqr_4(&r_mod_prime);
+                       let mut prime_extended = [0; 8];
+                       let prime = M::PRIME.0;
+                       copy_from_slice!(prime_extended, 4, 8, prime);
+                       let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_8(&r_squared, &prime_extended) { v } else { panic!() };
+                       assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
+                       assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
+               }
+
+               let mu_mod_r = const_subslice(&mu, 4, 8);
+               let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+               const ZEROS: &[u64; 4] = &[0; 4];
+               copy_from_slice!(v, 0, 4, ZEROS); // mod R
+               let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0);
+               let (t1, t1_extra_bit) = add_8(&t0, &mu);
+               let t1_on_r = const_subslice(&t1, 0, 4);
+               let mut res = [0; 4];
+               if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
+                       let underflow;
+                       (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0);
+                       debug_assert!(t1_extra_bit == underflow);
+               } else {
+                       copy_from_slice!(res, 0, 4, t1_on_r);
+               }
+               Self(U256(res), PhantomData)
+       }
+
+       pub(super) const fn from_u256_panicking(v: U256) -> Self {
+               assert!(v.0[0] <= M::PRIME.0[0]);
+               if v.0[0] == M::PRIME.0[0] {
+                       assert!(v.0[1] <= M::PRIME.0[1]);
+                       if v.0[1] == M::PRIME.0[1] {
+                               assert!(v.0[2] <= M::PRIME.0[2]);
+                               if v.0[2] == M::PRIME.0[2] {
+                                       assert!(v.0[3] < M::PRIME.0[3]);
+                               }
+                       }
+               }
+               assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0 || M::PRIME.0[3] != 0);
+               Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+       }
+
+       pub(super) fn from_u256(mut v: U256) -> Self {
+               debug_assert!(M::PRIME.0 != [0; 4]);
+               debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
+               while v >= M::PRIME {
+                       let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
+                       debug_assert!(!spurious_underflow);
+                       v = U256(new_v);
+               }
+               Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+       }
+
+       pub(super) fn from_modinv_of(v: U256) -> Result<Self, ()> {
+               Ok(Self::from_u256(U256(mod_inv_4(&v.0, &M::PRIME.0)?)))
+       }
+
+       /// Multiplies `self` * `b` mod `m`.
+       ///
+       /// Panics if `self`'s modulus is not equal to `b`'s
+       pub(super) fn mul(&self, b: &Self) -> Self {
+               Self::mont_reduction(mul_4(&self.0.0, &b.0.0))
+       }
+
+       /// Doubles `self` mod `m`.
+       pub(super) fn double(&self) -> Self {
+               let mut res = self.0.0;
+               let overflow = double!(res);
+               if overflow || !slice_greater_than(&M::PRIME.0, &res) {
+                       let underflow;
+                       (res, underflow) = sub_4(&res, &M::PRIME.0);
+                       debug_assert_eq!(overflow, underflow);
+               }
+               Self(U256(res), PhantomData)
+       }
+
+       /// Multiplies `self` by 3 mod `m`.
+       pub(super) fn times_three(&self) -> Self {
+               // TODO: Optimize this a lot
+               self.mul(&U256Mod::from_u256(U256::three()))
+       }
+
+       /// Multiplies `self` by 4 mod `m`.
+       pub(super) fn times_four(&self) -> Self {
+               // TODO: Optimize this somewhat?
+               self.double().double()
+       }
+
+       /// Multiplies `self` by 8 mod `m`.
+       pub(super) fn times_eight(&self) -> Self {
+               // TODO: Optimize this somewhat?
+               self.double().double().double()
+       }
+
+       /// Multiplies `self` by 8 mod `m`.
+       pub(super) fn square(&self) -> Self {
+               Self::mont_reduction(sqr_4(&self.0.0))
+       }
+
+       /// Subtracts `b` from `self` % `m`.
+       pub(super) fn sub(&self, b: &Self) -> Self {
+               let (mut val, underflow) = sub_4(&self.0.0, &b.0.0);
+               if underflow {
+                       let overflow;
+                       (val, overflow) = add_4(&val, &M::PRIME.0);
+                       debug_assert_eq!(overflow, underflow);
+               }
+               Self(U256(val), PhantomData)
+       }
+
+       /// Adds `b` to `self` % `m`.
+       pub(super) fn add(&self, b: &Self) -> Self {
+               let (mut val, overflow) = add_4(&self.0.0, &b.0.0);
+               if overflow || !slice_greater_than(&M::PRIME.0, &val) {
+                       let underflow;
+                       (val, underflow) = sub_4(&val, &M::PRIME.0);
+                       debug_assert_eq!(overflow, underflow);
+               }
+               Self(U256(val), PhantomData)
+       }
+
+       /// Returns the underlying [`U256`].
+       pub(super) fn into_u256(self) -> U256 {
+               let mut expanded_self = [0; 8];
+               expanded_self[4..].copy_from_slice(&self.0.0);
+               Self::mont_reduction(expanded_self).0
+       }
+}
+
+impl U384 {
+       /// Constructs a new [`U384`] from a variable number of big-endian bytes.
+       pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U384, ()> {
+               if bytes.len() > 384/8 { return Err(()); }
+               let u64s = (bytes.len() + 7) / 8;
+               let mut res = [0; WORD_COUNT_384];
+               for i in 0..u64s {
+                       let mut b = [0; 8];
+                       let pos = (u64s - i) * 8;
+                       let start = bytes.len().saturating_sub(pos);
+                       let end = bytes.len() + 8 - pos;
+                       b[8 + start - end..].copy_from_slice(&bytes[start..end]);
+                       res[i + WORD_COUNT_384 - u64s] = u64::from_be_bytes(b);
+               }
+               Ok(U384(res))
+       }
+
+       /// Constructs a new [`U384`] from a fixed number of big-endian bytes.
+       pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 {
+               let res = [
+                       u64_from_bytes_a_panicking(bytes),
+                       u64_from_bytes_b_panicking(bytes),
+                       u64_from_bytes_c_panicking(bytes),
+                       u64_from_bytes_d_panicking(bytes),
+                       u64_from_bytes_e_panicking(bytes),
+                       u64_from_bytes_f_panicking(bytes),
+               ];
+               U384(res)
+       }
+
+       pub(super) const fn zero() -> U384 { U384([0, 0, 0, 0, 0, 0]) }
+       pub(super) const fn one() -> U384 { U384([0, 0, 0, 0, 0, 1]) }
+       pub(super) const fn three() -> U384 { U384([0, 0, 0, 0, 0, 3]) }
+}
+
+impl<M: PrimeModulus<U384>> U384Mod<M> {
+       const fn mont_reduction(mu: [u64; 12]) -> Self {
+               #[cfg(debug_assertions)] {
+                       // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
+                       // should be able to do it at compile time alone.
+                       let minus_one_mod_r = mul_6(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+                       assert!(slice_equal(const_subslice(&minus_one_mod_r, 6, 12), &[0xffff_ffff_ffff_ffff; 6]));
+               }
+
+               #[cfg(debug_assertions)] {
+                       // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
+                       // should be able to do it at compile time alone.
+                       let r_minus_one = [0xffff_ffff_ffff_ffff; 6];
+                       let (mut r_mod_prime, _) = sub_6(&r_minus_one, &M::PRIME.0);
+                       add_one!(r_mod_prime);
+                       let r_squared = sqr_6(&r_mod_prime);
+                       let mut prime_extended = [0; 12];
+                       let prime = M::PRIME.0;
+                       copy_from_slice!(prime_extended, 6, 12, prime);
+                       let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_12(&r_squared, &prime_extended) { v } else { panic!() };
+                       assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
+                       assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
+               }
+
+               let mu_mod_r = const_subslice(&mu, 6, 12);
+               let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
+               const ZEROS: &[u64; 6] = &[0; 6];
+               copy_from_slice!(v, 0, 6, ZEROS); // mod R
+               let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0);
+               let (t1, t1_extra_bit) = add_12(&t0, &mu);
+               let t1_on_r = const_subslice(&t1, 0, 6);
+               let mut res = [0; 6];
+               if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
+                       let underflow;
+                       (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0);
+                       debug_assert!(t1_extra_bit == underflow);
+               } else {
+                       copy_from_slice!(res, 0, 6, t1_on_r);
+               }
+               Self(U384(res), PhantomData)
+       }
+
+       pub(super) const fn from_u384_panicking(v: U384) -> Self {
+               assert!(v.0[0] <= M::PRIME.0[0]);
+               if v.0[0] == M::PRIME.0[0] {
+                       assert!(v.0[1] <= M::PRIME.0[1]);
+                       if v.0[1] == M::PRIME.0[1] {
+                               assert!(v.0[2] <= M::PRIME.0[2]);
+                               if v.0[2] == M::PRIME.0[2] {
+                                       assert!(v.0[3] <= M::PRIME.0[3]);
+                                       if v.0[3] == M::PRIME.0[3] {
+                                               assert!(v.0[4] <= M::PRIME.0[4]);
+                                               if v.0[4] == M::PRIME.0[4] {
+                                                       assert!(v.0[5] < M::PRIME.0[5]);
+                                               }
+                                       }
+                               }
+                       }
+               }
+               assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0
+                       || M::PRIME.0[3] != 0|| M::PRIME.0[4] != 0|| M::PRIME.0[5] != 0);
+               Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+       }
+
+       pub(super) fn from_u384(mut v: U384) -> Self {
+               debug_assert!(M::PRIME.0 != [0; 6]);
+               debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
+               while v >= M::PRIME {
+                       let (new_v, spurious_underflow) = sub_6(&v.0, &M::PRIME.0);
+                       debug_assert!(!spurious_underflow);
+                       v = U384(new_v);
+               }
+               Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
+       }
+
+       pub(super) fn from_modinv_of(v: U384) -> Result<Self, ()> {
+               Ok(Self::from_u384(U384(mod_inv_6(&v.0, &M::PRIME.0)?)))
+       }
+
+       /// Multiplies `self` * `b` mod `m`.
+       ///
+       /// Panics if `self`'s modulus is not equal to `b`'s
+       pub(super) fn mul(&self, b: &Self) -> Self {
+               Self::mont_reduction(mul_6(&self.0.0, &b.0.0))
+       }
+
+       /// Doubles `self` mod `m`.
+       pub(super) fn double(&self) -> Self {
+               let mut res = self.0.0;
+               let overflow = double!(res);
+               if overflow || !slice_greater_than(&M::PRIME.0, &res) {
+                       let underflow;
+                       (res, underflow) = sub_6(&res, &M::PRIME.0);
+                       debug_assert_eq!(overflow, underflow);
+               }
+               Self(U384(res), PhantomData)
+       }
+
+       /// Multiplies `self` by 3 mod `m`.
+       pub(super) fn times_three(&self) -> Self {
+               // TODO: Optimize this a lot
+               self.mul(&U384Mod::from_u384(U384::three()))
+       }
+
+       /// Multiplies `self` by 4 mod `m`.
+       pub(super) fn times_four(&self) -> Self {
+               // TODO: Optimize this somewhat?
+               self.double().double()
+       }
+
+       /// Multiplies `self` by 8 mod `m`.
+       pub(super) fn times_eight(&self) -> Self {
+               // TODO: Optimize this somewhat?
+               self.double().double().double()
+       }
+
+       /// Multiplies `self` by 8 mod `m`.
+       pub(super) fn square(&self) -> Self {
+               Self::mont_reduction(sqr_6(&self.0.0))
+       }
+
+       /// Subtracts `b` from `self` % `m`.
+       pub(super) fn sub(&self, b: &Self) -> Self {
+               let (mut val, underflow) = sub_6(&self.0.0, &b.0.0);
+               if underflow {
+                       let overflow;
+                       (val, overflow) = add_6(&val, &M::PRIME.0);
+                       debug_assert_eq!(overflow, underflow);
+               }
+               Self(U384(val), PhantomData)
+       }
+
+       /// Adds `b` to `self` % `m`.
+       pub(super) fn add(&self, b: &Self) -> Self {
+               let (mut val, overflow) = add_6(&self.0.0, &b.0.0);
+               if overflow || !slice_greater_than(&M::PRIME.0, &val) {
+                       let underflow;
+                       (val, underflow) = sub_6(&val, &M::PRIME.0);
+                       debug_assert_eq!(overflow, underflow);
+               }
+               Self(U384(val), PhantomData)
+       }
+
+       /// Returns the underlying [`U384`].
+       pub(super) fn into_u384(self) -> U384 {
+               let mut expanded_self = [0; 12];
+               expanded_self[6..].copy_from_slice(&self.0.0);
+               Self::mont_reduction(expanded_self).0
+       }
+}
+
+#[cfg(fuzzing)]
+mod fuzz_moduli {
+       use super::*;
+
+       pub struct P256();
+       impl PrimeModulus<U256> for P256 {
+               const PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
+                       "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"));
+               const R_SQUARED_MOD_PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
+                       "00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003"));
+               const NEGATIVE_PRIME_INV_MOD_R: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
+                       "ffffffff00000002000000000000000000000001000000000000000000000001"));
+       }
+
+       pub struct P384();
+       impl PrimeModulus<U384> for P384 {
+               const PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
+                       "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff"));
+               const R_SQUARED_MOD_PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
+                       "000000000000000000000000000000010000000200000000fffffffe000000000000000200000000fffffffe00000001"));
+               const NEGATIVE_PRIME_INV_MOD_R: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
+                       "00000014000000140000000c00000002fffffffcfffffffafffffffbfffffffe00000000000000010000000100000001"));
+       }
+}
+
 #[cfg(fuzzing)]
 extern crate ibig;
 #[cfg(fuzzing)]
@@ -732,7 +1433,7 @@ pub fn fuzz_math(input: &[u8]) {
                b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
        }
 
-       macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident) => {
+       macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
                let res = $mul(&a_u64s, &b_u64s);
                let mut res_bytes = Vec::with_capacity(input.len() / 2);
                for i in res {
@@ -784,14 +1485,72 @@ pub fn fuzz_math(input: &[u8]) {
                let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi);
                assert_eq!(ibig::UBig::from_be_bytes(&quot_bytes), quoti);
                assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi);
+
+               if ai != ibig::UBig::from(0u32) { // ibig provides a spurious modular inverse for 0
+                       let ring = ibig::modular::ModuloRing::new(&bi);
+                       let ar = ring.from(ai.clone());
+                       let invi = ar.inverse().map(|i| i.residue());
+
+                       if let Ok(modinv) = $mod_inv(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
+                               let mut modinv_bytes = Vec::with_capacity(input.len() / 2);
+                               for i in modinv {
+                                       modinv_bytes.extend_from_slice(&i.to_be_bytes());
+                               }
+                               assert_eq!(invi.unwrap(), ibig::UBig::from_be_bytes(&modinv_bytes));
+                       } else {
+                               assert!(invi.is_none());
+                       }
+               }
+       } }
+
+       macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => {
+               // Test the U256/U384Mod wrapper, which operates in Montgomery representation
+               let mut p_extended = [0; $len * 2];
+               p_extended[$len..].copy_from_slice(&$PRIME);
+
+               let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
+               assert_eq!(&amodp_squared[..$len], &[0; $len]);
+               assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
+
+               let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1;
+               assert_eq!(&abmodp[..$len], &[0; $len]);
+               assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
+
+               let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s);
+               let mut aplusb_extended = [0; $len * 2];
+               aplusb_extended[$len..].copy_from_slice(&aplusb);
+               if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
+               let aplusbmodp = $div_rem_double(&aplusb_extended, &p_extended).unwrap().1;
+               assert_eq!(&aplusbmodp[..$len], &[0; $len]);
+               assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
+
+               let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s);
+               if aminusb_underflow {
+                       let mut overflow;
+                       (aminusb, overflow) = $add(&aminusb, &$PRIME);
+                       if !overflow {
+                               (aminusb, overflow) = $add(&aminusb, &$PRIME);
+                       }
+                       assert!(overflow);
+               }
+               let aminusbmodp = $div_rem(&aminusb, &$PRIME).unwrap().1;
+               assert_eq!(&$amodp.sub(&$bmodp).$into().0, &aminusbmodp);
        } }
 
        if a_u64s.len() == 2 {
-               test!(mul_2, sqr_2, add_2, sub_2, div_rem_2);
+               test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2);
        } else if a_u64s.len() == 4 {
-               test!(mul_4, sqr_4, add_4, sub_4, div_rem_4);
+               test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4);
+               let amodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(a_u64s[..].try_into().unwrap()));
+               let bmodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(b_u64s[..].try_into().unwrap()));
+               test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4);
+       } else if a_u64s.len() == 6 {
+               test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6);
+               let amodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(a_u64s[..].try_into().unwrap()));
+               let bmodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(b_u64s[..].try_into().unwrap()));
+               test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6);
        } else if a_u64s.len() == 8 {
-               test!(mul_8, sqr_8, add_8, sub_8, div_rem_8);
+               test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8);
        } else if input.len() == 512*2 + 4 {
                let mut e_bytes = [0; 4];
                e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);