Clean up and better comment math somewhat further
authorMatt Corallo <git@bluematt.me>
Fri, 3 May 2024 18:41:46 +0000 (18:41 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 3 May 2024 18:48:43 +0000 (18:48 +0000)
src/crypto/bigint.rs

index d07bb5622789c1ad45085516af5f1d9cfd8f1914..7c187ca272284d5df13d60ad0e34db14a37efc39 100644 (file)
@@ -3,51 +3,9 @@
 use alloc::vec::Vec;
 use core::marker::PhantomData;
 
-const WORD_COUNT_4096: usize = 4096 / 64;
-const WORD_COUNT_256: usize = 256 / 64;
-const WORD_COUNT_384: usize = 384 / 64;
-
-// RFC 5702 indicates RSA keys can be up to 4096 bits
-#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub(super) struct U4096([u64; WORD_COUNT_4096]);
-
-#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub(super) struct U256([u64; WORD_COUNT_256]);
-
-#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
-pub(super) struct U384([u64; WORD_COUNT_384]);
-
-pub(super) trait Int: Clone + Ord + Sized {
-       const ZERO: Self;
-       const BYTES: usize;
-       fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
-       fn limbs(&self) -> &[u64];
-}
-impl Int for U256 {
-       const ZERO: U256 = U256([0; 4]);
-       const BYTES: usize = 32;
-       fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
-       fn limbs(&self) -> &[u64] { &self.0 }
-}
-impl Int for U384 {
-       const ZERO: U384 = U384([0; 6]);
-       const BYTES: usize = 48;
-       fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
-       fn limbs(&self) -> &[u64] { &self.0 }
-}
-
-/// Defines a *PRIME* Modulus
-pub(super) trait PrimeModulus<I: Int> {
-       const PRIME: I;
-       const R_SQUARED_MOD_PRIME: I;
-       const NEGATIVE_PRIME_INV_MOD_R: I;
-}
-
-#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
-pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
-
-#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
-pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
+// **************************************
+// * Implementations of math primitives *
+// **************************************
 
 macro_rules! debug_unwrap { ($v: expr) => { {
        let v = $v;
@@ -522,6 +480,11 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init
        $($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> {
                if slice_equal(b, &[0; $len]) { return Err(()); }
 
+               // Very naively divide `a` by `b` by calculating all the powers of two times `b` up to `a`,
+               // then subtracting the powers of two in decreasing order. What's left is the remainder.
+               //
+               // This requires storing all the multiples of `b` in `pow2s`, which may be a vec or an
+               // array. `$pre_push!()` sets up the next element with zeros and then we can overwrite it.
                let mut b_pow = *b;
                let mut pow2s = $heap_init;
                let mut pow2s_count = 0;
@@ -538,10 +501,10 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init
                while pow2 >= 0 {
                        let b_pow = pow2s[pow2 as usize];
                        let overflow = double!(quot);
-                       debug_assert!(!overflow);
+                       debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
                        if slice_greater_than(&rem, &b_pow) {
-                               let (r, carry) = $sub(&rem, &b_pow);
-                               debug_assert!(!carry);
+                               let (r, underflow) = $sub(&rem, &b_pow);
+                               debug_assert!(!underflow, "rem was just checked to be > b_pow, so sub cannot underflow");
                                rem = r;
                                quot[$len - 1] |= 1;
                        }
@@ -549,7 +512,7 @@ macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init
                }
                if slice_equal(&rem, b) {
                        let overflow = add_u64!(quot, 1);
-                       debug_assert!(!overflow);
+                       debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
                        Ok((quot, [0; $len]))
                } else {
                        Ok((quot, rem))
@@ -638,6 +601,56 @@ define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6);
 #[cfg(fuzzing)]
 define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
 
+// ******************
+// * The public API *
+// ******************
+
+const WORD_COUNT_4096: usize = 4096 / 64;
+const WORD_COUNT_256: usize = 256 / 64;
+const WORD_COUNT_384: usize = 384 / 64;
+
+// RFC 5702 indicates RSA keys can be up to 4096 bits, so we always use 4096-bit integers
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U4096([u64; WORD_COUNT_4096]);
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U256([u64; WORD_COUNT_256]);
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub(super) struct U384([u64; WORD_COUNT_384]);
+
+pub(super) trait Int: Clone + Ord + Sized {
+       const ZERO: Self;
+       const BYTES: usize;
+       fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
+       fn limbs(&self) -> &[u64];
+}
+impl Int for U256 {
+       const ZERO: U256 = U256([0; 4]);
+       const BYTES: usize = 32;
+       fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+       fn limbs(&self) -> &[u64] { &self.0 }
+}
+impl Int for U384 {
+       const ZERO: U384 = U384([0; 6]);
+       const BYTES: usize = 48;
+       fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
+       fn limbs(&self) -> &[u64] { &self.0 }
+}
+
+/// Defines a *PRIME* Modulus
+pub(super) trait PrimeModulus<I: Int> {
+       const PRIME: I;
+       const R_SQUARED_MOD_PRIME: I;
+       const NEGATIVE_PRIME_INV_MOD_R: I;
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
+
+#[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
+pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
+
 impl U4096 {
        /// Constructs a new [`U4096`] from a variable number of big-endian bytes.
        pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
@@ -694,6 +707,10 @@ impl U4096 {
                // Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is
                // guaranteed to be co-prime with any non-even integer.
 
+               // We use a single 4096-bit integer type for all our RSA operations, though in most cases
+               // we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math
+               // here which debug_assert's the required bits are 0s and then uses faster math primitives.
+
                type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
                type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
                type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
@@ -785,6 +802,7 @@ impl U4096 {
                                (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
                        };
 
+               // r is always the even value with one bit set above the word count we're using.
                let mut r = [0; WORD_COUNT_4096 * 2];
                r[WORD_COUNT_4096 * 2 - word_count - 1] = 1;
 
@@ -800,7 +818,7 @@ impl U4096 {
                }
                m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0);
 
-               // We want the negative modular inverse of m mod R, so subtract m_inv from R.
+               // `m_inv` is the negative modular inverse of m mod R, so subtract m_inv from R.
                let mut m_inv = m_inv_pos;
                negate!(m_inv);
                m_inv[..WORD_COUNT_4096 - word_count].fill(0);
@@ -808,25 +826,38 @@ impl U4096 {
                        // R - 1 == -1 % R
                        &[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]);
 
-               debug_assert_eq!(&m_inv[..WORD_COUNT_4096 - word_count], &[0; WORD_COUNT_4096][..WORD_COUNT_4096 - word_count]);
-
                let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] {
                        debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2],
                                &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
+                       // Do a montgomery reduction of `mu`
+
+                       // mu % R is just the bottom word_count bytes of mu
                        let mut mu_mod_r = [0; WORD_COUNT_4096];
                        mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
+
+                       // v = ((mu % R) * negative_modulus_inverse) % R
                        let mut v = mul(&mu_mod_r, &m_inv);
                        v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
+
+                       // t_on_r = (mu + v*modulus) / R
                        let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
                        let (t1, t1_extra_bit) = add_double(&t0, &mu);
+
+                       // Note that dividing t1 by R is simply a matter of shifting right by word_count bytes
+                       // We only need to maintain word_count bytes (plus `t1_extra_bit` which is implicitly
+                       // an extra bit) because t_on_r is guarantee to be, at max, 2*m - 1.
                        let mut t1_on_r = [0; WORD_COUNT_4096];
                        debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..],
                                "t1 should be divisible by r");
                        t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]);
+
+                       // The modulus has only word_count bytes, so if t1_extra_bit is set we are definitely
+                       // larger than the modulus.
                        if t1_extra_bit || t1_on_r >= m.0 {
                                let underflow;
                                (t1_on_r, underflow) = sub(&t1_on_r, &m.0);
-                               debug_assert_eq!(t1_extra_bit, underflow);
+                               debug_assert_eq!(t1_extra_bit, underflow,
+                                       "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
                        }
                        t1_on_r
                };
@@ -868,6 +899,8 @@ impl U4096 {
                }
                debug_assert!(r2_mod_m < m.0);
 
+               // Finally, actually do the exponentiation...
+
                // Calculate t * R and a * R as mont multiplications by R^2 mod m
                let mut tr = mont_reduction(mul(&r2_mod_m, &t));
                let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
@@ -893,115 +926,11 @@ impl U4096 {
        }
 }
 
-const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 {
-       match b {
-               [a, b, c, d, e, f, g, h, ..] => {
-                       ((*a as u64) << 8*7) |
-                       ((*b as u64) << 8*6) |
-                       ((*c as u64) << 8*5) |
-                       ((*d as u64) << 8*4) |
-                       ((*e as u64) << 8*3) |
-                       ((*f as u64) << 8*2) |
-                       ((*g as u64) << 8*1) |
-                       ((*h as u64) << 8*0)
-               },
-               _ => panic!(),
-       }
-}
-
-const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 {
-       match b {
-               [_, _, _, _, _, _, _, _,
-                a, b, c, d, e, f, g, h, ..] => {
-                       ((*a as u64) << 8*7) |
-                       ((*b as u64) << 8*6) |
-                       ((*c as u64) << 8*5) |
-                       ((*d as u64) << 8*4) |
-                       ((*e as u64) << 8*3) |
-                       ((*f as u64) << 8*2) |
-                       ((*g as u64) << 8*1) |
-                       ((*h as u64) << 8*0)
-               },
-               _ => panic!(),
-       }
-}
-
-const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 {
-       match b {
-               [_, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                a, b, c, d, e, f, g, h, ..] => {
-                       ((*a as u64) << 8*7) |
-                       ((*b as u64) << 8*6) |
-                       ((*c as u64) << 8*5) |
-                       ((*d as u64) << 8*4) |
-                       ((*e as u64) << 8*3) |
-                       ((*f as u64) << 8*2) |
-                       ((*g as u64) << 8*1) |
-                       ((*h as u64) << 8*0)
-               },
-               _ => panic!(),
-       }
-}
-
-const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 {
-       match b {
-               [_, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                a, b, c, d, e, f, g, h, ..] => {
-                       ((*a as u64) << 8*7) |
-                       ((*b as u64) << 8*6) |
-                       ((*c as u64) << 8*5) |
-                       ((*d as u64) << 8*4) |
-                       ((*e as u64) << 8*3) |
-                       ((*f as u64) << 8*2) |
-                       ((*g as u64) << 8*1) |
-                       ((*h as u64) << 8*0)
-               },
-               _ => panic!(),
-       }
-}
-
-const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 {
-       match b {
-               [_, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                a, b, c, d, e, f, g, h, ..] => {
-                       ((*a as u64) << 8*7) |
-                       ((*b as u64) << 8*6) |
-                       ((*c as u64) << 8*5) |
-                       ((*d as u64) << 8*4) |
-                       ((*e as u64) << 8*3) |
-                       ((*f as u64) << 8*2) |
-                       ((*g as u64) << 8*1) |
-                       ((*h as u64) << 8*0)
-               },
-               _ => panic!(),
-       }
-}
-
-const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 {
-       match b {
-               [_, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                _, _, _, _, _, _, _, _,
-                a, b, c, d, e, f, g, h, ..] => {
-                       ((*a as u64) << 8*7) |
-                       ((*b as u64) << 8*6) |
-                       ((*c as u64) << 8*5) |
-                       ((*d as u64) << 8*4) |
-                       ((*e as u64) << 8*3) |
-                       ((*f as u64) << 8*2) |
-                       ((*g as u64) << 8*1) |
-                       ((*h as u64) << 8*0)
-               },
-               _ => panic!(),
-       }
+// In a const context we can't subslice a slice, so instead we pick the eight bytes we want and
+// pass them here to build u64s from arrays.
+const fn eight_bytes_to_u64_be(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 {
+       let b = [a, b, c, d, e, f, g, h];
+       u64::from_be_bytes(b)
 }
 
 impl U256 {
@@ -1024,10 +953,14 @@ impl U256 {
        /// Constructs a new [`U256`] from a fixed number of big-endian bytes.
        pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 {
                let res = [
-                       u64_from_bytes_a_panicking(bytes),
-                       u64_from_bytes_b_panicking(bytes),
-                       u64_from_bytes_c_panicking(bytes),
-                       u64_from_bytes_d_panicking(bytes),
+                       eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3],
+                                             bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3],
+                                             bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3],
+                                             bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3],
+                                             bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]),
                ];
                U256(res)
        }
@@ -1037,6 +970,7 @@ impl U256 {
        pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) }
 }
 
+// Values modulus M::PRIME.0, stored in montgomery form.
 impl<M: PrimeModulus<U256>> U256Mod<M> {
        const fn mont_reduction(mu: [u64; 8]) -> Self {
                #[cfg(debug_assertions)] {
@@ -1061,18 +995,30 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                        assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
                }
 
+               // mu % R is just the bottom 4 bytes of mu
                let mu_mod_r = const_subslice(&mu, 4, 8);
+               // v = ((mu % R) * negative_modulus_inverse) % R
                let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
                const ZEROS: &[u64; 4] = &[0; 4];
                copy_from_slice!(v, 0, 4, ZEROS); // mod R
+
+               // t_on_r = (mu + v*modulus) / R
                let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0);
                let (t1, t1_extra_bit) = add_8(&t0, &mu);
+
+               // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
+               // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
+               // because t_on_r is guarantee to be, at max, 2*m - 1.
                let t1_on_r = const_subslice(&t1, 0, 4);
+
                let mut res = [0; 4];
+               // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the
+               // modulus.
                if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
                        let underflow;
                        (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0);
-                       debug_assert!(t1_extra_bit == underflow);
+                       debug_assert_eq!(t1_extra_bit, underflow,
+                               "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
                } else {
                        copy_from_slice!(res, 0, 4, t1_on_r);
                }
@@ -1099,7 +1045,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
                while v >= M::PRIME {
                        let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
-                       debug_assert!(!spurious_underflow);
+                       debug_assert!(!spurious_underflow, "v was > M::PRIME.0");
                        v = U256(new_v);
                }
                Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
@@ -1181,6 +1127,7 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
        }
 }
 
+// Values modulus M::PRIME.0, stored in montgomery form.
 impl U384 {
        /// Constructs a new [`U384`] from a variable number of big-endian bytes.
        pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U384, ()> {
@@ -1201,12 +1148,18 @@ impl U384 {
        /// Constructs a new [`U384`] from a fixed number of big-endian bytes.
        pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 {
                let res = [
-                       u64_from_bytes_a_panicking(bytes),
-                       u64_from_bytes_b_panicking(bytes),
-                       u64_from_bytes_c_panicking(bytes),
-                       u64_from_bytes_d_panicking(bytes),
-                       u64_from_bytes_e_panicking(bytes),
-                       u64_from_bytes_f_panicking(bytes),
+                       eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3],
+                                             bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3],
+                                             bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3],
+                                             bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3],
+                                             bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[4*8 + 0], bytes[4*8 + 1], bytes[4*8 + 2], bytes[4*8 + 3],
+                                             bytes[4*8 + 4], bytes[4*8 + 5], bytes[4*8 + 6], bytes[4*8 + 7]),
+                       eight_bytes_to_u64_be(bytes[5*8 + 0], bytes[5*8 + 1], bytes[5*8 + 2], bytes[5*8 + 3],
+                                             bytes[5*8 + 4], bytes[5*8 + 5], bytes[5*8 + 6], bytes[5*8 + 7]),
                ];
                U384(res)
        }
@@ -1240,14 +1193,25 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                        assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
                }
 
+               // mu % R is just the bottom 4 bytes of mu
                let mu_mod_r = const_subslice(&mu, 6, 12);
+               // v = ((mu % R) * negative_modulus_inverse) % R
                let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
                const ZEROS: &[u64; 6] = &[0; 6];
                copy_from_slice!(v, 0, 6, ZEROS); // mod R
+
+               // t_on_r = (mu + v*modulus) / R
                let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0);
                let (t1, t1_extra_bit) = add_12(&t0, &mu);
+
+               // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
+               // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
+               // because t_on_r is guarantee to be, at max, 2*m - 1.
                let t1_on_r = const_subslice(&t1, 0, 6);
+
                let mut res = [0; 6];
+               // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the
+               // modulus.
                if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
                        let underflow;
                        (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0);