Bump version to 0.6.5 for new less-code-size feature
[dnssec-prover] / src / crypto / bigint.rs
1 //! Simple variable-time big integer implementation
2
3 use alloc::vec::Vec;
4 use core::marker::PhantomData;
5
6 #[allow(clippy::needless_lifetimes)] // lifetimes improve readability
7 #[allow(clippy::needless_borrow)] // borrows indicate read-only/non-move
8 #[allow(clippy::too_many_arguments)] // sometimes we don't have an option
9 #[allow(clippy::identity_op)] // sometimes identities improve readability for repeated actions
10
11 // **************************************
12 // * Implementations of math primitives *
13 // **************************************
14
15 macro_rules! debug_unwrap { ($v: expr) => { {
16         let v = $v;
17         debug_assert!(v.is_ok());
18         match v {
19                 Ok(r) => r,
20                 Err(e) => return Err(e),
21         }
22 } } }
23
24 // Various const versions of existing slice utilities
25 /// Const version of `&a[start..end]`
26 const fn const_subslice<'a, T>(a: &'a [T], start: usize, end: usize) -> &'a [T] {
27         assert!(start <= a.len());
28         assert!(end <= a.len());
29         assert!(end >= start);
30         let mut startptr = a.as_ptr();
31         startptr = unsafe { startptr.add(start) };
32         let len: usize = end - start;
33         // We should use alloc::slice::from_raw_parts here, but sadly it was only stabilized as const
34         // in 1.64, whereas we need an MSRV of 1.63. Instead we rely on that which "you should not rely
35         // on" - that slices are laid out as a simple tuple of pointer + length.
36         //
37         // The Rust language reference doesn't specify the layout of slices at all, but does give us a
38         // hint, saying
39         //   Note: Though you should not rely on this, all pointers to DSTs are currently twice the
40         //   size of the size of usize and have the same alignment.
41         // This leaves only two possibilities (for today's rust) - `(length, pointer)` and
42         // `(pointer, length)`. Today, in practice, this seems to always be `(pointer, length)`, so we
43         // just assume it and hope to move to a later MSRV soon.
44         unsafe { core::mem::transmute((startptr, len)) }
45         /*// The docs for from_raw_parts do not mention any requirements that the pointer be valid if the
46         // length is zero, aside from requiring proper alignment (which is met here). Thus,
47         // one-past-the-end should be an acceptable pointer for a 0-length slice.
48         unsafe { alloc::slice::from_raw_parts(startptr, len) }*/
49 }
50
51 /// Const version of `dest[dest_start..dest_end].copy_from_slice(source)`
52 ///
53 /// Once `const_mut_refs` is stable we can convert this to a function
54 macro_rules! copy_from_slice {
55         ($dest: ident, $dest_start: expr, $dest_end: expr, $source: ident) => { {
56                 let dest_start = $dest_start;
57                 let dest_end = $dest_end;
58                 assert!(dest_start <= $dest.len());
59                 assert!(dest_end <= $dest.len());
60                 assert!(dest_end >= dest_start);
61                 assert!(dest_end - dest_start == $source.len());
62                 let mut i = 0;
63                 while i < $source.len() {
64                         $dest[i + dest_start] = $source[i];
65                         i += 1;
66                 }
67         } }
68 }
69
70 /// Const version of a > b
71 const fn slice_greater_than(a: &[u64], b: &[u64]) -> bool {
72         debug_assert!(a.len() == b.len());
73         let len = if a.len() <= b.len() { a.len() } else { b.len() };
74         let mut i = 0;
75         while i < len {
76                 if a[i] > b[i] { return true; }
77                 else if a[i] < b[i] { return false; }
78                 i += 1;
79         }
80         false // Equal
81 }
82
83 /// Const version of a == b
84 const fn slice_equal(a: &[u64], b: &[u64]) -> bool {
85         debug_assert!(a.len() == b.len());
86         let len = if a.len() <= b.len() { a.len() } else { b.len() };
87         let mut i = 0;
88         while i < len {
89                 if a[i] != b[i] { return false; }
90                 i += 1;
91         }
92         true
93 }
94
95 /// Adds a single u64 valuein-place, returning an overflow flag, in which case one out-of-bounds
96 /// high bit is implicitly included in the result.
97 ///
98 /// Once `const_mut_refs` is stable we can convert this to a function
99 macro_rules! add_u64 { ($a: ident, $b: expr) => { {
100         let len = $a.len();
101         let mut i = len - 1;
102         let mut add = $b;
103         loop {
104                 let (v, carry) = $a[i].overflowing_add(add);
105                 $a[i] = v;
106                 add = carry as u64;
107                 if add == 0 { break; }
108
109                 if i == 0 { break; }
110                 i -= 1;
111         }
112         add != 0
113 } } }
114
115 /// Negates the given u64 slice.
116 ///
117 /// Once `const_mut_refs` is stable we can convert this to a function
118 macro_rules! negate { ($v: ident) => { {
119         let mut i = 0;
120         while i < $v.len() {
121                 $v[i] ^= 0xffff_ffff_ffff_ffff;
122                 i += 1;
123         }
124         let _ = add_u64!($v, 1);
125 } } }
126
127 /// Doubles in-place, returning an overflow flag, in which case one out-of-bounds high bit is
128 /// implicitly included in the result.
129 ///
130 /// Once `const_mut_refs` is stable we can convert this to a function
131 macro_rules! double { ($a: ident) => { {
132         { let _: &[u64] = &$a; } // Force type resolution
133         let len = $a.len();
134         let mut carry = false;
135         let mut i = len - 1;
136         loop {
137                 let next_carry = ($a[i] & (1 << 63)) != 0;
138                 let (v, _next_carry_2) = ($a[i] << 1).overflowing_add(carry as u64);
139                 if !next_carry {
140                         debug_assert!(!_next_carry_2, "Adding one to 0x7ffff..*2 is only 0xffff..");
141                 }
142                 // Note that we can ignore _next_carry_2 here as we never need it - it cannot be set if
143                 // next_carry is not set and at max 0xffff..*2 + 1 is only 0x1ffff.. (i.e. we can not need
144                 // a double-carry).
145                 $a[i] = v;
146                 carry = next_carry;
147
148                 if i == 0 { break; }
149                 i -= 1;
150         }
151         carry
152 } } }
153
154 macro_rules! define_add { ($name: ident, $len: expr) => {
155         /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
156         /// bit, with the same semantics as the std [`u64::overflowing_add`] method.
157         const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
158                 debug_assert!(a.len() == $len);
159                 debug_assert!(b.len() == $len);
160                 let mut r = [0; $len];
161                 let mut carry = false;
162                 let mut i = $len - 1;
163                 loop {
164                         let (v, mut new_carry) = a[i].overflowing_add(b[i]);
165                         let (v2, new_new_carry) = v.overflowing_add(carry as u64);
166                         new_carry |= new_new_carry;
167                         r[i] = v2;
168                         carry = new_carry;
169
170                         if i == 0 { break; }
171                         i -= 1;
172                 }
173                 (r, carry)
174         }
175 } }
176
177 define_add!(add_2, 2);
178 define_add!(add_3, 3);
179 define_add!(add_4, 4);
180 define_add!(add_6, 6);
181 define_add!(add_8, 8);
182 define_add!(add_12, 12);
183 define_add!(add_16, 16);
184 define_add!(add_32, 32);
185 define_add!(add_64, 64);
186 define_add!(add_128, 128);
187
188 macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
189         /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
190         /// $len-64-bit integer and an overflow bit, with the same semantics as the std
191         /// [`u64::overflowing_sub`] method.
192         const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
193                 debug_assert!(a.len() == $len);
194                 debug_assert!(b.len() == $len);
195                 let mut r = [0; $len];
196                 let mut carry = false;
197                 let mut i = $len - 1;
198                 loop {
199                         let (v, mut new_carry) = a[i].overflowing_sub(b[i]);
200                         let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
201                         new_carry |= new_new_carry;
202                         r[i] = v2;
203                         carry = new_carry;
204
205                         if i == 0 { break; }
206                         i -= 1;
207                 }
208                 (r, carry)
209         }
210
211         /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
212         /// $len-64-bit integer representing the absolute value of the result, as well as a sign bit.
213         #[allow(unused)]
214         const fn $name_abs(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
215                 let (mut res, neg) = $name(a, b);
216                 if neg {
217                         negate!(res);
218                 }
219                 (res, neg)
220         }
221 } }
222
223 define_sub!(sub_2, sub_abs_2, 2);
224 define_sub!(sub_3, sub_abs_3, 3);
225 define_sub!(sub_4, sub_abs_4, 4);
226 define_sub!(sub_6, sub_abs_6, 6);
227 define_sub!(sub_8, sub_abs_8, 8);
228 define_sub!(sub_12, sub_abs_12, 12);
229 define_sub!(sub_16, sub_abs_16, 16);
230 define_sub!(sub_32, sub_abs_32, 32);
231 define_sub!(sub_64, sub_abs_64, 64);
232 define_sub!(sub_128, sub_abs_128, 128);
233
234 /// Multiplies two 128-bit integers together, returning a new 256-bit integer.
235 ///
236 /// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int
237 /// types to do multiplication (potentially) natively.
238 const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
239         debug_assert!(a.len() == 2);
240         debug_assert!(b.len() == 2);
241
242         // Gradeschool multiplication is way faster here.
243         let (a0, a1) = (a[0] as u128, a[1] as u128);
244         let (b0, b1) = (b[0] as u128, b[1] as u128);
245         let z2 = a0 * b0;
246         let z1i = a0 * b1;
247         let z1j = b0 * a1;
248         let (z1, i_carry_a) = z1i.overflowing_add(z1j);
249         let z0 = a1 * b1;
250
251         add_mul_2_parts(z2, z1, z0, i_carry_a)
252 }
253
254 /// Adds the gradeschool multiplication intermediate parts to a final 256-bit result
255 const fn add_mul_2_parts(z2: u128, z1: u128, z0: u128, i_carry_a: bool) -> [u64; 4] {
256         let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
257         let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
258         let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
259         let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
260         let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
261         let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
262
263         let l = z0b;
264
265         let (k, j_carry) = z0a.overflowing_add(z1b);
266
267         let (mut j, i_carry_b) = z1a.overflowing_add(z2b);
268         let i_carry_c;
269         (j, i_carry_c) = j.overflowing_add(j_carry as u64);
270
271         let i_carry = i_carry_a as u64 + i_carry_b as u64 + i_carry_c as u64;
272         let (i, must_not_overflow) = z2a.overflowing_add(i_carry);
273         debug_assert!(!must_not_overflow, "Two 2*64 bit numbers, multiplied, will not use more than 4*64 bits");
274
275         [i, j, k, l]
276 }
277
278 const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
279         debug_assert!(a.len() == 3);
280         debug_assert!(b.len() == 3);
281
282         let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128);
283         let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128);
284
285         let m4 = a2 * b2;
286         let m3a = a2 * b1;
287         let m3b = a1 * b2;
288         let m2a = a2 * b0;
289         let m2b = a1 * b1;
290         let m2c = a0 * b2;
291         let m1a = a1 * b0;
292         let m1b = a0 * b1;
293         let m0 = a0 * b0;
294
295         let r5 = ((m4 >> 0) & 0xffff_ffff_ffff_ffff) as u64;
296
297         let r4a = ((m4 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
298         let r4b = ((m3a >> 0) & 0xffff_ffff_ffff_ffff) as u64;
299         let r4c = ((m3b >> 0) & 0xffff_ffff_ffff_ffff) as u64;
300
301         let r3a = ((m3a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
302         let r3b = ((m3b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
303         let r3c = ((m2a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
304         let r3d = ((m2b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
305         let r3e = ((m2c >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
306
307         let r2a = ((m2a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
308         let r2b = ((m2b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
309         let r2c = ((m2c >> 64) & 0xffff_ffff_ffff_ffff) as u64;
310         let r2d = ((m1a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
311         let r2e = ((m1b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
312
313         let r1a = ((m1a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
314         let r1b = ((m1b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
315         let r1c = ((m0  >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
316
317         let r0a = ((m0  >> 64) & 0xffff_ffff_ffff_ffff) as u64;
318
319         let (r4, r3_ca) = r4a.overflowing_add(r4b);
320         let (r4, r3_cb) = r4.overflowing_add(r4c);
321         let r3_c = r3_ca as u64 + r3_cb as u64;
322
323         let (r3, r2_ca) = r3a.overflowing_add(r3b);
324         let (r3, r2_cb) = r3.overflowing_add(r3c);
325         let (r3, r2_cc) = r3.overflowing_add(r3d);
326         let (r3, r2_cd) = r3.overflowing_add(r3e);
327         let (r3, r2_ce) = r3.overflowing_add(r3_c);
328         let r2_c = r2_ca as u64 + r2_cb as u64 + r2_cc as u64 + r2_cd as u64 + r2_ce as u64;
329
330         let (r2, r1_ca) = r2a.overflowing_add(r2b);
331         let (r2, r1_cb) = r2.overflowing_add(r2c);
332         let (r2, r1_cc) = r2.overflowing_add(r2d);
333         let (r2, r1_cd) = r2.overflowing_add(r2e);
334         let (r2, r1_ce) = r2.overflowing_add(r2_c);
335         let r1_c = r1_ca as u64 + r1_cb as u64 + r1_cc as u64 + r1_cd as u64 + r1_ce as u64;
336
337         let (r1, r0_ca) = r1a.overflowing_add(r1b);
338         let (r1, r0_cb) = r1.overflowing_add(r1c);
339         let (r1, r0_cc) = r1.overflowing_add(r1_c);
340         let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64;
341
342         let (r0, must_not_overflow) = r0a.overflowing_add(r0_c);
343         debug_assert!(!must_not_overflow, "Two 3*64 bit numbers, multiplied, will not use more than 6*64 bits");
344
345         [r0, r1, r2, r3, r4, r5]
346 }
347
348 macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
349         /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
350         const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
351                 // We could probably get a bit faster doing gradeschool multiplication for some smaller
352                 // sizes, but its easier to just have one variable-length multiplication, so we do
353                 // Karatsuba always here.
354                 debug_assert!(a.len() == $len);
355                 debug_assert!(b.len() == $len);
356
357                 let a0 = const_subslice(a, 0, $len / 2);
358                 let a1 = const_subslice(a, $len / 2, $len);
359                 let b0 = const_subslice(b, 0, $len / 2);
360                 let b1 = const_subslice(b, $len / 2, $len);
361
362                 let z2 = $submul(a0, b0);
363                 let z0 = $submul(a1, b1);
364
365                 let (z1a_max, z1a_min, z1a_sign) =
366                         if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) };
367                 let (z1b_max, z1b_min, z1b_sign) =
368                         if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
369
370                 let z1a = $subsub(z1a_max, z1a_min);
371                 debug_assert!(!z1a.1, "z1a_max was selected to be greater than z1a_min");
372                 let z1b = $subsub(z1b_max, z1b_min);
373                 debug_assert!(!z1b.1, "z1b_max was selected to be greater than z1b_min");
374                 let z1m_sign = z1a_sign == z1b_sign;
375
376                 let z1m = $submul(&z1a.0, &z1b.0);
377                 let z1n = $add(&z0, &z2);
378                 let mut z1_carry = z1n.1;
379                 let z1 = if z1m_sign {
380                         let r = $sub(&z1n.0, &z1m);
381                         if r.1 { z1_carry ^= true; }
382                         r.0
383                 } else {
384                         let r = $add(&z1n.0, &z1m);
385                         if r.1 { z1_carry = true; }
386                         r.0
387                 };
388
389                 let l = const_subslice(&z0, $len / 2, $len);
390                 let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
391                 let (mut j, i_carry_a) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
392                 let mut i_carry_b = false;
393                 if j_carry {
394                         i_carry_b = add_u64!(j, 1);
395                 }
396                 let mut i = [0; $len / 2];
397                 let i_source = const_subslice(&z2, 0, $len / 2);
398                 copy_from_slice!(i, 0, $len / 2, i_source);
399                 let i_carry = i_carry_a as u64 + i_carry_b as u64 + z1_carry as u64;
400                 if i_carry != 0 {
401                         let must_not_overflow = add_u64!(i, i_carry);
402                         debug_assert!(!must_not_overflow, "Two N*64 bit numbers, multiplied, will not use more than 2*N*64 bits");
403                 }
404
405                 let mut res = [0; $len * 2];
406                 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
407                 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
408                 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
409                 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
410                 res
411         }
412 } }
413
414 define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
415 define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3);
416 define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
417 define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
418 define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
419 define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32);
420
421
422 /// Squares a 128-bit integer, returning a new 256-bit integer.
423 ///
424 /// This is the base case for our squaring, taking advantage of Rust's native 128-bit int
425 /// types to do multiplication (potentially) natively.
426 const fn sqr_2(a: &[u64]) -> [u64; 4] {
427         debug_assert!(a.len() == 2);
428
429         let (a0, a1) = (a[0] as u128, a[1] as u128);
430         let z2 = a0 * a0;
431         let mut z1 = a0 * a1;
432         let i_carry_a = z1 & (1u128 << 127) != 0;
433         z1 <<= 1;
434         let z0 = a1 * a1;
435
436         add_mul_2_parts(z2, z1, z0, i_carry_a)
437 }
438
439 macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
440         /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
441         const fn $name(a: &[u64]) -> [u64; $len * 2] {
442                 // Squaring is only 3 half-length multiplies/squares in gradeschool math, so use that.
443                 debug_assert!(a.len() == $len);
444
445                 let hi = const_subslice(a, 0, $len / 2);
446                 let lo = const_subslice(a, $len / 2, $len);
447
448                 let v0 = $subsqr(lo);
449                 let mut v1 = $submul(hi, lo);
450                 let i_carry_a  = double!(v1);
451                 let v2 = $subsqr(hi);
452
453                 let l = const_subslice(&v0, $len / 2, $len);
454                 let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
455                 let (mut j, i_carry_b) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
456
457                 let mut i = [0; $len / 2];
458                 let i_source = const_subslice(&v2, 0, $len / 2);
459                 copy_from_slice!(i, 0, $len / 2, i_source);
460
461                 let mut i_carry_c = false;
462                 if j_carry {
463                         i_carry_c = add_u64!(j, 1);
464                 }
465                 let i_carry = i_carry_a as u64 + i_carry_b as u64 + i_carry_c as u64;
466                 if i_carry != 0 {
467                         let must_not_overflow = add_u64!(i, i_carry);
468                         debug_assert!(!must_not_overflow, "Two N*64 bit numbers, multiplied, will not use more than 2*N*64 bits");
469                 }
470
471                 let mut res = [0; $len * 2];
472                 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
473                 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
474                 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
475                 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
476                 res
477         }
478 } }
479
480 // TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
481 const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) }
482
483 define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
484 define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
485 define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
486 define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
487 define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
488 define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
489
490 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
491 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
492
493 macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init: expr, $pre_push: ident $(, $const_opt: tt)?) => {
494         /// Divides two $len-64-bit integers, `a` by `b`, returning the quotient and remainder
495         ///
496         /// Fails iff `b` is zero.
497         $($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> {
498                 if slice_equal(b, &[0; $len]) { return Err(()); }
499
500                 // Very naively divide `a` by `b` by calculating all the powers of two times `b` up to `a`,
501                 // then subtracting the powers of two in decreasing order. What's left is the remainder.
502                 //
503                 // This requires storing all the multiples of `b` in `pow2s`, which may be a vec or an
504                 // array. `$pre_push!()` sets up the next element with zeros and then we can overwrite it.
505                 let mut b_pow = *b;
506                 let mut pow2s = $heap_init;
507                 let mut pow2s_count = 0;
508                 while slice_greater_than(a, &b_pow) {
509                         $pre_push!(pow2s, $len);
510                         pow2s[pow2s_count] = b_pow;
511                         pow2s_count += 1;
512                         let double_overflow = double!(b_pow);
513                         if double_overflow { break; }
514                 }
515                 let mut quot = [0; $len];
516                 let mut rem = *a;
517                 let mut pow2 = pow2s_count as isize - 1;
518                 while pow2 >= 0 {
519                         let b_pow = pow2s[pow2 as usize];
520                         let overflow = double!(quot);
521                         debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
522                         if slice_greater_than(&rem, &b_pow) {
523                                 let (r, underflow) = $sub(&rem, &b_pow);
524                                 debug_assert!(!underflow, "rem was just checked to be > b_pow, so sub cannot underflow");
525                                 rem = r;
526                                 quot[$len - 1] |= 1;
527                         }
528                         pow2 -= 1;
529                 }
530                 if slice_equal(&rem, b) {
531                         let overflow = add_u64!(quot, 1);
532                         debug_assert!(!overflow, "quotient should be expressible in $len*64 bits");
533                         Ok((quot, [0; $len]))
534                 } else {
535                         Ok((quot, rem))
536                 }
537         }
538 } }
539
540 #[cfg(fuzzing)]
541 define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
542 define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
543 define_div_rem!(div_rem_6, 6, sub_6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack!
544 #[cfg(debug_assertions)]
545 define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
546 #[cfg(debug_assertions)]
547 define_div_rem!(div_rem_12, 12, sub_12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack!
548 define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
549 #[cfg(debug_assertions)]
550 define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
551
552 macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => {
553         /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus,
554         /// if one exists.
555         const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> {
556                 if slice_equal(a, &[0; $len]) || slice_equal(m, &[0; $len]) { return Err(()); }
557
558                 let (mut s, mut old_s) = ([0; $len], [0; $len]);
559                 old_s[$len - 1] = 1;
560                 let mut r = *m;
561                 let mut old_r = *a;
562
563                 let (mut old_s_neg, mut s_neg) = (false, false);
564
565                 while !slice_equal(&r, &[0; $len]) {
566                         let (quot, new_r) = debug_unwrap!($div(&old_r, &r));
567
568                         let new_sa = $mul(&quot, &s);
569                         debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
570                         let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
571                                 (true, true) => {
572                                         let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
573                                         debug_assert!(!overflow);
574                                         (new_s, true)
575                                 }
576                                 (false, true) => {
577                                         let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
578                                         debug_assert!(!overflow);
579                                         (new_s, false)
580                                 },
581                                 (true, false) => {
582                                         let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
583                                         debug_assert!(!overflow);
584                                         (new_s, true)
585                                 },
586                                 (false, false) => $sub_abs(&old_s, const_subslice(&new_sa, $len, new_sa.len())),
587                         };
588
589                         old_r = r;
590                         r = new_r;
591
592                         old_s = s;
593                         old_s_neg = s_neg;
594                         s = new_s;
595                         s_neg = new_s_neg;
596                 }
597
598                 // At this point old_r contains our GCD and old_s our first Bézout's identity coefficient.
599                 if !slice_equal(const_subslice(&old_r, 0, $len - 1), &[0; $len - 1]) || old_r[$len - 1] != 1 {
600                         Err(())
601                 } else {
602                         debug_assert!(slice_greater_than(m, &old_s));
603                         if old_s_neg {
604                                 let (modinv, underflow) = $sub_abs(m, &old_s);
605                                 debug_assert!(!underflow);
606                                 debug_assert!(slice_greater_than(m, &modinv));
607                                 Ok(modinv)
608                         } else {
609                                 Ok(old_s)
610                         }
611                 }
612         }
613 } }
614 #[cfg(fuzzing)]
615 define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2);
616 define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4);
617 define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6);
618 #[cfg(fuzzing)]
619 define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
620
621 // ******************
622 // * The public API *
623 // ******************
624
625 const WORD_COUNT_4096: usize = 4096 / 64;
626 const WORD_COUNT_256: usize = 256 / 64;
627 const WORD_COUNT_384: usize = 384 / 64;
628
629 // RFC 5702 indicates RSA keys can be up to 4096 bits, so we always use 4096-bit integers
630 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
631 pub(super) struct U4096([u64; WORD_COUNT_4096]);
632
633 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
634 pub(super) struct U256([u64; WORD_COUNT_256]);
635
636 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
637 pub(super) struct U384([u64; WORD_COUNT_384]);
638
639 pub(super) trait Int: Clone + Ord + Sized {
640         const ZERO: Self;
641         const BYTES: usize;
642         fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
643         fn limbs(&self) -> &[u64];
644 }
645 impl Int for U256 {
646         const ZERO: U256 = U256([0; 4]);
647         const BYTES: usize = 32;
648         fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
649         fn limbs(&self) -> &[u64] { &self.0 }
650 }
651 impl Int for U384 {
652         const ZERO: U384 = U384([0; 6]);
653         const BYTES: usize = 48;
654         fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
655         fn limbs(&self) -> &[u64] { &self.0 }
656 }
657
658 /// Defines a *PRIME* Modulus
659 pub(super) trait PrimeModulus<I: Int> {
660         const PRIME: I;
661         const R_SQUARED_MOD_PRIME: I;
662         const NEGATIVE_PRIME_INV_MOD_R: I;
663 }
664
665 #[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
666 pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
667
668 #[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
669 pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
670
671 impl U4096 {
672         /// Constructs a new [`U4096`] from a variable number of big-endian bytes.
673         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
674                 if bytes.len() > 4096/8 { return Err(()); }
675                 let u64s = (bytes.len() + 7) / 8;
676                 let mut res = [0; WORD_COUNT_4096];
677                 for i in 0..u64s {
678                         let mut b = [0; 8];
679                         let pos = (u64s - i) * 8;
680                         let start = bytes.len().saturating_sub(pos);
681                         let end = bytes.len() + 8 - pos;
682                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
683                         res[i + WORD_COUNT_4096 - u64s] = u64::from_be_bytes(b);
684                 }
685                 Ok(U4096(res))
686         }
687
688         /// Naively multiplies `self` * `b` mod `m`, returning a new [`U4096`].
689         ///
690         /// Fails iff m is 0 or self or b are greater than m.
691         #[cfg(debug_assertions)]
692         fn mulmod_naive(&self, b: &U4096, m: &U4096) -> Result<U4096, ()> {
693                 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
694                 if self > m || b > m { return Err(()); }
695
696                 let mul = mul_64(&self.0, &b.0);
697
698                 let mut m_zeros = [0; 128];
699                 m_zeros[WORD_COUNT_4096..].copy_from_slice(&m.0);
700                 let (_, rem) = div_rem_128(&mul, &m_zeros)?;
701                 let mut res = [0; WORD_COUNT_4096];
702                 debug_assert_eq!(&rem[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
703                 res.copy_from_slice(&rem[WORD_COUNT_4096..]);
704                 Ok(U4096(res))
705         }
706
707         /// Calculates `self` ^ `exp` mod `m`, returning a new [`U4096`].
708         ///
709         /// Fails iff m is 0, even, or self or b are greater than m.
710         pub(super) fn expmod_odd_mod(&self, mut exp: u32, m: &U4096) -> Result<U4096, ()> {
711                 #![allow(non_camel_case_types)]
712
713                 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
714                 if m.0[WORD_COUNT_4096 - 1] & 1 == 0 { return Err(()); }
715                 if self > m { return Err(()); }
716
717                 let mut t = [0; WORD_COUNT_4096];
718                 if m.0[..WORD_COUNT_4096 - 1] == [0; WORD_COUNT_4096 - 1] && m.0[WORD_COUNT_4096 - 1] == 1 {
719                         return Ok(U4096(t));
720                 }
721                 t[WORD_COUNT_4096 - 1] = 1;
722                 if exp == 0 { return Ok(U4096(t)); }
723
724                 // Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is
725                 // guaranteed to be co-prime with any non-even integer.
726
727                 // We use a single 4096-bit integer type for all our RSA operations, though in most cases
728                 // we're actually dealing with 1024-bit or 2048-bit ints. Thus, we define sub-array math
729                 // here which debug_assert's the required bits are 0s and then uses faster math primitives.
730
731                 type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
732                 type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
733                 type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
734                 type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool);
735                 let (word_count, log_bits, mul, sqr, add_double, sub) =
736                         if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
737                                 if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] {
738                                         fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
739                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
740                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
741                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
742                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
743                                                 let mut res = [0; WORD_COUNT_4096 * 2];
744                                                 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
745                                                         &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]));
746                                                 res
747                                         }
748                                         fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
749                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
750                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
751                                                 let mut res = [0; WORD_COUNT_4096 * 2];
752                                                 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
753                                                         &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
754                                                 res
755                                         }
756                                         fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
757                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
758                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
759                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
760                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
761                                                 let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]);
762                                                 let mut res = [0; WORD_COUNT_4096 * 2];
763                                                 res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
764                                                 (res, overflow)
765                                         }
766                                         fn sub_16_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
767                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
768                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
769                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
770                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
771                                                 let (sub, underflow) = sub_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]);
772                                                 let mut res = [0; WORD_COUNT_4096];
773                                                 res[WORD_COUNT_4096 * 3 / 4..].copy_from_slice(&sub);
774                                                 (res, underflow)
775                                         }
776                                         (16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty)
777                                 } else {
778                                         fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
779                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
780                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
781                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
782                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
783                                                 let mut res = [0; WORD_COUNT_4096 * 2];
784                                                 res[WORD_COUNT_4096..].copy_from_slice(
785                                                         &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]));
786                                                 res
787                                         }
788                                         fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
789                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
790                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
791                                                 let mut res = [0; WORD_COUNT_4096 * 2];
792                                                 res[WORD_COUNT_4096..].copy_from_slice(
793                                                         &sqr_32(&a[WORD_COUNT_4096 / 2..]));
794                                                 res
795                                         }
796                                         fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
797                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
798                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
799                                                 debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
800                                                 debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
801                                                 let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]);
802                                                 let mut res = [0; WORD_COUNT_4096 * 2];
803                                                 res[WORD_COUNT_4096..].copy_from_slice(&add);
804                                                 (res, overflow)
805                                         }
806                                         fn sub_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
807                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
808                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
809                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
810                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
811                                                 let (sub, underflow) = sub_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]);
812                                                 let mut res = [0; WORD_COUNT_4096];
813                                                 res[WORD_COUNT_4096 / 2..].copy_from_slice(&sub);
814                                                 (res, underflow)
815                                         }
816                                         (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty)
817                                 }
818                         } else {
819                                 (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
820                         };
821
822                 // r is always the even value with one bit set above the word count we're using.
823                 let mut r = [0; WORD_COUNT_4096 * 2];
824                 r[WORD_COUNT_4096 * 2 - word_count - 1] = 1;
825
826                 let mut m_inv_pos = [0; WORD_COUNT_4096];
827                 m_inv_pos[WORD_COUNT_4096 - 1] = 1;
828                 let mut two = [0; WORD_COUNT_4096];
829                 two[WORD_COUNT_4096 - 1] = 2;
830                 for _ in 0..log_bits {
831                         let mut m_m_inv = mul(&m_inv_pos, &m.0);
832                         m_m_inv[..WORD_COUNT_4096 * 2 - word_count].fill(0);
833                         let m_inv = mul(&sub(&two, &m_m_inv[WORD_COUNT_4096..]).0, &m_inv_pos);
834                         m_inv_pos[WORD_COUNT_4096 - word_count..].copy_from_slice(&m_inv[WORD_COUNT_4096 * 2 - word_count..]);
835                 }
836                 m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0);
837
838                 // `m_inv` is the negative modular inverse of m mod R, so subtract m_inv from R.
839                 let mut m_inv = m_inv_pos;
840                 negate!(m_inv);
841                 m_inv[..WORD_COUNT_4096 - word_count].fill(0);
842                 debug_assert_eq!(&mul(&m_inv, &m.0)[WORD_COUNT_4096 * 2 - word_count..],
843                         // R - 1 == -1 % R
844                         &[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]);
845
846                 let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] {
847                         debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2],
848                                 &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
849                         // Do a montgomery reduction of `mu`
850
851                         // The definition of REDC (with some names changed):
852                         // v = ((mu % R) * N') mod R
853                         // t = (mu + v*N) / R
854                         // if t >= N { t - N } else { t }
855
856                         // mu % R is just the bottom word_count bytes of mu
857                         let mut mu_mod_r = [0; WORD_COUNT_4096];
858                         mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
859
860                         // v = ((mu % R) * negative_modulus_inverse) % R
861                         let mut v = mul(&mu_mod_r, &m_inv);
862                         v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
863
864                         // t_on_r = (mu + v*modulus) / R
865                         let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
866                         let (t1, t1_extra_bit) = add_double(&t0, &mu);
867
868                         // Note that dividing t1 by R is simply a matter of shifting right by word_count bytes
869                         // We only need to maintain word_count bytes (plus `t1_extra_bit` which is implicitly
870                         // an extra bit) because t_on_r is guarantee to be, at max, 2*m - 1.
871                         let mut t1_on_r = [0; WORD_COUNT_4096];
872                         debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..],
873                                 "t1 should be divisible by r");
874                         t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]);
875
876                         // The modulus has only word_count bytes, so if t1_extra_bit is set we are definitely
877                         // larger than the modulus.
878                         if t1_extra_bit || t1_on_r >= m.0 {
879                                 let underflow;
880                                 (t1_on_r, underflow) = sub(&t1_on_r, &m.0);
881                                 debug_assert_eq!(t1_extra_bit, underflow,
882                                         "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
883                         }
884                         t1_on_r
885                 };
886
887                 // Calculate R^2 mod m as ((2^DOUBLES * R) mod m)^(log_bits - LOG2_DOUBLES) mod R
888                 let mut r_minus_one = [0xffff_ffff_ffff_ffffu64; WORD_COUNT_4096];
889                 r_minus_one[..WORD_COUNT_4096 - word_count].fill(0);
890                 // While we do a full div here, in general R should be less than 2x m (assuming the RSA
891                 // modulus used its full bit range and is 1024, 2048, or 4096 bits), so it should be cheap.
892                 // In cases with a nonstandard RSA modulus we may end up being pretty slow here, but we'll
893                 // survive.
894                 // If we ever find a problem with this we should reduce R to be tigher on m, as we're
895                 // wasting extra bits of calculation if R is too far from m.
896                 let (_, mut r_mod_m) = debug_unwrap!(div_rem_64(&r_minus_one, &m.0));
897                 let r_mod_m_overflow = add_u64!(r_mod_m, 1);
898                 if r_mod_m_overflow || r_mod_m >= m.0 {
899                         (r_mod_m, _) = sub_64(&r_mod_m, &m.0);
900                 }
901
902                 let mut r2_mod_m: [u64; 64] = r_mod_m;
903                 const DOUBLES: usize = 32;
904                 const LOG2_DOUBLES: usize = 5;
905
906                 for _ in 0..DOUBLES {
907                         let overflow = double!(r2_mod_m);
908                         if overflow || r2_mod_m > m.0 {
909                                 (r2_mod_m, _) = sub_64(&r2_mod_m, &m.0);
910                         }
911                 }
912                 for _ in 0..log_bits - LOG2_DOUBLES {
913                         r2_mod_m = mont_reduction(sqr(&r2_mod_m));
914                 }
915                 // Clear excess high bits
916                 for (m_limb, r2_limb) in m.0.iter().zip(r2_mod_m.iter_mut()) {
917                         let clear_bits = m_limb.leading_zeros();
918                         if clear_bits == 0 { break; }
919                         *r2_limb &= !(0xffff_ffff_ffff_ffffu64 << (64 - clear_bits));
920                         if *m_limb != 0 { break; }
921                 }
922
923                 debug_assert!(r2_mod_m < m.0);
924                 #[cfg(debug_assertions)] {
925                         debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
926                 }
927
928                 // Finally, actually do the exponentiation...
929
930                 // Calculate t * R and a * R as mont multiplications by R^2 mod m
931                 // (i.e. t * R^2 / R and 1 * R^2 / R)
932                 let mut tr = mont_reduction(mul(&r2_mod_m, &t));
933                 let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
934
935                 #[cfg(debug_assertions)] {
936                         // Check that tr/ar match naive multiplication
937                         debug_assert_eq!(&tr, &U4096(t).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
938                         debug_assert_eq!(&ar, &self.mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
939                 }
940
941                 while exp != 1 {
942                         if exp % 2 == 1 {
943                                 tr = mont_reduction(mul(&tr, &ar));
944                                 exp -= 1;
945                         }
946                         ar = mont_reduction(sqr(&ar));
947                         exp /= 2;
948                 }
949                 ar = mont_reduction(mul(&ar, &tr));
950                 let mut resr = [0; WORD_COUNT_4096 * 2];
951                 resr[WORD_COUNT_4096..].copy_from_slice(&ar);
952                 Ok(U4096(mont_reduction(resr)))
953         }
954 }
955
956 // In a const context we can't subslice a slice, so instead we pick the eight bytes we want and
957 // pass them here to build u64s from arrays.
958 const fn eight_bytes_to_u64_be(a: u8, b: u8, c: u8, d: u8, e: u8, f: u8, g: u8, h: u8) -> u64 {
959         let b = [a, b, c, d, e, f, g, h];
960         u64::from_be_bytes(b)
961 }
962
963 impl U256 {
964         /// Constructs a new [`U256`] from a variable number of big-endian bytes.
965         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U256, ()> {
966                 if bytes.len() > 256/8 { return Err(()); }
967                 let u64s = (bytes.len() + 7) / 8;
968                 let mut res = [0; WORD_COUNT_256];
969                 for i in 0..u64s {
970                         let mut b = [0; 8];
971                         let pos = (u64s - i) * 8;
972                         let start = bytes.len().saturating_sub(pos);
973                         let end = bytes.len() + 8 - pos;
974                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
975                         res[i + WORD_COUNT_256 - u64s] = u64::from_be_bytes(b);
976                 }
977                 Ok(U256(res))
978         }
979
980         /// Constructs a new [`U256`] from a fixed number of big-endian bytes.
981         pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 {
982                 let res = [
983                         eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3],
984                                               bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]),
985                         eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3],
986                                               bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]),
987                         eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3],
988                                               bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]),
989                         eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3],
990                                               bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]),
991                 ];
992                 U256(res)
993         }
994
995         pub(super) const fn zero() -> U256 { U256([0, 0, 0, 0]) }
996         pub(super) const fn one() -> U256 { U256([0, 0, 0, 1]) }
997         pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) }
998 }
999
1000 // Values modulus M::PRIME.0, stored in montgomery form.
1001 impl<M: PrimeModulus<U256>> U256Mod<M> {
1002         const fn mont_reduction(mu: [u64; 8]) -> Self {
1003                 #[cfg(debug_assertions)] {
1004                         // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
1005                         // should be able to do it at compile time alone.
1006                         let minus_one_mod_r = mul_4(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1007                         assert!(slice_equal(const_subslice(&minus_one_mod_r, 4, 8), &[0xffff_ffff_ffff_ffff; 4]));
1008                 }
1009
1010                 #[cfg(debug_assertions)] {
1011                         // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
1012                         // should be able to do it at compile time alone.
1013                         let r_minus_one = [0xffff_ffff_ffff_ffff; 4];
1014                         let (mut r_mod_prime, _) = sub_4(&r_minus_one, &M::PRIME.0);
1015                         let r_mod_prime_overflow = add_u64!(r_mod_prime, 1);
1016                         assert!(!r_mod_prime_overflow);
1017                         let r_squared = sqr_4(&r_mod_prime);
1018                         let mut prime_extended = [0; 8];
1019                         let prime = M::PRIME.0;
1020                         copy_from_slice!(prime_extended, 4, 8, prime);
1021                         let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_8(&r_squared, &prime_extended) { v } else { panic!() };
1022                         assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
1023                         assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
1024                 }
1025
1026                 // The definition of REDC (with some names changed):
1027                 // v = ((mu % R) * N') mod R
1028                 // t = (mu + v*N) / R
1029                 // if t >= N { t - N } else { t }
1030
1031                 // mu % R is just the bottom 4 bytes of mu
1032                 let mu_mod_r = const_subslice(&mu, 4, 8);
1033                 // v = ((mu % R) * negative_modulus_inverse) % R
1034                 let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1035                 const ZEROS: &[u64; 4] = &[0; 4];
1036                 copy_from_slice!(v, 0, 4, ZEROS); // mod R
1037
1038                 // t_on_r = (mu + v*modulus) / R
1039                 let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0);
1040                 let (t1, t1_extra_bit) = add_8(&t0, &mu);
1041
1042                 // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
1043                 // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
1044                 // because t_on_r is guarantee to be, at max, 2*m - 1.
1045                 let t1_on_r = const_subslice(&t1, 0, 4);
1046
1047                 let mut res = [0; 4];
1048                 // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the
1049                 // modulus.
1050                 if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
1051                         let underflow;
1052                         (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0);
1053                         debug_assert!(t1_extra_bit == underflow,
1054                                 "The number (t1_extra_bit, t1_on_r) is at most 2m-1, so underflowing t1_on_r - m should happen iff t1_extra_bit is set.");
1055                 } else {
1056                         copy_from_slice!(res, 0, 4, t1_on_r);
1057                 }
1058                 Self(U256(res), PhantomData)
1059         }
1060
1061         pub(super) const fn from_u256_panicking(v: U256) -> Self {
1062                 assert!(v.0[0] <= M::PRIME.0[0]);
1063                 if v.0[0] == M::PRIME.0[0] {
1064                         assert!(v.0[1] <= M::PRIME.0[1]);
1065                         if v.0[1] == M::PRIME.0[1] {
1066                                 assert!(v.0[2] <= M::PRIME.0[2]);
1067                                 if v.0[2] == M::PRIME.0[2] {
1068                                         assert!(v.0[3] < M::PRIME.0[3]);
1069                                 }
1070                         }
1071                 }
1072                 assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0 || M::PRIME.0[3] != 0);
1073                 Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1074         }
1075
1076         pub(super) fn from_u256(mut v: U256) -> Self {
1077                 debug_assert!(M::PRIME.0 != [0; 4]);
1078                 debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
1079                 while v >= M::PRIME {
1080                         let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
1081                         debug_assert!(!spurious_underflow, "v was > M::PRIME.0");
1082                         v = U256(new_v);
1083                 }
1084                 Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1085         }
1086
1087         pub(super) fn from_modinv_of(v: U256) -> Result<Self, ()> {
1088                 Ok(Self::from_u256(U256(mod_inv_4(&v.0, &M::PRIME.0)?)))
1089         }
1090
1091         /// Multiplies `self` * `b` mod `m`.
1092         ///
1093         /// Panics if `self`'s modulus is not equal to `b`'s
1094         pub(super) fn mul(&self, b: &Self) -> Self {
1095                 Self::mont_reduction(mul_4(&self.0.0, &b.0.0))
1096         }
1097
1098         /// Doubles `self` mod `m`.
1099         pub(super) fn double(&self) -> Self {
1100                 let mut res = self.0.0;
1101                 let overflow = double!(res);
1102                 if overflow || !slice_greater_than(&M::PRIME.0, &res) {
1103                         let underflow;
1104                         (res, underflow) = sub_4(&res, &M::PRIME.0);
1105                         debug_assert_eq!(overflow, underflow);
1106                 }
1107                 Self(U256(res), PhantomData)
1108         }
1109
1110         /// Multiplies `self` by 3 mod `m`.
1111         pub(super) fn times_three(&self) -> Self {
1112                 // TODO: Optimize this a lot
1113                 self.mul(&U256Mod::from_u256(U256::three()))
1114         }
1115
1116         /// Multiplies `self` by 4 mod `m`.
1117         pub(super) fn times_four(&self) -> Self {
1118                 // TODO: Optimize this somewhat?
1119                 self.double().double()
1120         }
1121
1122         /// Multiplies `self` by 8 mod `m`.
1123         pub(super) fn times_eight(&self) -> Self {
1124                 // TODO: Optimize this somewhat?
1125                 self.double().double().double()
1126         }
1127
1128         /// Multiplies `self` by 8 mod `m`.
1129         pub(super) fn square(&self) -> Self {
1130                 Self::mont_reduction(sqr_4(&self.0.0))
1131         }
1132
1133         /// Subtracts `b` from `self` % `m`.
1134         pub(super) fn sub(&self, b: &Self) -> Self {
1135                 let (mut val, underflow) = sub_4(&self.0.0, &b.0.0);
1136                 if underflow {
1137                         let overflow;
1138                         (val, overflow) = add_4(&val, &M::PRIME.0);
1139                         debug_assert_eq!(overflow, underflow);
1140                 }
1141                 Self(U256(val), PhantomData)
1142         }
1143
1144         /// Adds `b` to `self` % `m`.
1145         pub(super) fn add(&self, b: &Self) -> Self {
1146                 let (mut val, overflow) = add_4(&self.0.0, &b.0.0);
1147                 if overflow || !slice_greater_than(&M::PRIME.0, &val) {
1148                         let underflow;
1149                         (val, underflow) = sub_4(&val, &M::PRIME.0);
1150                         debug_assert_eq!(overflow, underflow);
1151                 }
1152                 Self(U256(val), PhantomData)
1153         }
1154
1155         /// Returns the underlying [`U256`].
1156         pub(super) fn into_u256(self) -> U256 {
1157                 let mut expanded_self = [0; 8];
1158                 expanded_self[4..].copy_from_slice(&self.0.0);
1159                 Self::mont_reduction(expanded_self).0
1160         }
1161 }
1162
1163 // Values modulus M::PRIME.0, stored in montgomery form.
1164 impl U384 {
1165         /// Constructs a new [`U384`] from a variable number of big-endian bytes.
1166         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U384, ()> {
1167                 if bytes.len() > 384/8 { return Err(()); }
1168                 let u64s = (bytes.len() + 7) / 8;
1169                 let mut res = [0; WORD_COUNT_384];
1170                 for i in 0..u64s {
1171                         let mut b = [0; 8];
1172                         let pos = (u64s - i) * 8;
1173                         let start = bytes.len().saturating_sub(pos);
1174                         let end = bytes.len() + 8 - pos;
1175                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
1176                         res[i + WORD_COUNT_384 - u64s] = u64::from_be_bytes(b);
1177                 }
1178                 Ok(U384(res))
1179         }
1180
1181         /// Constructs a new [`U384`] from a fixed number of big-endian bytes.
1182         pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 {
1183                 let res = [
1184                         eight_bytes_to_u64_be(bytes[0*8 + 0], bytes[0*8 + 1], bytes[0*8 + 2], bytes[0*8 + 3],
1185                                               bytes[0*8 + 4], bytes[0*8 + 5], bytes[0*8 + 6], bytes[0*8 + 7]),
1186                         eight_bytes_to_u64_be(bytes[1*8 + 0], bytes[1*8 + 1], bytes[1*8 + 2], bytes[1*8 + 3],
1187                                               bytes[1*8 + 4], bytes[1*8 + 5], bytes[1*8 + 6], bytes[1*8 + 7]),
1188                         eight_bytes_to_u64_be(bytes[2*8 + 0], bytes[2*8 + 1], bytes[2*8 + 2], bytes[2*8 + 3],
1189                                               bytes[2*8 + 4], bytes[2*8 + 5], bytes[2*8 + 6], bytes[2*8 + 7]),
1190                         eight_bytes_to_u64_be(bytes[3*8 + 0], bytes[3*8 + 1], bytes[3*8 + 2], bytes[3*8 + 3],
1191                                               bytes[3*8 + 4], bytes[3*8 + 5], bytes[3*8 + 6], bytes[3*8 + 7]),
1192                         eight_bytes_to_u64_be(bytes[4*8 + 0], bytes[4*8 + 1], bytes[4*8 + 2], bytes[4*8 + 3],
1193                                               bytes[4*8 + 4], bytes[4*8 + 5], bytes[4*8 + 6], bytes[4*8 + 7]),
1194                         eight_bytes_to_u64_be(bytes[5*8 + 0], bytes[5*8 + 1], bytes[5*8 + 2], bytes[5*8 + 3],
1195                                               bytes[5*8 + 4], bytes[5*8 + 5], bytes[5*8 + 6], bytes[5*8 + 7]),
1196                 ];
1197                 U384(res)
1198         }
1199
1200         pub(super) const fn zero() -> U384 { U384([0, 0, 0, 0, 0, 0]) }
1201         pub(super) const fn one() -> U384 { U384([0, 0, 0, 0, 0, 1]) }
1202         pub(super) const fn three() -> U384 { U384([0, 0, 0, 0, 0, 3]) }
1203 }
1204
1205 impl<M: PrimeModulus<U384>> U384Mod<M> {
1206         const fn mont_reduction(mu: [u64; 12]) -> Self {
1207                 #[cfg(debug_assertions)] {
1208                         // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
1209                         // should be able to do it at compile time alone.
1210                         let minus_one_mod_r = mul_6(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1211                         assert!(slice_equal(const_subslice(&minus_one_mod_r, 6, 12), &[0xffff_ffff_ffff_ffff; 6]));
1212                 }
1213
1214                 #[cfg(debug_assertions)] {
1215                         // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
1216                         // should be able to do it at compile time alone.
1217                         let r_minus_one = [0xffff_ffff_ffff_ffff; 6];
1218                         let (mut r_mod_prime, _) = sub_6(&r_minus_one, &M::PRIME.0);
1219                         let r_mod_prime_overflow = add_u64!(r_mod_prime, 1);
1220                         assert!(!r_mod_prime_overflow);
1221                         let r_squared = sqr_6(&r_mod_prime);
1222                         let mut prime_extended = [0; 12];
1223                         let prime = M::PRIME.0;
1224                         copy_from_slice!(prime_extended, 6, 12, prime);
1225                         let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_12(&r_squared, &prime_extended) { v } else { panic!() };
1226                         assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
1227                         assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
1228                 }
1229
1230                 // The definition of REDC (with some names changed):
1231                 // v = ((mu % R) * N') mod R
1232                 // t = (mu + v*N) / R
1233                 // if t >= N { t - N } else { t }
1234
1235                 // mu % R is just the bottom 4 bytes of mu
1236                 let mu_mod_r = const_subslice(&mu, 6, 12);
1237                 // v = ((mu % R) * negative_modulus_inverse) % R
1238                 let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1239                 const ZEROS: &[u64; 6] = &[0; 6];
1240                 copy_from_slice!(v, 0, 6, ZEROS); // mod R
1241
1242                 // t_on_r = (mu + v*modulus) / R
1243                 let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0);
1244                 let (t1, t1_extra_bit) = add_12(&t0, &mu);
1245
1246                 // Note that dividing t1 by R is simply a matter of shifting right by 4 bytes.
1247                 // We only need to maintain 4 bytes (plus `t1_extra_bit` which is implicitly an extra bit)
1248                 // because t_on_r is guarantee to be, at max, 2*m - 1.
1249                 let t1_on_r = const_subslice(&t1, 0, 6);
1250
1251                 let mut res = [0; 6];
1252                 // The modulus is only 4 bytes, so t1_extra_bit implies we're definitely larger than the
1253                 // modulus.
1254                 if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
1255                         let underflow;
1256                         (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0);
1257                         debug_assert!(t1_extra_bit == underflow);
1258                 } else {
1259                         copy_from_slice!(res, 0, 6, t1_on_r);
1260                 }
1261                 Self(U384(res), PhantomData)
1262         }
1263
1264         pub(super) const fn from_u384_panicking(v: U384) -> Self {
1265                 assert!(v.0[0] <= M::PRIME.0[0]);
1266                 if v.0[0] == M::PRIME.0[0] {
1267                         assert!(v.0[1] <= M::PRIME.0[1]);
1268                         if v.0[1] == M::PRIME.0[1] {
1269                                 assert!(v.0[2] <= M::PRIME.0[2]);
1270                                 if v.0[2] == M::PRIME.0[2] {
1271                                         assert!(v.0[3] <= M::PRIME.0[3]);
1272                                         if v.0[3] == M::PRIME.0[3] {
1273                                                 assert!(v.0[4] <= M::PRIME.0[4]);
1274                                                 if v.0[4] == M::PRIME.0[4] {
1275                                                         assert!(v.0[5] < M::PRIME.0[5]);
1276                                                 }
1277                                         }
1278                                 }
1279                         }
1280                 }
1281                 assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0
1282                         || M::PRIME.0[3] != 0|| M::PRIME.0[4] != 0|| M::PRIME.0[5] != 0);
1283                 Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1284         }
1285
1286         pub(super) fn from_u384(mut v: U384) -> Self {
1287                 debug_assert!(M::PRIME.0 != [0; 6]);
1288                 debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
1289                 while v >= M::PRIME {
1290                         let (new_v, spurious_underflow) = sub_6(&v.0, &M::PRIME.0);
1291                         debug_assert!(!spurious_underflow);
1292                         v = U384(new_v);
1293                 }
1294                 Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1295         }
1296
1297         pub(super) fn from_modinv_of(v: U384) -> Result<Self, ()> {
1298                 Ok(Self::from_u384(U384(mod_inv_6(&v.0, &M::PRIME.0)?)))
1299         }
1300
1301         /// Multiplies `self` * `b` mod `m`.
1302         ///
1303         /// Panics if `self`'s modulus is not equal to `b`'s
1304         pub(super) fn mul(&self, b: &Self) -> Self {
1305                 Self::mont_reduction(mul_6(&self.0.0, &b.0.0))
1306         }
1307
1308         /// Doubles `self` mod `m`.
1309         pub(super) fn double(&self) -> Self {
1310                 let mut res = self.0.0;
1311                 let overflow = double!(res);
1312                 if overflow || !slice_greater_than(&M::PRIME.0, &res) {
1313                         let underflow;
1314                         (res, underflow) = sub_6(&res, &M::PRIME.0);
1315                         debug_assert_eq!(overflow, underflow);
1316                 }
1317                 Self(U384(res), PhantomData)
1318         }
1319
1320         /// Multiplies `self` by 3 mod `m`.
1321         pub(super) fn times_three(&self) -> Self {
1322                 // TODO: Optimize this a lot
1323                 self.mul(&U384Mod::from_u384(U384::three()))
1324         }
1325
1326         /// Multiplies `self` by 4 mod `m`.
1327         pub(super) fn times_four(&self) -> Self {
1328                 // TODO: Optimize this somewhat?
1329                 self.double().double()
1330         }
1331
1332         /// Multiplies `self` by 8 mod `m`.
1333         pub(super) fn times_eight(&self) -> Self {
1334                 // TODO: Optimize this somewhat?
1335                 self.double().double().double()
1336         }
1337
1338         /// Multiplies `self` by 8 mod `m`.
1339         pub(super) fn square(&self) -> Self {
1340                 Self::mont_reduction(sqr_6(&self.0.0))
1341         }
1342
1343         /// Subtracts `b` from `self` % `m`.
1344         pub(super) fn sub(&self, b: &Self) -> Self {
1345                 let (mut val, underflow) = sub_6(&self.0.0, &b.0.0);
1346                 if underflow {
1347                         let overflow;
1348                         (val, overflow) = add_6(&val, &M::PRIME.0);
1349                         debug_assert_eq!(overflow, underflow);
1350                 }
1351                 Self(U384(val), PhantomData)
1352         }
1353
1354         /// Adds `b` to `self` % `m`.
1355         pub(super) fn add(&self, b: &Self) -> Self {
1356                 let (mut val, overflow) = add_6(&self.0.0, &b.0.0);
1357                 if overflow || !slice_greater_than(&M::PRIME.0, &val) {
1358                         let underflow;
1359                         (val, underflow) = sub_6(&val, &M::PRIME.0);
1360                         debug_assert_eq!(overflow, underflow);
1361                 }
1362                 Self(U384(val), PhantomData)
1363         }
1364
1365         /// Returns the underlying [`U384`].
1366         pub(super) fn into_u384(self) -> U384 {
1367                 let mut expanded_self = [0; 12];
1368                 expanded_self[6..].copy_from_slice(&self.0.0);
1369                 Self::mont_reduction(expanded_self).0
1370         }
1371 }
1372
1373 #[cfg(fuzzing)]
1374 mod fuzz_moduli {
1375         use super::*;
1376
1377         pub struct P256();
1378         impl PrimeModulus<U256> for P256 {
1379                 const PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
1380                         "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"));
1381                 const R_SQUARED_MOD_PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
1382                         "00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003"));
1383                 const NEGATIVE_PRIME_INV_MOD_R: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
1384                         "ffffffff00000002000000000000000000000001000000000000000000000001"));
1385         }
1386
1387         pub struct P384();
1388         impl PrimeModulus<U384> for P384 {
1389                 const PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
1390                         "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff"));
1391                 const R_SQUARED_MOD_PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
1392                         "000000000000000000000000000000010000000200000000fffffffe000000000000000200000000fffffffe00000001"));
1393                 const NEGATIVE_PRIME_INV_MOD_R: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
1394                         "00000014000000140000000c00000002fffffffcfffffffafffffffbfffffffe00000000000000010000000100000001"));
1395         }
1396 }
1397
1398 #[cfg(fuzzing)]
1399 extern crate ibig;
1400 #[cfg(fuzzing)]
1401 /// Read some bytes and use them to test bigint math by comparing results against the `ibig` crate.
1402 pub fn fuzz_math(input: &[u8]) {
1403         if input.len() < 32 || input.len() % 16 != 0 { return; }
1404         let split = core::cmp::min(input.len() / 2, 512);
1405         let (a, b) = input.split_at(core::cmp::min(input.len() / 2, 512));
1406         let b = &b[..split];
1407
1408         let ai = ibig::UBig::from_be_bytes(&a);
1409         let bi = ibig::UBig::from_be_bytes(&b);
1410
1411         let mut a_u64s = Vec::with_capacity(split / 8);
1412         for chunk in a.chunks(8) {
1413                 a_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
1414         }
1415         let mut b_u64s = Vec::with_capacity(split / 8);
1416         for chunk in b.chunks(8) {
1417                 b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
1418         }
1419
1420         macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
1421                 let res = $mul(&a_u64s, &b_u64s);
1422                 let mut res_bytes = Vec::with_capacity(input.len() / 2);
1423                 for i in res {
1424                         res_bytes.extend_from_slice(&i.to_be_bytes());
1425                 }
1426                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone());
1427
1428                 debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
1429                 debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
1430
1431                 let (res, carry) = $add(&a_u64s, &b_u64s);
1432                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
1433                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
1434                 for i in res {
1435                         res_bytes.extend_from_slice(&i.to_be_bytes());
1436                 }
1437                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + bi.clone());
1438
1439                 let mut add_u64s = a_u64s.clone();
1440                 let carry = add_u64!(add_u64s, 1);
1441                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
1442                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
1443                 for i in &add_u64s {
1444                         res_bytes.extend_from_slice(&i.to_be_bytes());
1445                 }
1446                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + 1);
1447
1448                 let mut double_u64s = b_u64s.clone();
1449                 let carry = double!(double_u64s);
1450                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
1451                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
1452                 for i in &double_u64s {
1453                         res_bytes.extend_from_slice(&i.to_be_bytes());
1454                 }
1455                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), bi.clone() * 2);
1456
1457                 let (quot, rem) = if let Ok(res) =
1458                         $div_rem(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
1459                                 res
1460                         } else { return };
1461                 let mut quot_bytes = Vec::with_capacity(input.len() / 2);
1462                 for i in quot {
1463                         quot_bytes.extend_from_slice(&i.to_be_bytes());
1464                 }
1465                 let mut rem_bytes = Vec::with_capacity(input.len() / 2);
1466                 for i in rem {
1467                         rem_bytes.extend_from_slice(&i.to_be_bytes());
1468                 }
1469                 let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi);
1470                 assert_eq!(ibig::UBig::from_be_bytes(&quot_bytes), quoti);
1471                 assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi);
1472
1473                 if ai != ibig::UBig::from(0u32) { // ibig provides a spurious modular inverse for 0
1474                         let ring = ibig::modular::ModuloRing::new(&bi);
1475                         let ar = ring.from(ai.clone());
1476                         let invi = ar.inverse().map(|i| i.residue());
1477
1478                         if let Ok(modinv) = $mod_inv(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
1479                                 let mut modinv_bytes = Vec::with_capacity(input.len() / 2);
1480                                 for i in modinv {
1481                                         modinv_bytes.extend_from_slice(&i.to_be_bytes());
1482                                 }
1483                                 assert_eq!(invi.unwrap(), ibig::UBig::from_be_bytes(&modinv_bytes));
1484                         } else {
1485                                 assert!(invi.is_none());
1486                         }
1487                 }
1488         } }
1489
1490         macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => {
1491                 // Test the U256/U384Mod wrapper, which operates in Montgomery representation
1492                 let mut p_extended = [0; $len * 2];
1493                 p_extended[$len..].copy_from_slice(&$PRIME);
1494
1495                 let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
1496                 assert_eq!(&amodp_squared[..$len], &[0; $len]);
1497                 assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
1498
1499                 let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1;
1500                 assert_eq!(&abmodp[..$len], &[0; $len]);
1501                 assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
1502
1503                 let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s);
1504                 let mut aplusb_extended = [0; $len * 2];
1505                 aplusb_extended[$len..].copy_from_slice(&aplusb);
1506                 if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
1507                 let aplusbmodp = $div_rem_double(&aplusb_extended, &p_extended).unwrap().1;
1508                 assert_eq!(&aplusbmodp[..$len], &[0; $len]);
1509                 assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
1510
1511                 let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s);
1512                 if aminusb_underflow {
1513                         let mut overflow;
1514                         (aminusb, overflow) = $add(&aminusb, &$PRIME);
1515                         if !overflow {
1516                                 (aminusb, overflow) = $add(&aminusb, &$PRIME);
1517                         }
1518                         assert!(overflow);
1519                 }
1520                 let aminusbmodp = $div_rem(&aminusb, &$PRIME).unwrap().1;
1521                 assert_eq!(&$amodp.sub(&$bmodp).$into().0, &aminusbmodp);
1522         } }
1523
1524         if a_u64s.len() == 2 {
1525                 test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2);
1526         } else if a_u64s.len() == 4 {
1527                 test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4);
1528                 let amodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(a_u64s[..].try_into().unwrap()));
1529                 let bmodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(b_u64s[..].try_into().unwrap()));
1530                 test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4);
1531         } else if a_u64s.len() == 6 {
1532                 test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6);
1533                 let amodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(a_u64s[..].try_into().unwrap()));
1534                 let bmodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(b_u64s[..].try_into().unwrap()));
1535                 test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6);
1536         } else if a_u64s.len() == 8 {
1537                 test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8);
1538         } else if input.len() == 512*2 + 4 {
1539                 let mut e_bytes = [0; 4];
1540                 e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);
1541                 let e = u32::from_le_bytes(e_bytes);
1542                 let a = U4096::from_be_bytes(&a).unwrap();
1543                 let b = U4096::from_be_bytes(&b).unwrap();
1544
1545                 let res = if let Ok(r) = a.expmod_odd_mod(e, &b) { r } else { return };
1546                 let mut res_bytes = Vec::with_capacity(512);
1547                 for i in res.0 {
1548                         res_bytes.extend_from_slice(&i.to_be_bytes());
1549                 }
1550
1551                 let ring = ibig::modular::ModuloRing::new(&bi);
1552                 let ar = ring.from(ai.clone());
1553                 assert_eq!(ar.pow(&e.into()).residue(), ibig::UBig::from_be_bytes(&res_bytes));
1554         }
1555 }
1556
1557 #[cfg(test)]
1558 mod tests {
1559         use super::*;
1560
1561         fn u64s_to_u128(v: [u64; 2]) -> u128 {
1562                 let mut r = 0;
1563                 r |= v[1] as u128;
1564                 r |= (v[0] as u128) << 64;
1565                 r
1566         }
1567
1568         fn u64s_to_i128(v: [u64; 2]) -> i128 {
1569                 let mut r = 0;
1570                 r |= v[1] as i128;
1571                 r |= (v[0] as i128) << 64;
1572                 r
1573         }
1574
1575         #[test]
1576         fn test_negate() {
1577                 let mut zero = [0u64; 2];
1578                 negate!(zero);
1579                 assert_eq!(zero, [0; 2]);
1580
1581                 let mut one = [0u64, 1u64];
1582                 negate!(one);
1583                 assert_eq!(u64s_to_i128(one), -1);
1584
1585                 let mut minus_one: [u64; 2] = [u64::MAX, u64::MAX];
1586                 negate!(minus_one);
1587                 assert_eq!(minus_one, [0, 1]);
1588         }
1589
1590         #[test]
1591         fn test_double() {
1592                 let mut zero = [0u64; 2];
1593                 assert!(!double!(zero));
1594                 assert_eq!(zero, [0; 2]);
1595
1596                 let mut one = [0u64, 1u64];
1597                 assert!(!double!(one));
1598                 assert_eq!(one, [0, 2]);
1599
1600                 let mut u64_max = [0, u64::MAX];
1601                 assert!(!double!(u64_max));
1602                 assert_eq!(u64_max, [1, u64::MAX - 1]);
1603
1604                 let mut u64_carry_overflow = [0x7fff_ffff_ffff_ffffu64, 0x8000_0000_0000_0000];
1605                 assert!(!double!(u64_carry_overflow));
1606                 assert_eq!(u64_carry_overflow, [u64::MAX, 0]);
1607
1608                 let mut max = [u64::MAX; 4];
1609                 assert!(double!(max));
1610                 assert_eq!(max, [u64::MAX, u64::MAX, u64::MAX, u64::MAX - 1]);
1611         }
1612
1613         #[test]
1614         fn mul_min_simple_tests() {
1615                 let a = [1, 2];
1616                 let b = [3, 4];
1617                 let res = mul_2(&a, &b);
1618                 assert_eq!(res, [0, 3, 10, 8]);
1619
1620                 let a = [0x1bad_cafe_dead_beef, 2424];
1621                 let b = [0x2bad_beef_dead_cafe, 4242];
1622                 let res = mul_2(&a, &b);
1623                 assert_eq!(res, [340296855556511776, 15015369169016130186, 4248480538569992542, 10282608]);
1624
1625                 let a = [0xf6d9_f8eb_8b60_7a6d, 0x4b93_833e_2194_fc2e];
1626                 let b = [0xfdab_0000_6952_8ab4, 0xd302_0000_8282_0000];
1627                 let res = mul_2(&a, &b);
1628                 assert_eq!(res, [17625486516939878681, 18390748118453258282, 2695286104209847530, 1510594524414214144]);
1629
1630                 let a = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
1631                 let b = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
1632                 let res = mul_2(&a, &b);
1633                 assert_eq!(res, [5481115605507762349, 8230042173354675923, 16737530186064798, 15714555036048702841]);
1634
1635                 let a = [0x0000_0000_0000_0020, 0x002d_362c_005b_7753];
1636                 let b = [0x0900_0000_0030_0003, 0xb708_00fe_0000_00cd];
1637                 let res = mul_2(&a, &b);
1638                 assert_eq!(res, [1, 2306290405521702946, 17647397529888728169, 10271802099389861239]);
1639
1640                 let a = [0x0000_0000_7fff_ffff, 0xffff_ffff_0000_0000];
1641                 let b = [0x0000_0800_0000_0000, 0x0000_1000_0000_00e1];
1642                 let res = mul_2(&a, &b);
1643                 assert_eq!(res, [1024, 0, 483183816703, 18446743107341910016]);
1644
1645                 let a = [0xf6d9_f8eb_ebeb_eb6d, 0x4b93_83a0_bb35_0680];
1646                 let b = [0xfd02_b9b9_b9b9_b9b9, 0xb9b9_b9b9_b9b9_b9b9];
1647                 let res = mul_2(&a, &b);
1648                 assert_eq!(res, [17579814114991930107, 15033987447865175985, 488855932380801351, 5453318140933190272]);
1649
1650                 let a = [u64::MAX; 2];
1651                 let b = [u64::MAX; 2];
1652                 let res = mul_2(&a, &b);
1653                 assert_eq!(res, [18446744073709551615, 18446744073709551614, 0, 1]);
1654         }
1655
1656         #[test]
1657         fn test_add_sub() {
1658                 fn test(a: [u64; 2], b: [u64; 2]) {
1659                         let a_int = u64s_to_u128(a);
1660                         let b_int = u64s_to_u128(b);
1661
1662                         let res = add_2(&a, &b);
1663                         assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int));
1664
1665                         let res = sub_2(&a, &b);
1666                         assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_sub(b_int));
1667                 }
1668
1669                 test([0; 2], [0; 2]);
1670                 test([0x1bad_cafe_dead_beef, 2424], [0x2bad_cafe_dead_cafe, 4242]);
1671                 test([u64::MAX; 2], [u64::MAX; 2]);
1672                 test([u64::MAX, 0x8000_0000_0000_0000], [0, 0x7fff_ffff_ffff_ffff]);
1673                 test([0, 0x7fff_ffff_ffff_ffff], [u64::MAX, 0x8000_0000_0000_0000]);
1674                 test([u64::MAX, 0], [0, u64::MAX]);
1675                 test([0, u64::MAX], [u64::MAX, 0]);
1676                 test([u64::MAX; 2], [0; 2]);
1677                 test([0; 2], [u64::MAX; 2]);
1678         }
1679
1680         #[test]
1681         fn mul_4_simple_tests() {
1682                 let a = [1; 4];
1683                 let b = [2; 4];
1684                 assert_eq!(mul_4(&a, &b),
1685                         [0, 2, 4, 6, 8, 6, 4, 2]);
1686
1687                 let a = [0x1bad_cafe_dead_beef, 2424, 0x1bad_cafe_dead_beef, 2424];
1688                 let b = [0x2bad_beef_dead_cafe, 4242, 0x2bad_beef_dead_cafe, 4242];
1689                 assert_eq!(mul_4(&a, &b),
1690                         [340296855556511776, 15015369169016130186, 4929074249683016095, 11583994264332991364,
1691                          8837257932696496860, 15015369169036695402, 4248480538569992542, 10282608]);
1692
1693                 let a = [u64::MAX; 4];
1694                 let b = [u64::MAX; 4];
1695                 assert_eq!(mul_4(&a, &b),
1696                         [18446744073709551615, 18446744073709551615, 18446744073709551615,
1697                          18446744073709551614, 0, 0, 0, 1]);
1698         }
1699
1700         #[test]
1701         fn double_simple_tests() {
1702                 let mut a = [0xfff5_b32d_01ff_0000, 0x00e7_e7e7_e7e7_e7e7];
1703                 assert!(double!(a));
1704                 assert_eq!(a, [18440945635998695424, 130551405668716494]);
1705
1706                 let mut a = [u64::MAX, u64::MAX];
1707                 assert!(double!(a));
1708                 assert_eq!(a, [18446744073709551615, 18446744073709551614]);
1709         }
1710 }