From 6109d8350163516da0947f70dd4efbe02414a1b5 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 7 May 2024 20:26:31 +0000 Subject: [PATCH] Add some comments about mont reduction to make it a bit clearer --- src/crypto/bigint.rs | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/crypto/bigint.rs b/src/crypto/bigint.rs index eec0a56..ff478ee 100644 --- a/src/crypto/bigint.rs +++ b/src/crypto/bigint.rs @@ -836,6 +836,11 @@ impl U4096 { &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]); // Do a montgomery reduction of `mu` + // The definition of REDC (with some names changed): + // v = ((mu % R) * N') mod R + // t = (mu + v*N) / R + // if t >= N { t - N } else { t } + // mu % R is just the bottom word_count bytes of mu let mut mu_mod_r = [0; WORD_COUNT_4096]; mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]); @@ -902,16 +907,21 @@ impl U4096 { *r2_limb &= !(0xffff_ffff_ffff_ffffu64 << (64 - clear_bits)); if *m_limb != 0 { break; } } + debug_assert!(r2_mod_m < m.0); + #[cfg(debug_assertions)] { + debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0); + } // Finally, actually do the exponentiation... // Calculate t * R and a * R as mont multiplications by R^2 mod m + // (i.e. t * R^2 / R and 1 * R^2 / R) let mut tr = mont_reduction(mul(&r2_mod_m, &t)); let mut ar = mont_reduction(mul(&r2_mod_m, &self.0)); #[cfg(debug_assertions)] { - debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0); + // Check that tr/ar match naive multiplication debug_assert_eq!(&tr, &U4096(t).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0); debug_assert_eq!(&ar, &self.mulmod_naive(&U4096(r_mod_m), &m).unwrap().0); } @@ -1001,6 +1011,11 @@ impl> U256Mod { assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0)); } + // The definition of REDC (with some names changed): + // v = ((mu % R) * N') mod R + // t = (mu + v*N) / R + // if t >= N { t - N } else { t } + // mu % R is just the bottom 4 bytes of mu let mu_mod_r = const_subslice(&mu, 4, 8); // v = ((mu % R) * negative_modulus_inverse) % R @@ -1200,6 +1215,11 @@ impl> U384Mod { assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0)); } + // The definition of REDC (with some names changed): + // v = ((mu % R) * N') mod R + // t = (mu + v*N) / R + // if t >= N { t - N } else { t } + // mu % R is just the bottom 4 bytes of mu let mu_mod_r = const_subslice(&mu, 6, 12); // v = ((mu % R) * negative_modulus_inverse) % R -- 2.39.5