62beef434bee5290c28f6991dbb23ca799ec4ef8
[dnssec-prover] / src / crypto / bigint.rs
1 //! Simple variable-time big integer implementation
2
3 use alloc::vec::Vec;
4
5 const WORD_COUNT_4096: usize = 4096 / 64;
6
7 // RFC 5702 indicates RSA keys can be up to 4096 bits
8 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
9 pub(super) struct U4096([u64; WORD_COUNT_4096]);
10
11 macro_rules! debug_unwrap { ($v: expr) => { {
12         let v = $v;
13         debug_assert!(v.is_ok());
14         match v {
15                 Ok(r) => r,
16                 Err(e) => return Err(e),
17         }
18 } } }
19
20 // Various const versions of existing slice utilities
21 /// Const version of `&a[start..end]`
22 const fn const_subslice<'a, T>(a: &'a [T], start: usize, end: usize) -> &'a [T] {
23         assert!(start <= a.len());
24         assert!(end <= a.len());
25         assert!(end >= start);
26         let mut startptr = a.as_ptr();
27         startptr = unsafe { startptr.add(start) };
28         let len = end - start;
29         // The docs for from_raw_parts do not mention any requirements that the pointer be valid if the
30         // length is zero, aside from requiring proper alignment (which is met here). Thus,
31         // one-past-the-end should be an acceptable pointer for a 0-length slice.
32         unsafe { alloc::slice::from_raw_parts(startptr, len) }
33 }
34
35 /// Const version of `dest[dest_start..dest_end].copy_from_slice(source)`
36 ///
37 /// Once `const_mut_refs` is stable we can convert this to a function
38 macro_rules! copy_from_slice {
39         ($dest: ident, $dest_start: expr, $dest_end: expr, $source: ident) => { {
40                 let dest_start = $dest_start;
41                 let dest_end = $dest_end;
42                 assert!(dest_start <= $dest.len());
43                 assert!(dest_end <= $dest.len());
44                 assert!(dest_end >= dest_start);
45                 assert!(dest_end - dest_start == $source.len());
46                 let mut i = 0;
47                 while i < $source.len() {
48                         $dest[i + dest_start] = $source[i];
49                         i += 1;
50                 }
51         } }
52 }
53
54 /// Const version of a > b
55 const fn slice_greater_than(a: &[u64], b: &[u64]) -> bool {
56         debug_assert!(a.len() == b.len());
57         let len = if a.len() <= b.len() { a.len() } else { b.len() };
58         let mut i = 0;
59         while i < len {
60                 if a[i] > b[i] { return true; }
61                 else if a[i] < b[i] { return false; }
62                 i += 1;
63         }
64         false // Equal
65 }
66
67 /// Const version of a == b
68 const fn slice_equal(a: &[u64], b: &[u64]) -> bool {
69         debug_assert!(a.len() == b.len());
70         let len = if a.len() <= b.len() { a.len() } else { b.len() };
71         let mut i = 0;
72         while i < len {
73                 if a[i] != b[i] { return false; }
74                 i += 1;
75         }
76         true
77 }
78
79 /// Adds one in-place, returning an overflow flag, in which case one out-of-bounds high bit is
80 /// implicitly included in the result.
81 ///
82 /// Once `const_mut_refs` is stable we can convert this to a function
83 macro_rules! add_one { ($a: ident) => { {
84         let len = $a.len();
85         let mut i = 0;
86         let mut res = true;
87         while i < len {
88                 let (v, carry) = $a[len - 1 - i].overflowing_add(1);
89                 $a[len - 1 - i] = v;
90                 if !carry { res = false; break; }
91                 i += 1;
92         }
93         res
94 } } }
95
96 /// Negates the given u64 slice.
97 ///
98 /// Once `const_mut_refs` is stable we can convert this to a function
99 macro_rules! negate { ($v: ident) => { {
100         let mut i = 0;
101         while i < $v.len() {
102                 $v[i] ^= 0xffff_ffff_ffff_ffff;
103                 i += 1;
104         }
105         let overflow = add_one!($v);
106         debug_assert!(!overflow);
107 } } }
108
109 /// Doubles in-place, returning an overflow flag, in which case one out-of-bounds high bit is
110 /// implicitly included in the result.
111 ///
112 /// Once `const_mut_refs` is stable we can convert this to a function
113 macro_rules! double { ($a: ident) => { {
114         { let _: &[u64] = &$a; } // Force type resolution
115         let len = $a.len();
116         let mut carry = false;
117         let mut i = 0;
118         while i < len {
119                 let mut next_carry = ($a[len - 1 - i] & (1 << 63)) != 0;
120                 let (v, next_carry_2) = ($a[len - 1 - i] << 1).overflowing_add(carry as u64);
121                 $a[len - 1 - i] = v;
122                 debug_assert!(!next_carry || !next_carry_2);
123                 next_carry |= next_carry_2;
124                 carry = next_carry;
125                 i += 1;
126         }
127         carry
128 } } }
129
130 macro_rules! define_add { ($name: ident, $len: expr) => {
131         /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
132         /// bit, with the same semantics as the std [`u64::overflowing_add`] method.
133         const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
134                 debug_assert!(a.len() == $len);
135                 debug_assert!(b.len() == $len);
136                 let mut r = [0; $len];
137                 let mut carry = false;
138                 let mut i = 0;
139                 while i < $len {
140                         let pos = $len - 1 - i;
141                         let (v, mut new_carry) = a[pos].overflowing_add(b[pos]);
142                         let (v2, new_new_carry) = v.overflowing_add(carry as u64);
143                         new_carry |= new_new_carry;
144                         r[pos] = v2;
145                         carry = new_carry;
146                         i += 1;
147                 }
148                 (r, carry)
149         }
150 } }
151
152 define_add!(add_2, 2);
153 define_add!(add_4, 4);
154 define_add!(add_8, 8);
155 define_add!(add_16, 16);
156 define_add!(add_32, 32);
157 define_add!(add_64, 64);
158 define_add!(add_128, 128);
159
160 macro_rules! define_sub { ($name: ident, $len: expr) => {
161         /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
162         /// $len-64-bit integer and an overflow bit, with the same semantics as the std
163         /// [`u64::overflowing_sub`] method.
164         const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
165                 debug_assert!(a.len() == $len);
166                 debug_assert!(b.len() == $len);
167                 let mut r = [0; $len];
168                 let mut carry = false;
169                 let mut i = 0;
170                 while i < $len {
171                         let pos = $len - 1 - i;
172                         let (v, mut new_carry) = a[pos].overflowing_sub(b[pos]);
173                         let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
174                         new_carry |= new_new_carry;
175                         r[pos] = v2;
176                         carry = new_carry;
177                         i += 1;
178                 }
179                 (r, carry)
180         }
181 } }
182
183 define_sub!(sub_2, 2);
184 define_sub!(sub_4, 4);
185 define_sub!(sub_8, 8);
186 define_sub!(sub_16, 16);
187 define_sub!(sub_32, 32);
188 define_sub!(sub_64, 64);
189 #[cfg(debug_assertions)]
190 define_sub!(sub_128, 128);
191
192 /// Multiplies two 128-bit integers together, returning a new 256-bit integer.
193 ///
194 /// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int
195 /// types to do multiplication (potentially) natively.
196 const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
197         debug_assert!(a.len() == 2);
198         debug_assert!(b.len() == 2);
199
200         // Gradeschool multiplication is way faster here.
201         let (a0, a1) = (a[0] as u128, a[1] as u128);
202         let (b0, b1) = (b[0] as u128, b[1] as u128);
203         let z2 = a0 * b0;
204         let z1i = a0 * b1;
205         let z1j = b0 * a1;
206         let (z1, i_carry) = z1i.overflowing_add(z1j);
207         let z0 = a1 * b1;
208
209         let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
210         let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
211         let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
212         let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
213         let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
214         let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
215
216         let l = z0b;
217         let (k, j_carry) = z0a.overflowing_add(z1b);
218         let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
219
220         let new_i_carry;
221         (j, new_i_carry) = j.overflowing_add(j_carry as u64);
222         debug_assert!(!second_i_carry || !new_i_carry);
223         second_i_carry |= new_i_carry;
224
225         let mut i = z2a;
226         let mut spurious_overflow;
227         (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
228         debug_assert!(!spurious_overflow);
229         (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
230         debug_assert!(!spurious_overflow);
231
232         [i, j, k, l]
233 }
234
235 macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
236         /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
237         const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
238                 // We could probably get a bit faster doing gradeschool multiplication for some smaller
239                 // sizes, but its easier to just have one variable-length multiplication, so we do
240                 // Karatsuba always here.
241                 debug_assert!(a.len() == $len);
242                 debug_assert!(b.len() == $len);
243
244                 let a0 = const_subslice(a, 0, $len / 2);
245                 let a1 = const_subslice(a, $len / 2, $len);
246                 let b0 = const_subslice(b, 0, $len / 2);
247                 let b1 = const_subslice(b, $len / 2, $len);
248
249                 let z2 = $submul(a0, b0);
250                 let z0 = $submul(a1, b1);
251
252                 let (z1a_max, z1a_min, z1a_sign) =
253                         if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) };
254                 let (z1b_max, z1b_min, z1b_sign) =
255                         if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
256
257                 let z1a = $subsub(z1a_max, z1a_min);
258                 debug_assert!(!z1a.1);
259                 let z1b = $subsub(z1b_max, z1b_min);
260                 debug_assert!(!z1b.1);
261                 let z1m_sign = z1a_sign == z1b_sign;
262
263                 let z1m = $submul(&z1a.0, &z1b.0);
264                 let z1n = $add(&z0, &z2);
265                 let mut z1_carry = z1n.1;
266                 let z1 = if z1m_sign {
267                         let r = $sub(&z1n.0, &z1m);
268                         if r.1 { z1_carry ^= true; }
269                         r.0
270                 } else {
271                         let r = $add(&z1n.0, &z1m);
272                         if r.1 { z1_carry = true; }
273                         r.0
274                 };
275
276                 let l = const_subslice(&z0, $len / 2, $len);
277                 let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
278                 let (mut j, mut i_carry) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
279                 if j_carry {
280                         let new_i_carry = add_one!(j);
281                         debug_assert!(!i_carry || !new_i_carry);
282                         i_carry |= new_i_carry;
283                 }
284                 let mut i = [0; $len / 2];
285                 let i_source = const_subslice(&z2, 0, $len / 2);
286                 copy_from_slice!(i, 0, $len / 2, i_source);
287                 if i_carry {
288                         let spurious_carry = add_one!(i);
289                         debug_assert!(!spurious_carry);
290                 }
291                 if z1_carry {
292                         let spurious_carry = add_one!(i);
293                         debug_assert!(!spurious_carry);
294                 }
295
296                 let mut res = [0; $len * 2];
297                 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
298                 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
299                 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
300                 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
301                 res
302         }
303 } }
304
305 define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
306 define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
307 define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
308 define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
309 define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32);
310
311
312 /// Squares a 128-bit integer, returning a new 256-bit integer.
313 ///
314 /// This is the base case for our squaring, taking advantage of Rust's native 128-bit int
315 /// types to do multiplication (potentially) natively.
316 const fn sqr_2(a: &[u64]) -> [u64; 4] {
317         debug_assert!(a.len() == 2);
318
319         let (a0, a1) = (a[0] as u128, a[1] as u128);
320         let z2 = a0 * a0;
321         let mut z1 = a0 * a1;
322         let i_carry = z1 & (1u128 << 127) != 0;
323         z1 <<= 1;
324         let z0 = a1 * a1;
325
326         let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
327         let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
328         let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
329         let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
330         let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
331         let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
332
333         let l = z0b;
334         let (k, j_carry) = z0a.overflowing_add(z1b);
335         let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
336
337         let new_i_carry;
338         (j, new_i_carry) = j.overflowing_add(j_carry as u64);
339         debug_assert!(!second_i_carry || !new_i_carry);
340         second_i_carry |= new_i_carry;
341
342         let mut i = z2a;
343         let mut spurious_overflow;
344         (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
345         debug_assert!(!spurious_overflow);
346         (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
347         debug_assert!(!spurious_overflow);
348
349         [i, j, k, l]
350 }
351
352 macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
353         /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
354         const fn $name(a: &[u64]) -> [u64; $len * 2] {
355                 debug_assert!(a.len() == $len);
356
357                 let hi = const_subslice(a, 0, $len / 2);
358                 let lo = const_subslice(a, $len / 2, $len);
359
360                 let v0 = $subsqr(lo);
361                 let mut v1 = $submul(hi, lo);
362                 let i_carry  = double!(v1);
363                 let v2 = $subsqr(hi);
364
365                 let l = const_subslice(&v0, $len / 2, $len);
366                 let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
367                 let (mut j, mut i_carry_2) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
368
369                 let mut i = [0; $len / 2];
370                 let i_source = const_subslice(&v2, 0, $len / 2);
371                 copy_from_slice!(i, 0, $len / 2, i_source);
372
373                 if j_carry {
374                         let new_i_carry = add_one!(j);
375                         debug_assert!(!i_carry_2 || !new_i_carry);
376                         i_carry_2 |= new_i_carry;
377                 }
378                 if i_carry {
379                         let spurious_carry = add_one!(i);
380                         debug_assert!(!spurious_carry);
381                 }
382                 if i_carry_2 {
383                         let spurious_carry = add_one!(i);
384                         debug_assert!(!spurious_carry);
385                 }
386
387                 let mut res = [0; $len * 2];
388                 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
389                 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
390                 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
391                 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
392                 res
393         }
394 } }
395
396 define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
397 define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
398 define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
399 define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
400 define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
401
402 #[cfg(fuzzing)]
403 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
404 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
405
406 macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init: expr, $pre_push: ident $(, $const_opt: tt)?) => {
407         /// Divides two $len-64-bit integers, `a` by `b`, returning the quotient and remainder
408         ///
409         /// Fails iff `b` is zero.
410         $($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> {
411                 if slice_equal(b, &[0; $len]) { return Err(()); }
412
413                 let mut b_pow = *b;
414                 let mut pow2s = $heap_init;
415                 let mut pow2s_count = 0;
416                 while slice_greater_than(a, &b_pow) {
417                         $pre_push!(pow2s, $len);
418                         pow2s[pow2s_count] = b_pow;
419                         pow2s_count += 1;
420                         let double_overflow = double!(b_pow);
421                         if double_overflow { break; }
422                 }
423                 let mut quot = [0; $len];
424                 let mut rem = *a;
425                 let mut pow2 = pow2s_count as isize - 1;
426                 while pow2 >= 0 {
427                         let b_pow = pow2s[pow2 as usize];
428                         let overflow = double!(quot);
429                         debug_assert!(!overflow);
430                         if slice_greater_than(&rem, &b_pow) {
431                                 let (r, carry) = $sub(&rem, &b_pow);
432                                 debug_assert!(!carry);
433                                 rem = r;
434                                 quot[$len - 1] |= 1;
435                         }
436                         pow2 -= 1;
437                 }
438                 if slice_equal(&rem, b) {
439                         let overflow = add_one!(quot);
440                         debug_assert!(!overflow);
441                         Ok((quot, [0; $len]))
442                 } else {
443                         Ok((quot, rem))
444                 }
445         }
446 } }
447
448 #[cfg(fuzzing)]
449 define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
450 #[cfg(fuzzing)]
451 define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
452 #[cfg(fuzzing)]
453 define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
454 define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
455 #[cfg(debug_assertions)]
456 define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
457
458 impl U4096 {
459         /// Constructs a new [`U4096`] from a variable number of big-endian bytes.
460         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
461                 if bytes.len() > 4096/8 { return Err(()); }
462                 let u64s = (bytes.len() + 7) / 8;
463                 let mut res = [0; WORD_COUNT_4096];
464                 for i in 0..u64s {
465                         let mut b = [0; 8];
466                         let pos = (u64s - i) * 8;
467                         let start = bytes.len().saturating_sub(pos);
468                         let end = bytes.len() + 8 - pos;
469                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
470                         res[i + WORD_COUNT_4096 - u64s] = u64::from_be_bytes(b);
471                 }
472                 Ok(U4096(res))
473         }
474
475         /// Naively multiplies `self` * `b` mod `m`, returning a new [`U4096`].
476         ///
477         /// Fails iff m is 0 or self or b are greater than m.
478         #[cfg(debug_assertions)]
479         fn mulmod_naive(&self, b: &U4096, m: &U4096) -> Result<U4096, ()> {
480                 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
481                 if self > m || b > m { return Err(()); }
482
483                 let mul = mul_64(&self.0, &b.0);
484
485                 let mut m_zeros = [0; 128];
486                 m_zeros[WORD_COUNT_4096..].copy_from_slice(&m.0);
487                 let (_, rem) = div_rem_128(&mul, &m_zeros)?;
488                 let mut res = [0; WORD_COUNT_4096];
489                 debug_assert_eq!(&rem[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
490                 res.copy_from_slice(&rem[WORD_COUNT_4096..]);
491                 Ok(U4096(res))
492         }
493
494         /// Calculates `self` ^ `exp` mod `m`, returning a new [`U4096`].
495         ///
496         /// Fails iff m is 0, even, or self or b are greater than m.
497         pub(super) fn expmod_odd_mod(&self, mut exp: u32, m: &U4096) -> Result<U4096, ()> {
498                 #![allow(non_camel_case_types)]
499
500                 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
501                 if m.0[WORD_COUNT_4096 - 1] & 1 == 0 { return Err(()); }
502                 if self > m { return Err(()); }
503
504                 let mut t = [0; WORD_COUNT_4096];
505                 if &m.0[..WORD_COUNT_4096 - 1] == &[0; WORD_COUNT_4096 - 1] && m.0[WORD_COUNT_4096 - 1] == 1 {
506                         return Ok(U4096(t));
507                 }
508                 t[WORD_COUNT_4096 - 1] = 1;
509                 if exp == 0 { return Ok(U4096(t)); }
510
511                 // Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is
512                 // guaranteed to be co-prime with any non-even integer.
513
514                 type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
515                 type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
516                 type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
517                 type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool);
518                 let (word_count, log_bits, mul, sqr, add_double, sub) =
519                         if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
520                                 if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] {
521                                         fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
522                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
523                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
524                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
525                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
526                                                 let mut res = [0; WORD_COUNT_4096 * 2];
527                                                 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
528                                                         &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]));
529                                                 res
530                                         }
531                                         fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
532                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
533                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
534                                                 let mut res = [0; WORD_COUNT_4096 * 2];
535                                                 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
536                                                         &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
537                                                 res
538                                         }
539                                         fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
540                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
541                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
542                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
543                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
544                                                 let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]);
545                                                 let mut res = [0; WORD_COUNT_4096 * 2];
546                                                 res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
547                                                 (res, overflow)
548                                         }
549                                         fn sub_16_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
550                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
551                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
552                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
553                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
554                                                 let (sub, underflow) = sub_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]);
555                                                 let mut res = [0; WORD_COUNT_4096];
556                                                 res[WORD_COUNT_4096 * 3 / 4..].copy_from_slice(&sub);
557                                                 (res, underflow)
558                                         }
559                                         (16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty)
560                                 } else {
561                                         fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
562                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
563                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
564                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
565                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
566                                                 let mut res = [0; WORD_COUNT_4096 * 2];
567                                                 res[WORD_COUNT_4096..].copy_from_slice(
568                                                         &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]));
569                                                 res
570                                         }
571                                         fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
572                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
573                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
574                                                 let mut res = [0; WORD_COUNT_4096 * 2];
575                                                 res[WORD_COUNT_4096..].copy_from_slice(
576                                                         &sqr_32(&a[WORD_COUNT_4096 / 2..]));
577                                                 res
578                                         }
579                                         fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
580                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
581                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
582                                                 debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
583                                                 debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
584                                                 let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]);
585                                                 let mut res = [0; WORD_COUNT_4096 * 2];
586                                                 res[WORD_COUNT_4096..].copy_from_slice(&add);
587                                                 (res, overflow)
588                                         }
589                                         fn sub_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
590                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
591                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
592                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
593                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
594                                                 let (sub, underflow) = sub_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]);
595                                                 let mut res = [0; WORD_COUNT_4096];
596                                                 res[WORD_COUNT_4096 / 2..].copy_from_slice(&sub);
597                                                 (res, underflow)
598                                         }
599                                         (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty)
600                                 }
601                         } else {
602                                 (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
603                         };
604
605                 let mut r = [0; WORD_COUNT_4096 * 2];
606                 r[WORD_COUNT_4096 * 2 - word_count - 1] = 1;
607
608                 let mut m_inv_pos = [0; WORD_COUNT_4096];
609                 m_inv_pos[WORD_COUNT_4096 - 1] = 1;
610                 let mut two = [0; WORD_COUNT_4096];
611                 two[WORD_COUNT_4096 - 1] = 2;
612                 for _ in 0..log_bits {
613                         let mut m_m_inv = mul(&m_inv_pos, &m.0);
614                         m_m_inv[..WORD_COUNT_4096 * 2 - word_count].fill(0);
615                         let m_inv = mul(&sub(&two, &m_m_inv[WORD_COUNT_4096..]).0, &m_inv_pos);
616                         m_inv_pos[WORD_COUNT_4096 - word_count..].copy_from_slice(&m_inv[WORD_COUNT_4096 * 2 - word_count..]);
617                 }
618                 m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0);
619
620                 // We want the negative modular inverse of m mod R, so subtract m_inv from R.
621                 let mut m_inv = m_inv_pos;
622                 negate!(m_inv);
623                 m_inv[..WORD_COUNT_4096 - word_count].fill(0);
624                 debug_assert_eq!(&mul(&m_inv, &m.0)[WORD_COUNT_4096 * 2 - word_count..],
625                         // R - 1 == -1 % R
626                         &[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]);
627
628                 debug_assert_eq!(&m_inv[..WORD_COUNT_4096 - word_count], &[0; WORD_COUNT_4096][..WORD_COUNT_4096 - word_count]);
629
630                 let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] {
631                         debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2],
632                                 &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
633                         let mut mu_mod_r = [0; WORD_COUNT_4096];
634                         mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
635                         let mut v = mul(&mu_mod_r, &m_inv);
636                         v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
637                         let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
638                         let (t1, t1_extra_bit) = add_double(&t0, &mu);
639                         let mut t1_on_r = [0; WORD_COUNT_4096];
640                         debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..],
641                                 "t1 should be divisible by r");
642                         t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]);
643                         if t1_extra_bit || t1_on_r >= m.0 {
644                                 let underflow;
645                                 (t1_on_r, underflow) = sub(&t1_on_r, &m.0);
646                                 debug_assert_eq!(t1_extra_bit, underflow);
647                         }
648                         t1_on_r
649                 };
650
651                 // Calculate R^2 mod m as ((2^DOUBLES * R) mod m)^(log_bits - LOG2_DOUBLES) mod R
652                 let mut r_minus_one = [0xffff_ffff_ffff_ffffu64; WORD_COUNT_4096];
653                 r_minus_one[..WORD_COUNT_4096 - word_count].fill(0);
654                 // While we do a full div here, in general R should be less than 2x m (assuming the RSA
655                 // modulus used its full bit range and is 1024, 2048, or 4096 bits), so it should be cheap.
656                 // In cases with a nonstandard RSA modulus we may end up being pretty slow here, but we'll
657                 // survive.
658                 // If we ever find a problem with this we should reduce R to be tigher on m, as we're
659                 // wasting extra bits of calculation if R is too far from m.
660                 let (_, mut r_mod_m) = debug_unwrap!(div_rem_64(&r_minus_one, &m.0));
661                 let r_mod_m_overflow = add_one!(r_mod_m);
662                 if r_mod_m_overflow || r_mod_m >= m.0 {
663                         (r_mod_m, _) = sub_64(&r_mod_m, &m.0);
664                 }
665
666                 let mut r2_mod_m: [u64; 64] = r_mod_m;
667                 const DOUBLES: usize = 32;
668                 const LOG2_DOUBLES: usize = 5;
669
670                 for _ in 0..DOUBLES {
671                         let overflow = double!(r2_mod_m);
672                         if overflow || r2_mod_m > m.0 {
673                                 (r2_mod_m, _) = sub_64(&r2_mod_m, &m.0);
674                         }
675                 }
676                 for _ in 0..log_bits - LOG2_DOUBLES {
677                         r2_mod_m = mont_reduction(sqr(&r2_mod_m));
678                 }
679                 // Clear excess high bits
680                 for (m_limb, r2_limb) in m.0.iter().zip(r2_mod_m.iter_mut()) {
681                         let clear_bits = m_limb.leading_zeros();
682                         if clear_bits == 0 { break; }
683                         *r2_limb &= !(0xffff_ffff_ffff_ffffu64 << (64 - clear_bits));
684                         if *m_limb != 0 { break; }
685                 }
686                 debug_assert!(r2_mod_m < m.0);
687
688                 // Calculate t * R and a * R as mont multiplications by R^2 mod m
689                 let mut tr = mont_reduction(mul(&r2_mod_m, &t));
690                 let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
691
692                 #[cfg(debug_assertions)] {
693                         debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
694                         debug_assert_eq!(&tr, &U4096(t).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
695                         debug_assert_eq!(&ar, &self.mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
696                 }
697
698                 while exp != 1 {
699                         if exp % 2 == 1 {
700                                 tr = mont_reduction(mul(&tr, &ar));
701                                 exp -= 1;
702                         }
703                         ar = mont_reduction(sqr(&ar));
704                         exp /= 2;
705                 }
706                 ar = mont_reduction(mul(&ar, &tr));
707                 let mut resr = [0; WORD_COUNT_4096 * 2];
708                 resr[WORD_COUNT_4096..].copy_from_slice(&ar);
709                 Ok(U4096(mont_reduction(resr)))
710         }
711 }
712
713 #[cfg(fuzzing)]
714 extern crate ibig;
715 #[cfg(fuzzing)]
716 /// Read some bytes and use them to test bigint math by comparing results against the `ibig` crate.
717 pub fn fuzz_math(input: &[u8]) {
718         if input.len() < 32 || input.len() % 16 != 0 { return; }
719         let split = core::cmp::min(input.len() / 2, 512);
720         let (a, b) = input.split_at(core::cmp::min(input.len() / 2, 512));
721         let b = &b[..split];
722
723         let ai = ibig::UBig::from_be_bytes(&a);
724         let bi = ibig::UBig::from_be_bytes(&b);
725
726         let mut a_u64s = Vec::with_capacity(split / 8);
727         for chunk in a.chunks(8) {
728                 a_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
729         }
730         let mut b_u64s = Vec::with_capacity(split / 8);
731         for chunk in b.chunks(8) {
732                 b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
733         }
734
735         macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident) => {
736                 let res = $mul(&a_u64s, &b_u64s);
737                 let mut res_bytes = Vec::with_capacity(input.len() / 2);
738                 for i in res {
739                         res_bytes.extend_from_slice(&i.to_be_bytes());
740                 }
741                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone());
742
743                 debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
744                 debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
745
746                 let (res, carry) = $add(&a_u64s, &b_u64s);
747                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
748                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
749                 for i in res {
750                         res_bytes.extend_from_slice(&i.to_be_bytes());
751                 }
752                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + bi.clone());
753
754                 let mut add_u64s = a_u64s.clone();
755                 let carry = add_one!(add_u64s);
756                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
757                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
758                 for i in &add_u64s {
759                         res_bytes.extend_from_slice(&i.to_be_bytes());
760                 }
761                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + 1);
762
763                 let mut double_u64s = b_u64s.clone();
764                 let carry = double!(double_u64s);
765                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
766                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
767                 for i in &double_u64s {
768                         res_bytes.extend_from_slice(&i.to_be_bytes());
769                 }
770                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), bi.clone() * 2);
771
772                 let (quot, rem) = if let Ok(res) =
773                         $div_rem(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
774                                 res
775                         } else { return };
776                 let mut quot_bytes = Vec::with_capacity(input.len() / 2);
777                 for i in quot {
778                         quot_bytes.extend_from_slice(&i.to_be_bytes());
779                 }
780                 let mut rem_bytes = Vec::with_capacity(input.len() / 2);
781                 for i in rem {
782                         rem_bytes.extend_from_slice(&i.to_be_bytes());
783                 }
784                 let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi);
785                 assert_eq!(ibig::UBig::from_be_bytes(&quot_bytes), quoti);
786                 assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi);
787         } }
788
789         if a_u64s.len() == 2 {
790                 test!(mul_2, sqr_2, add_2, sub_2, div_rem_2);
791         } else if a_u64s.len() == 4 {
792                 test!(mul_4, sqr_4, add_4, sub_4, div_rem_4);
793         } else if a_u64s.len() == 8 {
794                 test!(mul_8, sqr_8, add_8, sub_8, div_rem_8);
795         } else if input.len() == 512*2 + 4 {
796                 let mut e_bytes = [0; 4];
797                 e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);
798                 let e = u32::from_le_bytes(e_bytes);
799                 let a = U4096::from_be_bytes(&a).unwrap();
800                 let b = U4096::from_be_bytes(&b).unwrap();
801
802                 let res = if let Ok(r) = a.expmod_odd_mod(e, &b) { r } else { return };
803                 let mut res_bytes = Vec::with_capacity(512);
804                 for i in res.0 {
805                         res_bytes.extend_from_slice(&i.to_be_bytes());
806                 }
807
808                 let ring = ibig::modular::ModuloRing::new(&bi);
809                 let ar = ring.from(ai.clone());
810                 assert_eq!(ar.pow(&e.into()).residue(), ibig::UBig::from_be_bytes(&res_bytes));
811         }
812 }
813
814 #[cfg(test)]
815 mod tests {
816         use super::*;
817
818         #[test]
819         fn mul_min_simple_tests() {
820                 let a = [1, 2];
821                 let b = [3, 4];
822                 let res = mul_2(&a, &b);
823                 assert_eq!(res, [0, 3, 10, 8]);
824
825                 let a = [0x1bad_cafe_dead_beef, 2424];
826                 let b = [0x2bad_beef_dead_cafe, 4242];
827                 let res = mul_2(&a, &b);
828                 assert_eq!(res, [340296855556511776, 15015369169016130186, 4248480538569992542, 10282608]);
829
830                 let a = [0xf6d9_f8eb_8b60_7a6d, 0x4b93_833e_2194_fc2e];
831                 let b = [0xfdab_0000_6952_8ab4, 0xd302_0000_8282_0000];
832                 let res = mul_2(&a, &b);
833                 assert_eq!(res, [17625486516939878681, 18390748118453258282, 2695286104209847530, 1510594524414214144]);
834
835                 let a = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
836                 let b = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
837                 let res = mul_2(&a, &b);
838                 assert_eq!(res, [5481115605507762349, 8230042173354675923, 16737530186064798, 15714555036048702841]);
839
840                 let a = [0x0000_0000_0000_0020, 0x002d_362c_005b_7753];
841                 let b = [0x0900_0000_0030_0003, 0xb708_00fe_0000_00cd];
842                 let res = mul_2(&a, &b);
843                 assert_eq!(res, [1, 2306290405521702946, 17647397529888728169, 10271802099389861239]);
844
845                 let a = [0x0000_0000_7fff_ffff, 0xffff_ffff_0000_0000];
846                 let b = [0x0000_0800_0000_0000, 0x0000_1000_0000_00e1];
847                 let res = mul_2(&a, &b);
848                 assert_eq!(res, [1024, 0, 483183816703, 18446743107341910016]);
849
850                 let a = [0xf6d9_f8eb_ebeb_eb6d, 0x4b93_83a0_bb35_0680];
851                 let b = [0xfd02_b9b9_b9b9_b9b9, 0xb9b9_b9b9_b9b9_b9b9];
852                 let res = mul_2(&a, &b);
853                 assert_eq!(res, [17579814114991930107, 15033987447865175985, 488855932380801351, 5453318140933190272]);
854
855                 let a = [u64::MAX; 2];
856                 let b = [u64::MAX; 2];
857                 let res = mul_2(&a, &b);
858                 assert_eq!(res, [18446744073709551615, 18446744073709551614, 0, 1]);
859         }
860
861         #[test]
862         fn add_simple_tests() {
863                 let a = [u64::MAX; 2];
864                 let b = [u64::MAX; 2];
865                 assert_eq!(add_2(&a, &b), ([18446744073709551615, 18446744073709551614], true));
866
867                 let a = [0x1bad_cafe_dead_beef, 2424];
868                 let b = [0x2bad_beef_dead_cafe, 4242];
869                 assert_eq!(add_2(&a, &b), ([5141855058045667821, 6666], false));
870         }
871
872         #[test]
873         fn mul_4_simple_tests() {
874                 let a = [1; 4];
875                 let b = [2; 4];
876                 assert_eq!(mul_4(&a, &b),
877                         [0, 2, 4, 6, 8, 6, 4, 2]);
878
879                 let a = [0x1bad_cafe_dead_beef, 2424, 0x1bad_cafe_dead_beef, 2424];
880                 let b = [0x2bad_beef_dead_cafe, 4242, 0x2bad_beef_dead_cafe, 4242];
881                 assert_eq!(mul_4(&a, &b),
882                         [340296855556511776, 15015369169016130186, 4929074249683016095, 11583994264332991364,
883                          8837257932696496860, 15015369169036695402, 4248480538569992542, 10282608]);
884
885                 let a = [u64::MAX; 4];
886                 let b = [u64::MAX; 4];
887                 assert_eq!(mul_4(&a, &b),
888                         [18446744073709551615, 18446744073709551615, 18446744073709551615,
889                          18446744073709551614, 0, 0, 0, 1]);
890         }
891
892         #[test]
893         fn double_simple_tests() {
894                 let mut a = [0xfff5_b32d_01ff_0000, 0x00e7_e7e7_e7e7_e7e7];
895                 assert!(double!(a));
896                 assert_eq!(a, [18440945635998695424, 130551405668716494]);
897
898                 let mut a = [u64::MAX, u64::MAX];
899                 assert!(double!(a));
900                 assert_eq!(a, [18446744073709551615, 18446744073709551614]);
901         }
902 }