Add some comments about mont reduction to make it a bit clearer
[dnssec-prover] / src / crypto / bigint.rs
index eec0a56f8d380745e61b5e2486fe9b608859674c..ff478eec6fe65eed692575b733b41fbca025cca5 100644 (file)
@@ -836,6 +836,11 @@ impl U4096 {
                                &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
                        // Do a montgomery reduction of `mu`
 
+                       // The definition of REDC (with some names changed):
+                       // v = ((mu % R) * N') mod R
+                       // t = (mu + v*N) / R
+                       // if t >= N { t - N } else { t }
+
                        // mu % R is just the bottom word_count bytes of mu
                        let mut mu_mod_r = [0; WORD_COUNT_4096];
                        mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
@@ -902,16 +907,21 @@ impl U4096 {
                        *r2_limb &= !(0xffff_ffff_ffff_ffffu64 << (64 - clear_bits));
                        if *m_limb != 0 { break; }
                }
+
                debug_assert!(r2_mod_m < m.0);
+               #[cfg(debug_assertions)] {
+                       debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
+               }
 
                // Finally, actually do the exponentiation...
 
                // Calculate t * R and a * R as mont multiplications by R^2 mod m
+               // (i.e. t * R^2 / R and 1 * R^2 / R)
                let mut tr = mont_reduction(mul(&r2_mod_m, &t));
                let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
 
                #[cfg(debug_assertions)] {
-                       debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
+                       // Check that tr/ar match naive multiplication
                        debug_assert_eq!(&tr, &U4096(t).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
                        debug_assert_eq!(&ar, &self.mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
                }
@@ -1001,6 +1011,11 @@ impl<M: PrimeModulus<U256>> U256Mod<M> {
                        assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
                }
 
+               // The definition of REDC (with some names changed):
+               // v = ((mu % R) * N') mod R
+               // t = (mu + v*N) / R
+               // if t >= N { t - N } else { t }
+
                // mu % R is just the bottom 4 bytes of mu
                let mu_mod_r = const_subslice(&mu, 4, 8);
                // v = ((mu % R) * negative_modulus_inverse) % R
@@ -1200,6 +1215,11 @@ impl<M: PrimeModulus<U384>> U384Mod<M> {
                        assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
                }
 
+               // The definition of REDC (with some names changed):
+               // v = ((mu % R) * N') mod R
+               // t = (mu + v*N) / R
+               // if t >= N { t - N } else { t }
+
                // mu % R is just the bottom 4 bytes of mu
                let mu_mod_r = const_subslice(&mu, 6, 12);
                // v = ((mu % R) * negative_modulus_inverse) % R