1 //! Simple variable-time big integer implementation
5 const WORD_COUNT_4096: usize = 4096 / 64;
7 // RFC 5702 indicates RSA keys can be up to 4096 bits
8 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
9 pub(super) struct U4096([u64; WORD_COUNT_4096]);
11 macro_rules! debug_unwrap { ($v: expr) => { {
13 debug_assert!(v.is_ok());
16 Err(e) => return Err(e),
20 // Various const versions of existing slice utilities
21 /// Const version of `&a[start..end]`
22 const fn const_subslice<'a, T>(a: &'a [T], start: usize, end: usize) -> &'a [T] {
23 assert!(start <= a.len());
24 assert!(end <= a.len());
25 assert!(end >= start);
26 let mut startptr = a.as_ptr();
27 startptr = unsafe { startptr.add(start) };
28 let len = end - start;
29 // The docs for from_raw_parts do not mention any requirements that the pointer be valid if the
30 // length is zero, aside from requiring proper alignment (which is met here). Thus,
31 // one-past-the-end should be an acceptable pointer for a 0-length slice.
32 unsafe { alloc::slice::from_raw_parts(startptr, len) }
35 /// Const version of `dest[dest_start..dest_end].copy_from_slice(source)`
37 /// Once `const_mut_refs` is stable we can convert this to a function
38 macro_rules! copy_from_slice {
39 ($dest: ident, $dest_start: expr, $dest_end: expr, $source: ident) => { {
40 let dest_start = $dest_start;
41 let dest_end = $dest_end;
42 assert!(dest_start <= $dest.len());
43 assert!(dest_end <= $dest.len());
44 assert!(dest_end >= dest_start);
45 assert!(dest_end - dest_start == $source.len());
47 while i < $source.len() {
48 $dest[i + dest_start] = $source[i];
54 /// Const version of a > b
55 const fn slice_greater_than(a: &[u64], b: &[u64]) -> bool {
56 debug_assert!(a.len() == b.len());
57 let len = if a.len() <= b.len() { a.len() } else { b.len() };
60 if a[i] > b[i] { return true; }
61 else if a[i] < b[i] { return false; }
67 /// Const version of a == b
68 const fn slice_equal(a: &[u64], b: &[u64]) -> bool {
69 debug_assert!(a.len() == b.len());
70 let len = if a.len() <= b.len() { a.len() } else { b.len() };
73 if a[i] != b[i] { return false; }
79 /// Adds one in-place, returning an overflow flag, in which case one out-of-bounds high bit is
80 /// implicitly included in the result.
82 /// Once `const_mut_refs` is stable we can convert this to a function
83 macro_rules! add_one { ($a: ident) => { {
88 let (v, carry) = $a[len - 1 - i].overflowing_add(1);
90 if !carry { res = false; break; }
96 /// Negates the given u64 slice.
98 /// Once `const_mut_refs` is stable we can convert this to a function
99 macro_rules! negate { ($v: ident) => { {
102 $v[i] ^= 0xffff_ffff_ffff_ffff;
105 let overflow = add_one!($v);
106 debug_assert!(!overflow);
109 /// Doubles in-place, returning an overflow flag, in which case one out-of-bounds high bit is
110 /// implicitly included in the result.
112 /// Once `const_mut_refs` is stable we can convert this to a function
113 macro_rules! double { ($a: ident) => { {
114 { let _: &[u64] = &$a; } // Force type resolution
116 let mut carry = false;
119 let mut next_carry = ($a[len - 1 - i] & (1 << 63)) != 0;
120 let (v, next_carry_2) = ($a[len - 1 - i] << 1).overflowing_add(carry as u64);
122 debug_assert!(!next_carry || !next_carry_2);
123 next_carry |= next_carry_2;
130 macro_rules! define_add { ($name: ident, $len: expr) => {
131 /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
132 /// bit, with the same semantics as the std [`u64::overflowing_add`] method.
133 const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
134 debug_assert!(a.len() == $len);
135 debug_assert!(b.len() == $len);
136 let mut r = [0; $len];
137 let mut carry = false;
140 let pos = $len - 1 - i;
141 let (v, mut new_carry) = a[pos].overflowing_add(b[pos]);
142 let (v2, new_new_carry) = v.overflowing_add(carry as u64);
143 new_carry |= new_new_carry;
152 define_add!(add_2, 2);
153 define_add!(add_4, 4);
154 define_add!(add_8, 8);
155 define_add!(add_16, 16);
156 define_add!(add_32, 32);
157 define_add!(add_64, 64);
158 define_add!(add_128, 128);
160 macro_rules! define_sub { ($name: ident, $len: expr) => {
161 /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
162 /// $len-64-bit integer and an overflow bit, with the same semantics as the std
163 /// [`u64::overflowing_sub`] method.
164 const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
165 debug_assert!(a.len() == $len);
166 debug_assert!(b.len() == $len);
167 let mut r = [0; $len];
168 let mut carry = false;
171 let pos = $len - 1 - i;
172 let (v, mut new_carry) = a[pos].overflowing_sub(b[pos]);
173 let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
174 new_carry |= new_new_carry;
183 define_sub!(sub_2, 2);
184 define_sub!(sub_4, 4);
185 define_sub!(sub_8, 8);
186 define_sub!(sub_16, 16);
187 define_sub!(sub_32, 32);
188 define_sub!(sub_64, 64);
189 #[cfg(debug_assertions)]
190 define_sub!(sub_128, 128);
192 /// Multiplies two 128-bit integers together, returning a new 256-bit integer.
194 /// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int
195 /// types to do multiplication (potentially) natively.
196 const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
197 debug_assert!(a.len() == 2);
198 debug_assert!(b.len() == 2);
200 // Gradeschool multiplication is way faster here.
201 let (a0, a1) = (a[0] as u128, a[1] as u128);
202 let (b0, b1) = (b[0] as u128, b[1] as u128);
206 let (z1, i_carry) = z1i.overflowing_add(z1j);
209 let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
210 let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
211 let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
212 let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
213 let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
214 let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
217 let (k, j_carry) = z0a.overflowing_add(z1b);
218 let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
221 (j, new_i_carry) = j.overflowing_add(j_carry as u64);
222 debug_assert!(!second_i_carry || !new_i_carry);
223 second_i_carry |= new_i_carry;
226 let mut spurious_overflow;
227 (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
228 debug_assert!(!spurious_overflow);
229 (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
230 debug_assert!(!spurious_overflow);
235 macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
236 /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
237 const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
238 // We could probably get a bit faster doing gradeschool multiplication for some smaller
239 // sizes, but its easier to just have one variable-length multiplication, so we do
240 // Karatsuba always here.
241 debug_assert!(a.len() == $len);
242 debug_assert!(b.len() == $len);
244 let a0 = const_subslice(a, 0, $len / 2);
245 let a1 = const_subslice(a, $len / 2, $len);
246 let b0 = const_subslice(b, 0, $len / 2);
247 let b1 = const_subslice(b, $len / 2, $len);
249 let z2 = $submul(a0, b0);
250 let z0 = $submul(a1, b1);
252 let (z1a_max, z1a_min, z1a_sign) =
253 if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) };
254 let (z1b_max, z1b_min, z1b_sign) =
255 if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
257 let z1a = $subsub(z1a_max, z1a_min);
258 debug_assert!(!z1a.1);
259 let z1b = $subsub(z1b_max, z1b_min);
260 debug_assert!(!z1b.1);
261 let z1m_sign = z1a_sign == z1b_sign;
263 let z1m = $submul(&z1a.0, &z1b.0);
264 let z1n = $add(&z0, &z2);
265 let mut z1_carry = z1n.1;
266 let z1 = if z1m_sign {
267 let r = $sub(&z1n.0, &z1m);
268 if r.1 { z1_carry ^= true; }
271 let r = $add(&z1n.0, &z1m);
272 if r.1 { z1_carry = true; }
276 let l = const_subslice(&z0, $len / 2, $len);
277 let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
278 let (mut j, mut i_carry) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
280 let new_i_carry = add_one!(j);
281 debug_assert!(!i_carry || !new_i_carry);
282 i_carry |= new_i_carry;
284 let mut i = [0; $len / 2];
285 let i_source = const_subslice(&z2, 0, $len / 2);
286 copy_from_slice!(i, 0, $len / 2, i_source);
288 let spurious_carry = add_one!(i);
289 debug_assert!(!spurious_carry);
292 let spurious_carry = add_one!(i);
293 debug_assert!(!spurious_carry);
296 let mut res = [0; $len * 2];
297 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
298 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
299 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
300 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
305 define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
306 define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
307 define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
308 define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
309 define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32);
312 /// Squares a 128-bit integer, returning a new 256-bit integer.
314 /// This is the base case for our squaring, taking advantage of Rust's native 128-bit int
315 /// types to do multiplication (potentially) natively.
316 const fn sqr_2(a: &[u64]) -> [u64; 4] {
317 debug_assert!(a.len() == 2);
319 let (a0, a1) = (a[0] as u128, a[1] as u128);
321 let mut z1 = a0 * a1;
322 let i_carry = z1 & (1u128 << 127) != 0;
326 let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
327 let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
328 let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
329 let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
330 let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
331 let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
334 let (k, j_carry) = z0a.overflowing_add(z1b);
335 let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
338 (j, new_i_carry) = j.overflowing_add(j_carry as u64);
339 debug_assert!(!second_i_carry || !new_i_carry);
340 second_i_carry |= new_i_carry;
343 let mut spurious_overflow;
344 (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
345 debug_assert!(!spurious_overflow);
346 (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
347 debug_assert!(!spurious_overflow);
352 macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
353 /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
354 const fn $name(a: &[u64]) -> [u64; $len * 2] {
355 debug_assert!(a.len() == $len);
357 let hi = const_subslice(a, 0, $len / 2);
358 let lo = const_subslice(a, $len / 2, $len);
360 let v0 = $subsqr(lo);
361 let mut v1 = $submul(hi, lo);
362 let i_carry = double!(v1);
363 let v2 = $subsqr(hi);
365 let l = const_subslice(&v0, $len / 2, $len);
366 let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
367 let (mut j, mut i_carry_2) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
369 let mut i = [0; $len / 2];
370 let i_source = const_subslice(&v2, 0, $len / 2);
371 copy_from_slice!(i, 0, $len / 2, i_source);
374 let new_i_carry = add_one!(j);
375 debug_assert!(!i_carry_2 || !new_i_carry);
376 i_carry_2 |= new_i_carry;
379 let spurious_carry = add_one!(i);
380 debug_assert!(!spurious_carry);
383 let spurious_carry = add_one!(i);
384 debug_assert!(!spurious_carry);
387 let mut res = [0; $len * 2];
388 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
389 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
390 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
391 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
396 define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
397 define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
398 define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
399 define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
400 define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
403 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
404 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
406 macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init: expr, $pre_push: ident $(, $const_opt: tt)?) => {
407 /// Divides two $len-64-bit integers, `a` by `b`, returning the quotient and remainder
409 /// Fails iff `b` is zero.
410 $($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> {
411 if slice_equal(b, &[0; $len]) { return Err(()); }
414 let mut pow2s = $heap_init;
415 let mut pow2s_count = 0;
416 while slice_greater_than(a, &b_pow) {
417 $pre_push!(pow2s, $len);
418 pow2s[pow2s_count] = b_pow;
420 let double_overflow = double!(b_pow);
421 if double_overflow { break; }
423 let mut quot = [0; $len];
425 let mut pow2 = pow2s_count as isize - 1;
427 let b_pow = pow2s[pow2 as usize];
428 let overflow = double!(quot);
429 debug_assert!(!overflow);
430 if slice_greater_than(&rem, &b_pow) {
431 let (r, carry) = $sub(&rem, &b_pow);
432 debug_assert!(!carry);
438 if slice_equal(&rem, b) {
439 let overflow = add_one!(quot);
440 debug_assert!(!overflow);
441 Ok((quot, [0; $len]))
449 define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
451 define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
453 define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
454 define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
455 #[cfg(debug_assertions)]
456 define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
459 /// Constructs a new [`U4096`] from a variable number of big-endian bytes.
460 pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
461 if bytes.len() > 4096/8 { return Err(()); }
462 let u64s = (bytes.len() + 7) / 8;
463 let mut res = [0; WORD_COUNT_4096];
466 let pos = (u64s - i) * 8;
467 let start = bytes.len().saturating_sub(pos);
468 let end = bytes.len() + 8 - pos;
469 b[8 + start - end..].copy_from_slice(&bytes[start..end]);
470 res[i + WORD_COUNT_4096 - u64s] = u64::from_be_bytes(b);
475 /// Naively multiplies `self` * `b` mod `m`, returning a new [`U4096`].
477 /// Fails iff m is 0 or self or b are greater than m.
478 #[cfg(debug_assertions)]
479 fn mulmod_naive(&self, b: &U4096, m: &U4096) -> Result<U4096, ()> {
480 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
481 if self > m || b > m { return Err(()); }
483 let mul = mul_64(&self.0, &b.0);
485 let mut m_zeros = [0; 128];
486 m_zeros[WORD_COUNT_4096..].copy_from_slice(&m.0);
487 let (_, rem) = div_rem_128(&mul, &m_zeros)?;
488 let mut res = [0; WORD_COUNT_4096];
489 debug_assert_eq!(&rem[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
490 res.copy_from_slice(&rem[WORD_COUNT_4096..]);
494 /// Calculates `self` ^ `exp` mod `m`, returning a new [`U4096`].
496 /// Fails iff m is 0, even, or self or b are greater than m.
497 pub(super) fn expmod_odd_mod(&self, mut exp: u32, m: &U4096) -> Result<U4096, ()> {
498 #![allow(non_camel_case_types)]
500 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
501 if m.0[WORD_COUNT_4096 - 1] & 1 == 0 { return Err(()); }
502 if self > m { return Err(()); }
504 let mut t = [0; WORD_COUNT_4096];
505 if &m.0[..WORD_COUNT_4096 - 1] == &[0; WORD_COUNT_4096 - 1] && m.0[WORD_COUNT_4096 - 1] == 1 {
508 t[WORD_COUNT_4096 - 1] = 1;
509 if exp == 0 { return Ok(U4096(t)); }
511 // Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is
512 // guaranteed to be co-prime with any non-even integer.
514 type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
515 type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
516 type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
517 type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool);
518 let (word_count, log_bits, mul, sqr, add_double, sub) =
519 if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
520 if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] {
521 fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
522 debug_assert_eq!(a.len(), WORD_COUNT_4096);
523 debug_assert_eq!(b.len(), WORD_COUNT_4096);
524 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
525 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
526 let mut res = [0; WORD_COUNT_4096 * 2];
527 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
528 &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]));
531 fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
532 debug_assert_eq!(a.len(), WORD_COUNT_4096);
533 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
534 let mut res = [0; WORD_COUNT_4096 * 2];
535 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
536 &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
539 fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
540 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
541 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
542 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
543 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
544 let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]);
545 let mut res = [0; WORD_COUNT_4096 * 2];
546 res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
549 fn sub_16_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
550 debug_assert_eq!(a.len(), WORD_COUNT_4096);
551 debug_assert_eq!(b.len(), WORD_COUNT_4096);
552 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
553 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
554 let (sub, underflow) = sub_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]);
555 let mut res = [0; WORD_COUNT_4096];
556 res[WORD_COUNT_4096 * 3 / 4..].copy_from_slice(&sub);
559 (16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty)
561 fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
562 debug_assert_eq!(a.len(), WORD_COUNT_4096);
563 debug_assert_eq!(b.len(), WORD_COUNT_4096);
564 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
565 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
566 let mut res = [0; WORD_COUNT_4096 * 2];
567 res[WORD_COUNT_4096..].copy_from_slice(
568 &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]));
571 fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
572 debug_assert_eq!(a.len(), WORD_COUNT_4096);
573 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
574 let mut res = [0; WORD_COUNT_4096 * 2];
575 res[WORD_COUNT_4096..].copy_from_slice(
576 &sqr_32(&a[WORD_COUNT_4096 / 2..]));
579 fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
580 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
581 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
582 debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
583 debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
584 let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]);
585 let mut res = [0; WORD_COUNT_4096 * 2];
586 res[WORD_COUNT_4096..].copy_from_slice(&add);
589 fn sub_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
590 debug_assert_eq!(a.len(), WORD_COUNT_4096);
591 debug_assert_eq!(b.len(), WORD_COUNT_4096);
592 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
593 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
594 let (sub, underflow) = sub_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]);
595 let mut res = [0; WORD_COUNT_4096];
596 res[WORD_COUNT_4096 / 2..].copy_from_slice(&sub);
599 (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty)
602 (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
605 let mut r = [0; WORD_COUNT_4096 * 2];
606 r[WORD_COUNT_4096 * 2 - word_count - 1] = 1;
608 let mut m_inv_pos = [0; WORD_COUNT_4096];
609 m_inv_pos[WORD_COUNT_4096 - 1] = 1;
610 let mut two = [0; WORD_COUNT_4096];
611 two[WORD_COUNT_4096 - 1] = 2;
612 for _ in 0..log_bits {
613 let mut m_m_inv = mul(&m_inv_pos, &m.0);
614 m_m_inv[..WORD_COUNT_4096 * 2 - word_count].fill(0);
615 let m_inv = mul(&sub(&two, &m_m_inv[WORD_COUNT_4096..]).0, &m_inv_pos);
616 m_inv_pos[WORD_COUNT_4096 - word_count..].copy_from_slice(&m_inv[WORD_COUNT_4096 * 2 - word_count..]);
618 m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0);
620 // We want the negative modular inverse of m mod R, so subtract m_inv from R.
621 let mut m_inv = m_inv_pos;
623 m_inv[..WORD_COUNT_4096 - word_count].fill(0);
624 debug_assert_eq!(&mul(&m_inv, &m.0)[WORD_COUNT_4096 * 2 - word_count..],
626 &[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]);
628 debug_assert_eq!(&m_inv[..WORD_COUNT_4096 - word_count], &[0; WORD_COUNT_4096][..WORD_COUNT_4096 - word_count]);
630 let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] {
631 debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2],
632 &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
633 let mut mu_mod_r = [0; WORD_COUNT_4096];
634 mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
635 let mut v = mul(&mu_mod_r, &m_inv);
636 v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
637 let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
638 let (t1, t1_extra_bit) = add_double(&t0, &mu);
639 let mut t1_on_r = [0; WORD_COUNT_4096];
640 debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..],
641 "t1 should be divisible by r");
642 t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]);
643 if t1_extra_bit || t1_on_r >= m.0 {
645 (t1_on_r, underflow) = sub(&t1_on_r, &m.0);
646 debug_assert_eq!(t1_extra_bit, underflow);
651 // Calculate R^2 mod m as ((2^DOUBLES * R) mod m)^(log_bits - LOG2_DOUBLES) mod R
652 let mut r_minus_one = [0xffff_ffff_ffff_ffffu64; WORD_COUNT_4096];
653 r_minus_one[..WORD_COUNT_4096 - word_count].fill(0);
654 // While we do a full div here, in general R should be less than 2x m (assuming the RSA
655 // modulus used its full bit range and is 1024, 2048, or 4096 bits), so it should be cheap.
656 // In cases with a nonstandard RSA modulus we may end up being pretty slow here, but we'll
658 // If we ever find a problem with this we should reduce R to be tigher on m, as we're
659 // wasting extra bits of calculation if R is too far from m.
660 let (_, mut r_mod_m) = debug_unwrap!(div_rem_64(&r_minus_one, &m.0));
661 let r_mod_m_overflow = add_one!(r_mod_m);
662 if r_mod_m_overflow || r_mod_m >= m.0 {
663 (r_mod_m, _) = sub_64(&r_mod_m, &m.0);
666 let mut r2_mod_m: [u64; 64] = r_mod_m;
667 const DOUBLES: usize = 32;
668 const LOG2_DOUBLES: usize = 5;
670 for _ in 0..DOUBLES {
671 let overflow = double!(r2_mod_m);
672 if overflow || r2_mod_m > m.0 {
673 (r2_mod_m, _) = sub_64(&r2_mod_m, &m.0);
676 for _ in 0..log_bits - LOG2_DOUBLES {
677 r2_mod_m = mont_reduction(sqr(&r2_mod_m));
679 // Clear excess high bits
680 for (m_limb, r2_limb) in m.0.iter().zip(r2_mod_m.iter_mut()) {
681 let clear_bits = m_limb.leading_zeros();
682 if clear_bits == 0 { break; }
683 *r2_limb &= !(0xffff_ffff_ffff_ffffu64 << (64 - clear_bits));
684 if *m_limb != 0 { break; }
686 debug_assert!(r2_mod_m < m.0);
688 // Calculate t * R and a * R as mont multiplications by R^2 mod m
689 let mut tr = mont_reduction(mul(&r2_mod_m, &t));
690 let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
692 #[cfg(debug_assertions)] {
693 debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
694 debug_assert_eq!(&tr, &U4096(t).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
695 debug_assert_eq!(&ar, &self.mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
700 tr = mont_reduction(mul(&tr, &ar));
703 ar = mont_reduction(sqr(&ar));
706 ar = mont_reduction(mul(&ar, &tr));
707 let mut resr = [0; WORD_COUNT_4096 * 2];
708 resr[WORD_COUNT_4096..].copy_from_slice(&ar);
709 Ok(U4096(mont_reduction(resr)))
716 /// Read some bytes and use them to test bigint math by comparing results against the `ibig` crate.
717 pub fn fuzz_math(input: &[u8]) {
718 if input.len() < 32 || input.len() % 16 != 0 { return; }
719 let split = core::cmp::min(input.len() / 2, 512);
720 let (a, b) = input.split_at(core::cmp::min(input.len() / 2, 512));
723 let ai = ibig::UBig::from_be_bytes(&a);
724 let bi = ibig::UBig::from_be_bytes(&b);
726 let mut a_u64s = Vec::with_capacity(split / 8);
727 for chunk in a.chunks(8) {
728 a_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
730 let mut b_u64s = Vec::with_capacity(split / 8);
731 for chunk in b.chunks(8) {
732 b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
735 macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident) => {
736 let res = $mul(&a_u64s, &b_u64s);
737 let mut res_bytes = Vec::with_capacity(input.len() / 2);
739 res_bytes.extend_from_slice(&i.to_be_bytes());
741 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone());
743 debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
744 debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
746 let (res, carry) = $add(&a_u64s, &b_u64s);
747 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
748 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
750 res_bytes.extend_from_slice(&i.to_be_bytes());
752 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + bi.clone());
754 let mut add_u64s = a_u64s.clone();
755 let carry = add_one!(add_u64s);
756 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
757 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
759 res_bytes.extend_from_slice(&i.to_be_bytes());
761 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + 1);
763 let mut double_u64s = b_u64s.clone();
764 let carry = double!(double_u64s);
765 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
766 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
767 for i in &double_u64s {
768 res_bytes.extend_from_slice(&i.to_be_bytes());
770 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), bi.clone() * 2);
772 let (quot, rem) = if let Ok(res) =
773 $div_rem(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
776 let mut quot_bytes = Vec::with_capacity(input.len() / 2);
778 quot_bytes.extend_from_slice(&i.to_be_bytes());
780 let mut rem_bytes = Vec::with_capacity(input.len() / 2);
782 rem_bytes.extend_from_slice(&i.to_be_bytes());
784 let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi);
785 assert_eq!(ibig::UBig::from_be_bytes("_bytes), quoti);
786 assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi);
789 if a_u64s.len() == 2 {
790 test!(mul_2, sqr_2, add_2, sub_2, div_rem_2);
791 } else if a_u64s.len() == 4 {
792 test!(mul_4, sqr_4, add_4, sub_4, div_rem_4);
793 } else if a_u64s.len() == 8 {
794 test!(mul_8, sqr_8, add_8, sub_8, div_rem_8);
795 } else if input.len() == 512*2 + 4 {
796 let mut e_bytes = [0; 4];
797 e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);
798 let e = u32::from_le_bytes(e_bytes);
799 let a = U4096::from_be_bytes(&a).unwrap();
800 let b = U4096::from_be_bytes(&b).unwrap();
802 let res = if let Ok(r) = a.expmod_odd_mod(e, &b) { r } else { return };
803 let mut res_bytes = Vec::with_capacity(512);
805 res_bytes.extend_from_slice(&i.to_be_bytes());
808 let ring = ibig::modular::ModuloRing::new(&bi);
809 let ar = ring.from(ai.clone());
810 assert_eq!(ar.pow(&e.into()).residue(), ibig::UBig::from_be_bytes(&res_bytes));
819 fn mul_min_simple_tests() {
822 let res = mul_2(&a, &b);
823 assert_eq!(res, [0, 3, 10, 8]);
825 let a = [0x1bad_cafe_dead_beef, 2424];
826 let b = [0x2bad_beef_dead_cafe, 4242];
827 let res = mul_2(&a, &b);
828 assert_eq!(res, [340296855556511776, 15015369169016130186, 4248480538569992542, 10282608]);
830 let a = [0xf6d9_f8eb_8b60_7a6d, 0x4b93_833e_2194_fc2e];
831 let b = [0xfdab_0000_6952_8ab4, 0xd302_0000_8282_0000];
832 let res = mul_2(&a, &b);
833 assert_eq!(res, [17625486516939878681, 18390748118453258282, 2695286104209847530, 1510594524414214144]);
835 let a = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
836 let b = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
837 let res = mul_2(&a, &b);
838 assert_eq!(res, [5481115605507762349, 8230042173354675923, 16737530186064798, 15714555036048702841]);
840 let a = [0x0000_0000_0000_0020, 0x002d_362c_005b_7753];
841 let b = [0x0900_0000_0030_0003, 0xb708_00fe_0000_00cd];
842 let res = mul_2(&a, &b);
843 assert_eq!(res, [1, 2306290405521702946, 17647397529888728169, 10271802099389861239]);
845 let a = [0x0000_0000_7fff_ffff, 0xffff_ffff_0000_0000];
846 let b = [0x0000_0800_0000_0000, 0x0000_1000_0000_00e1];
847 let res = mul_2(&a, &b);
848 assert_eq!(res, [1024, 0, 483183816703, 18446743107341910016]);
850 let a = [0xf6d9_f8eb_ebeb_eb6d, 0x4b93_83a0_bb35_0680];
851 let b = [0xfd02_b9b9_b9b9_b9b9, 0xb9b9_b9b9_b9b9_b9b9];
852 let res = mul_2(&a, &b);
853 assert_eq!(res, [17579814114991930107, 15033987447865175985, 488855932380801351, 5453318140933190272]);
855 let a = [u64::MAX; 2];
856 let b = [u64::MAX; 2];
857 let res = mul_2(&a, &b);
858 assert_eq!(res, [18446744073709551615, 18446744073709551614, 0, 1]);
862 fn add_simple_tests() {
863 let a = [u64::MAX; 2];
864 let b = [u64::MAX; 2];
865 assert_eq!(add_2(&a, &b), ([18446744073709551615, 18446744073709551614], true));
867 let a = [0x1bad_cafe_dead_beef, 2424];
868 let b = [0x2bad_beef_dead_cafe, 4242];
869 assert_eq!(add_2(&a, &b), ([5141855058045667821, 6666], false));
873 fn mul_4_simple_tests() {
876 assert_eq!(mul_4(&a, &b),
877 [0, 2, 4, 6, 8, 6, 4, 2]);
879 let a = [0x1bad_cafe_dead_beef, 2424, 0x1bad_cafe_dead_beef, 2424];
880 let b = [0x2bad_beef_dead_cafe, 4242, 0x2bad_beef_dead_cafe, 4242];
881 assert_eq!(mul_4(&a, &b),
882 [340296855556511776, 15015369169016130186, 4929074249683016095, 11583994264332991364,
883 8837257932696496860, 15015369169036695402, 4248480538569992542, 10282608]);
885 let a = [u64::MAX; 4];
886 let b = [u64::MAX; 4];
887 assert_eq!(mul_4(&a, &b),
888 [18446744073709551615, 18446744073709551615, 18446744073709551615,
889 18446744073709551614, 0, 0, 0, 1]);
893 fn double_simple_tests() {
894 let mut a = [0xfff5_b32d_01ff_0000, 0x00e7_e7e7_e7e7_e7e7];
896 assert_eq!(a, [18440945635998695424, 130551405668716494]);
898 let mut a = [u64::MAX, u64::MAX];
900 assert_eq!(a, [18446744073709551615, 18446744073709551614]);