Clean up + test add/sub/negate, fixing a debug assert in negate
[dnssec-prover] / src / crypto / bigint.rs
1 //! Simple variable-time big integer implementation
2
3 use alloc::vec::Vec;
4 use core::marker::PhantomData;
5
6 const WORD_COUNT_4096: usize = 4096 / 64;
7 const WORD_COUNT_256: usize = 256 / 64;
8 const WORD_COUNT_384: usize = 384 / 64;
9
10 // RFC 5702 indicates RSA keys can be up to 4096 bits
11 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
12 pub(super) struct U4096([u64; WORD_COUNT_4096]);
13
14 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
15 pub(super) struct U256([u64; WORD_COUNT_256]);
16
17 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
18 pub(super) struct U384([u64; WORD_COUNT_384]);
19
20 pub(super) trait Int: Clone + Ord + Sized {
21         const ZERO: Self;
22         const BYTES: usize;
23         fn from_be_bytes(b: &[u8]) -> Result<Self, ()>;
24         fn limbs(&self) -> &[u64];
25 }
26 impl Int for U256 {
27         const ZERO: U256 = U256([0; 4]);
28         const BYTES: usize = 32;
29         fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
30         fn limbs(&self) -> &[u64] { &self.0 }
31 }
32 impl Int for U384 {
33         const ZERO: U384 = U384([0; 6]);
34         const BYTES: usize = 48;
35         fn from_be_bytes(b: &[u8]) -> Result<Self, ()> { Self::from_be_bytes(b) }
36         fn limbs(&self) -> &[u64] { &self.0 }
37 }
38
39 /// Defines a *PRIME* Modulus
40 pub(super) trait PrimeModulus<I: Int> {
41         const PRIME: I;
42         const R_SQUARED_MOD_PRIME: I;
43         const NEGATIVE_PRIME_INV_MOD_R: I;
44 }
45
46 #[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
47 pub(super) struct U256Mod<M: PrimeModulus<U256>>(U256, PhantomData<M>);
48
49 #[derive(Clone, Debug, PartialEq, Eq)] // Ord doesn't make sense cause we have an R factor
50 pub(super) struct U384Mod<M: PrimeModulus<U384>>(U384, PhantomData<M>);
51
52 macro_rules! debug_unwrap { ($v: expr) => { {
53         let v = $v;
54         debug_assert!(v.is_ok());
55         match v {
56                 Ok(r) => r,
57                 Err(e) => return Err(e),
58         }
59 } } }
60
61 // Various const versions of existing slice utilities
62 /// Const version of `&a[start..end]`
63 const fn const_subslice<'a, T>(a: &'a [T], start: usize, end: usize) -> &'a [T] {
64         assert!(start <= a.len());
65         assert!(end <= a.len());
66         assert!(end >= start);
67         let mut startptr = a.as_ptr();
68         startptr = unsafe { startptr.add(start) };
69         let len = end - start;
70         // The docs for from_raw_parts do not mention any requirements that the pointer be valid if the
71         // length is zero, aside from requiring proper alignment (which is met here). Thus,
72         // one-past-the-end should be an acceptable pointer for a 0-length slice.
73         unsafe { alloc::slice::from_raw_parts(startptr, len) }
74 }
75
76 /// Const version of `dest[dest_start..dest_end].copy_from_slice(source)`
77 ///
78 /// Once `const_mut_refs` is stable we can convert this to a function
79 macro_rules! copy_from_slice {
80         ($dest: ident, $dest_start: expr, $dest_end: expr, $source: ident) => { {
81                 let dest_start = $dest_start;
82                 let dest_end = $dest_end;
83                 assert!(dest_start <= $dest.len());
84                 assert!(dest_end <= $dest.len());
85                 assert!(dest_end >= dest_start);
86                 assert!(dest_end - dest_start == $source.len());
87                 let mut i = 0;
88                 while i < $source.len() {
89                         $dest[i + dest_start] = $source[i];
90                         i += 1;
91                 }
92         } }
93 }
94
95 /// Const version of a > b
96 const fn slice_greater_than(a: &[u64], b: &[u64]) -> bool {
97         debug_assert!(a.len() == b.len());
98         let len = if a.len() <= b.len() { a.len() } else { b.len() };
99         let mut i = 0;
100         while i < len {
101                 if a[i] > b[i] { return true; }
102                 else if a[i] < b[i] { return false; }
103                 i += 1;
104         }
105         false // Equal
106 }
107
108 /// Const version of a == b
109 const fn slice_equal(a: &[u64], b: &[u64]) -> bool {
110         debug_assert!(a.len() == b.len());
111         let len = if a.len() <= b.len() { a.len() } else { b.len() };
112         let mut i = 0;
113         while i < len {
114                 if a[i] != b[i] { return false; }
115                 i += 1;
116         }
117         true
118 }
119
120 /// Adds a single u64 valuein-place, returning an overflow flag, in which case one out-of-bounds
121 /// high bit is implicitly included in the result.
122 ///
123 /// Once `const_mut_refs` is stable we can convert this to a function
124 macro_rules! add_u64 { ($a: ident, $b: expr) => { {
125         let len = $a.len();
126         let mut i = len - 1;
127         let mut add = $b;
128         loop {
129                 let (v, carry) = $a[i].overflowing_add(add);
130                 $a[i] = v;
131                 add = carry as u64;
132                 if add == 0 { break; }
133
134                 if i == 0 { break; }
135                 i -= 1;
136         }
137         add != 0
138 } } }
139
140 /// Negates the given u64 slice.
141 ///
142 /// Once `const_mut_refs` is stable we can convert this to a function
143 macro_rules! negate { ($v: ident) => { {
144         let mut i = 0;
145         while i < $v.len() {
146                 $v[i] ^= 0xffff_ffff_ffff_ffff;
147                 i += 1;
148         }
149         add_u64!($v, 1);
150 } } }
151
152 /// Doubles in-place, returning an overflow flag, in which case one out-of-bounds high bit is
153 /// implicitly included in the result.
154 ///
155 /// Once `const_mut_refs` is stable we can convert this to a function
156 macro_rules! double { ($a: ident) => { {
157         { let _: &[u64] = &$a; } // Force type resolution
158         let len = $a.len();
159         let mut carry = false;
160         let mut i = len - 1;
161         loop {
162                 let next_carry = ($a[i] & (1 << 63)) != 0;
163                 let (v, _next_carry_2) = ($a[i] << 1).overflowing_add(carry as u64);
164                 if !next_carry {
165                         debug_assert!(!_next_carry_2, "Adding one to 0x7ffff..*2 is only 0xffff..");
166                 }
167                 // Note that we can ignore _next_carry_2 here as we never need it - it cannot be set if
168                 // next_carry is not set and at max 0xffff..*2 + 1 is only 0x1ffff.. (i.e. we can not need
169                 // a double-carry).
170                 $a[i] = v;
171                 carry = next_carry;
172
173                 if i == 0 { break; }
174                 i -= 1;
175         }
176         carry
177 } } }
178
179 macro_rules! define_add { ($name: ident, $len: expr) => {
180         /// Adds two $len-64-bit integers together, returning a new $len-64-bit integer and an overflow
181         /// bit, with the same semantics as the std [`u64::overflowing_add`] method.
182         const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
183                 debug_assert!(a.len() == $len);
184                 debug_assert!(b.len() == $len);
185                 let mut r = [0; $len];
186                 let mut carry = false;
187                 let mut i = $len - 1;
188                 loop {
189                         let (v, mut new_carry) = a[i].overflowing_add(b[i]);
190                         let (v2, new_new_carry) = v.overflowing_add(carry as u64);
191                         new_carry |= new_new_carry;
192                         r[i] = v2;
193                         carry = new_carry;
194
195                         if i == 0 { break; }
196                         i -= 1;
197                 }
198                 (r, carry)
199         }
200 } }
201
202 define_add!(add_2, 2);
203 define_add!(add_3, 3);
204 define_add!(add_4, 4);
205 define_add!(add_6, 6);
206 define_add!(add_8, 8);
207 define_add!(add_12, 12);
208 define_add!(add_16, 16);
209 define_add!(add_32, 32);
210 define_add!(add_64, 64);
211 define_add!(add_128, 128);
212
213 macro_rules! define_sub { ($name: ident, $name_abs: ident, $len: expr) => {
214         /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
215         /// $len-64-bit integer and an overflow bit, with the same semantics as the std
216         /// [`u64::overflowing_sub`] method.
217         const fn $name(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
218                 debug_assert!(a.len() == $len);
219                 debug_assert!(b.len() == $len);
220                 let mut r = [0; $len];
221                 let mut carry = false;
222                 let mut i = $len - 1;
223                 loop {
224                         let (v, mut new_carry) = a[i].overflowing_sub(b[i]);
225                         let (v2, new_new_carry) = v.overflowing_sub(carry as u64);
226                         new_carry |= new_new_carry;
227                         r[i] = v2;
228                         carry = new_carry;
229
230                         if i == 0 { break; }
231                         i -= 1;
232                 }
233                 (r, carry)
234         }
235
236         /// Subtracts the `b` $len-64-bit integer from the `a` $len-64-bit integer, returning a new
237         /// $len-64-bit integer representing the absolute value of the result, as well as a sign bit.
238         #[allow(unused)]
239         const fn $name_abs(a: &[u64], b: &[u64]) -> ([u64; $len], bool) {
240                 let (mut res, neg) = $name(a, b);
241                 if neg {
242                         negate!(res);
243                 }
244                 (res, neg)
245         }
246 } }
247
248 define_sub!(sub_2, sub_abs_2, 2);
249 define_sub!(sub_3, sub_abs_3, 3);
250 define_sub!(sub_4, sub_abs_4, 4);
251 define_sub!(sub_6, sub_abs_6, 6);
252 define_sub!(sub_8, sub_abs_8, 8);
253 define_sub!(sub_12, sub_abs_12, 12);
254 define_sub!(sub_16, sub_abs_16, 16);
255 define_sub!(sub_32, sub_abs_32, 32);
256 define_sub!(sub_64, sub_abs_64, 64);
257 define_sub!(sub_128, sub_abs_128, 128);
258
259 /// Multiplies two 128-bit integers together, returning a new 256-bit integer.
260 ///
261 /// This is the base case for our multiplication, taking advantage of Rust's native 128-bit int
262 /// types to do multiplication (potentially) natively.
263 const fn mul_2(a: &[u64], b: &[u64]) -> [u64; 4] {
264         debug_assert!(a.len() == 2);
265         debug_assert!(b.len() == 2);
266
267         // Gradeschool multiplication is way faster here.
268         let (a0, a1) = (a[0] as u128, a[1] as u128);
269         let (b0, b1) = (b[0] as u128, b[1] as u128);
270         let z2 = a0 * b0;
271         let z1i = a0 * b1;
272         let z1j = b0 * a1;
273         let (z1, i_carry) = z1i.overflowing_add(z1j);
274         let z0 = a1 * b1;
275
276         let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
277         let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
278         let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
279         let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
280         let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
281         let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
282
283         let l = z0b;
284         let (k, j_carry) = z0a.overflowing_add(z1b);
285         let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
286
287         let new_i_carry;
288         (j, new_i_carry) = j.overflowing_add(j_carry as u64);
289         debug_assert!(!second_i_carry || !new_i_carry);
290         second_i_carry |= new_i_carry;
291
292         let mut i = z2a;
293         let mut spurious_overflow;
294         (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
295         debug_assert!(!spurious_overflow);
296         (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
297         debug_assert!(!spurious_overflow);
298
299         [i, j, k, l]
300 }
301
302 const fn mul_3(a: &[u64], b: &[u64]) -> [u64; 6] {
303         debug_assert!(a.len() == 3);
304         debug_assert!(b.len() == 3);
305
306         let (a0, a1, a2) = (a[0] as u128, a[1] as u128, a[2] as u128);
307         let (b0, b1, b2) = (b[0] as u128, b[1] as u128, b[2] as u128);
308
309         let m4 = a2 * b2;
310         let m3a = a2 * b1;
311         let m3b = a1 * b2;
312         let m2a = a2 * b0;
313         let m2b = a1 * b1;
314         let m2c = a0 * b2;
315         let m1a = a1 * b0;
316         let m1b = a0 * b1;
317         let m0 = a0 * b0;
318
319         let r5 = ((m4 >> 0) & 0xffff_ffff_ffff_ffff) as u64;
320
321         let r4a = ((m4 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
322         let r4b = ((m3a >> 0) & 0xffff_ffff_ffff_ffff) as u64;
323         let r4c = ((m3b >> 0) & 0xffff_ffff_ffff_ffff) as u64;
324
325         let r3a = ((m3a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
326         let r3b = ((m3b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
327         let r3c = ((m2a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
328         let r3d = ((m2b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
329         let r3e = ((m2c >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
330
331         let r2a = ((m2a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
332         let r2b = ((m2b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
333         let r2c = ((m2c >> 64) & 0xffff_ffff_ffff_ffff) as u64;
334         let r2d = ((m1a >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
335         let r2e = ((m1b >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
336
337         let r1a = ((m1a >> 64) & 0xffff_ffff_ffff_ffff) as u64;
338         let r1b = ((m1b >> 64) & 0xffff_ffff_ffff_ffff) as u64;
339         let r1c = ((m0  >> 0 ) & 0xffff_ffff_ffff_ffff) as u64;
340
341         let r0a = ((m0  >> 64) & 0xffff_ffff_ffff_ffff) as u64;
342
343         let (r4, r3_ca) = r4a.overflowing_add(r4b);
344         let (r4, r3_cb) = r4.overflowing_add(r4c);
345         let r3_c = r3_ca as u64 + r3_cb as u64;
346
347         let (r3, r2_ca) = r3a.overflowing_add(r3b);
348         let (r3, r2_cb) = r3.overflowing_add(r3c);
349         let (r3, r2_cc) = r3.overflowing_add(r3d);
350         let (r3, r2_cd) = r3.overflowing_add(r3e);
351         let (r3, r2_ce) = r3.overflowing_add(r3_c);
352         let r2_c = r2_ca as u64 + r2_cb as u64 + r2_cc as u64 + r2_cd as u64 + r2_ce as u64;
353
354         let (r2, r1_ca) = r2a.overflowing_add(r2b);
355         let (r2, r1_cb) = r2.overflowing_add(r2c);
356         let (r2, r1_cc) = r2.overflowing_add(r2d);
357         let (r2, r1_cd) = r2.overflowing_add(r2e);
358         let (r2, r1_ce) = r2.overflowing_add(r2_c);
359         let r1_c = r1_ca as u64 + r1_cb as u64 + r1_cc as u64 + r1_cd as u64 + r1_ce as u64;
360
361         let (r1, r0_ca) = r1a.overflowing_add(r1b);
362         let (r1, r0_cb) = r1.overflowing_add(r1c);
363         let (r1, r0_cc) = r1.overflowing_add(r1_c);
364         let r0_c = r0_ca as u64 + r0_cb as u64 + r0_cc as u64;
365
366         let (r0, must_not_overflow) = r0a.overflowing_add(r0_c);
367         debug_assert!(!must_not_overflow);
368
369         [r0, r1, r2, r3, r4, r5]
370 }
371
372 macro_rules! define_mul { ($name: ident, $len: expr, $submul: ident, $add: ident, $subadd: ident, $sub: ident, $subsub: ident) => {
373         /// Multiplies two $len-64-bit integers together, returning a new $len*2-64-bit integer.
374         const fn $name(a: &[u64], b: &[u64]) -> [u64; $len * 2] {
375                 // We could probably get a bit faster doing gradeschool multiplication for some smaller
376                 // sizes, but its easier to just have one variable-length multiplication, so we do
377                 // Karatsuba always here.
378                 debug_assert!(a.len() == $len);
379                 debug_assert!(b.len() == $len);
380
381                 let a0 = const_subslice(a, 0, $len / 2);
382                 let a1 = const_subslice(a, $len / 2, $len);
383                 let b0 = const_subslice(b, 0, $len / 2);
384                 let b1 = const_subslice(b, $len / 2, $len);
385
386                 let z2 = $submul(a0, b0);
387                 let z0 = $submul(a1, b1);
388
389                 let (z1a_max, z1a_min, z1a_sign) =
390                         if slice_greater_than(&a1, &a0) { (a1, a0, true) } else { (a0, a1, false) };
391                 let (z1b_max, z1b_min, z1b_sign) =
392                         if slice_greater_than(&b1, &b0) { (b1, b0, true) } else { (b0, b1, false) };
393
394                 let z1a = $subsub(z1a_max, z1a_min);
395                 debug_assert!(!z1a.1);
396                 let z1b = $subsub(z1b_max, z1b_min);
397                 debug_assert!(!z1b.1);
398                 let z1m_sign = z1a_sign == z1b_sign;
399
400                 let z1m = $submul(&z1a.0, &z1b.0);
401                 let z1n = $add(&z0, &z2);
402                 let mut z1_carry = z1n.1;
403                 let z1 = if z1m_sign {
404                         let r = $sub(&z1n.0, &z1m);
405                         if r.1 { z1_carry ^= true; }
406                         r.0
407                 } else {
408                         let r = $add(&z1n.0, &z1m);
409                         if r.1 { z1_carry = true; }
410                         r.0
411                 };
412
413                 let l = const_subslice(&z0, $len / 2, $len);
414                 let (k, j_carry) = $subadd(const_subslice(&z0, 0, $len / 2), const_subslice(&z1, $len / 2, $len));
415                 let (mut j, mut i_carry) = $subadd(const_subslice(&z1, 0, $len / 2), const_subslice(&z2, $len / 2, $len));
416                 if j_carry {
417                         let new_i_carry = add_u64!(j, 1);
418                         debug_assert!(!i_carry || !new_i_carry);
419                         i_carry |= new_i_carry;
420                 }
421                 let mut i = [0; $len / 2];
422                 let i_source = const_subslice(&z2, 0, $len / 2);
423                 copy_from_slice!(i, 0, $len / 2, i_source);
424                 if i_carry {
425                         let spurious_carry = add_u64!(i, 1);
426                         debug_assert!(!spurious_carry);
427                 }
428                 if z1_carry {
429                         let spurious_carry = add_u64!(i, 1);
430                         debug_assert!(!spurious_carry);
431                 }
432
433                 let mut res = [0; $len * 2];
434                 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
435                 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
436                 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
437                 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
438                 res
439         }
440 } }
441
442 define_mul!(mul_4, 4, mul_2, add_4, add_2, sub_4, sub_2);
443 define_mul!(mul_6, 6, mul_3, add_6, add_3, sub_6, sub_3);
444 define_mul!(mul_8, 8, mul_4, add_8, add_4, sub_8, sub_4);
445 define_mul!(mul_16, 16, mul_8, add_16, add_8, sub_16, sub_8);
446 define_mul!(mul_32, 32, mul_16, add_32, add_16, sub_32, sub_16);
447 define_mul!(mul_64, 64, mul_32, add_64, add_32, sub_64, sub_32);
448
449
450 /// Squares a 128-bit integer, returning a new 256-bit integer.
451 ///
452 /// This is the base case for our squaring, taking advantage of Rust's native 128-bit int
453 /// types to do multiplication (potentially) natively.
454 const fn sqr_2(a: &[u64]) -> [u64; 4] {
455         debug_assert!(a.len() == 2);
456
457         let (a0, a1) = (a[0] as u128, a[1] as u128);
458         let z2 = a0 * a0;
459         let mut z1 = a0 * a1;
460         let i_carry = z1 & (1u128 << 127) != 0;
461         z1 <<= 1;
462         let z0 = a1 * a1;
463
464         let z2a = ((z2 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
465         let z1a = ((z1 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
466         let z0a = ((z0 >> 64) & 0xffff_ffff_ffff_ffff) as u64;
467         let z2b = (z2 & 0xffff_ffff_ffff_ffff) as u64;
468         let z1b = (z1 & 0xffff_ffff_ffff_ffff) as u64;
469         let z0b = (z0 & 0xffff_ffff_ffff_ffff) as u64;
470
471         let l = z0b;
472         let (k, j_carry) = z0a.overflowing_add(z1b);
473         let (mut j, mut second_i_carry) = z1a.overflowing_add(z2b);
474
475         let new_i_carry;
476         (j, new_i_carry) = j.overflowing_add(j_carry as u64);
477         debug_assert!(!second_i_carry || !new_i_carry);
478         second_i_carry |= new_i_carry;
479
480         let mut i = z2a;
481         let mut spurious_overflow;
482         (i, spurious_overflow) = i.overflowing_add(i_carry as u64);
483         debug_assert!(!spurious_overflow);
484         (i, spurious_overflow) = i.overflowing_add(second_i_carry as u64);
485         debug_assert!(!spurious_overflow);
486
487         [i, j, k, l]
488 }
489
490 macro_rules! define_sqr { ($name: ident, $len: expr, $submul: ident, $subsqr: ident, $subadd: ident) => {
491         /// Squares a $len-64-bit integers, returning a new $len*2-64-bit integer.
492         const fn $name(a: &[u64]) -> [u64; $len * 2] {
493                 debug_assert!(a.len() == $len);
494
495                 let hi = const_subslice(a, 0, $len / 2);
496                 let lo = const_subslice(a, $len / 2, $len);
497
498                 let v0 = $subsqr(lo);
499                 let mut v1 = $submul(hi, lo);
500                 let i_carry  = double!(v1);
501                 let v2 = $subsqr(hi);
502
503                 let l = const_subslice(&v0, $len / 2, $len);
504                 let (k, j_carry) = $subadd(const_subslice(&v0, 0, $len / 2), const_subslice(&v1, $len / 2, $len));
505                 let (mut j, mut i_carry_2) = $subadd(const_subslice(&v1, 0, $len / 2), const_subslice(&v2, $len / 2, $len));
506
507                 let mut i = [0; $len / 2];
508                 let i_source = const_subslice(&v2, 0, $len / 2);
509                 copy_from_slice!(i, 0, $len / 2, i_source);
510
511                 if j_carry {
512                         let new_i_carry = add_u64!(j, 1);
513                         debug_assert!(!i_carry_2 || !new_i_carry);
514                         i_carry_2 |= new_i_carry;
515                 }
516                 if i_carry {
517                         let spurious_carry = add_u64!(i, 1);
518                         debug_assert!(!spurious_carry);
519                 }
520                 if i_carry_2 {
521                         let spurious_carry = add_u64!(i, 1);
522                         debug_assert!(!spurious_carry);
523                 }
524
525                 let mut res = [0; $len * 2];
526                 copy_from_slice!(res, $len * 2 * 0 / 4, $len * 2 * 1 / 4, i);
527                 copy_from_slice!(res, $len * 2 * 1 / 4, $len * 2 * 2 / 4, j);
528                 copy_from_slice!(res, $len * 2 * 2 / 4, $len * 2 * 3 / 4, k);
529                 copy_from_slice!(res, $len * 2 * 3 / 4, $len * 2 * 4 / 4, l);
530                 res
531         }
532 } }
533
534 // TODO: Write an optimized sqr_3 (though secp384r1 is barely used)
535 const fn sqr_3(a: &[u64]) -> [u64; 6] { mul_3(a, a) }
536
537 define_sqr!(sqr_4, 4, mul_2, sqr_2, add_2);
538 define_sqr!(sqr_6, 6, mul_3, sqr_3, add_3);
539 define_sqr!(sqr_8, 8, mul_4, sqr_4, add_4);
540 define_sqr!(sqr_16, 16, mul_8, sqr_8, add_8);
541 define_sqr!(sqr_32, 32, mul_16, sqr_16, add_16);
542 define_sqr!(sqr_64, 64, mul_32, sqr_32, add_32);
543
544 macro_rules! dummy_pre_push { ($name: ident, $len: expr) => {} }
545 macro_rules! vec_pre_push { ($name: ident, $len: expr) => { $name.push([0; $len]); } }
546
547 macro_rules! define_div_rem { ($name: ident, $len: expr, $sub: ident, $heap_init: expr, $pre_push: ident $(, $const_opt: tt)?) => {
548         /// Divides two $len-64-bit integers, `a` by `b`, returning the quotient and remainder
549         ///
550         /// Fails iff `b` is zero.
551         $($const_opt)? fn $name(a: &[u64; $len], b: &[u64; $len]) -> Result<([u64; $len], [u64; $len]), ()> {
552                 if slice_equal(b, &[0; $len]) { return Err(()); }
553
554                 let mut b_pow = *b;
555                 let mut pow2s = $heap_init;
556                 let mut pow2s_count = 0;
557                 while slice_greater_than(a, &b_pow) {
558                         $pre_push!(pow2s, $len);
559                         pow2s[pow2s_count] = b_pow;
560                         pow2s_count += 1;
561                         let double_overflow = double!(b_pow);
562                         if double_overflow { break; }
563                 }
564                 let mut quot = [0; $len];
565                 let mut rem = *a;
566                 let mut pow2 = pow2s_count as isize - 1;
567                 while pow2 >= 0 {
568                         let b_pow = pow2s[pow2 as usize];
569                         let overflow = double!(quot);
570                         debug_assert!(!overflow);
571                         if slice_greater_than(&rem, &b_pow) {
572                                 let (r, carry) = $sub(&rem, &b_pow);
573                                 debug_assert!(!carry);
574                                 rem = r;
575                                 quot[$len - 1] |= 1;
576                         }
577                         pow2 -= 1;
578                 }
579                 if slice_equal(&rem, b) {
580                         let overflow = add_u64!(quot, 1);
581                         debug_assert!(!overflow);
582                         Ok((quot, [0; $len]))
583                 } else {
584                         Ok((quot, rem))
585                 }
586         }
587 } }
588
589 #[cfg(fuzzing)]
590 define_div_rem!(div_rem_2, 2, sub_2, [[0; 2]; 2 * 64], dummy_pre_push, const);
591 define_div_rem!(div_rem_4, 4, sub_4, [[0; 4]; 4 * 64], dummy_pre_push, const); // Uses 8 KiB of stack
592 define_div_rem!(div_rem_6, 6, sub_6, [[0; 6]; 6 * 64], dummy_pre_push, const); // Uses 18 KiB of stack!
593 #[cfg(debug_assertions)]
594 define_div_rem!(div_rem_8, 8, sub_8, [[0; 8]; 8 * 64], dummy_pre_push, const); // Uses 32 KiB of stack!
595 #[cfg(debug_assertions)]
596 define_div_rem!(div_rem_12, 12, sub_12, [[0; 12]; 12 * 64], dummy_pre_push, const); // Uses 72 KiB of stack!
597 define_div_rem!(div_rem_64, 64, sub_64, Vec::new(), vec_pre_push); // Uses up to 2 MiB of heap
598 #[cfg(debug_assertions)]
599 define_div_rem!(div_rem_128, 128, sub_128, Vec::new(), vec_pre_push); // Uses up to 8 MiB of heap
600
601 macro_rules! define_mod_inv { ($name: ident, $len: expr, $div: ident, $add: ident, $sub_abs: ident, $mul: ident) => {
602         /// Calculates the modular inverse of a $len-64-bit number with respect to the given modulus,
603         /// if one exists.
604         const fn $name(a: &[u64; $len], m: &[u64; $len]) -> Result<[u64; $len], ()> {
605                 if slice_equal(a, &[0; $len]) || slice_equal(m, &[0; $len]) { return Err(()); }
606
607                 let (mut s, mut old_s) = ([0; $len], [0; $len]);
608                 old_s[$len - 1] = 1;
609                 let mut r = *m;
610                 let mut old_r = *a;
611
612                 let (mut old_s_neg, mut s_neg) = (false, false);
613
614                 while !slice_equal(&r, &[0; $len]) {
615                         let (quot, new_r) = debug_unwrap!($div(&old_r, &r));
616
617                         let new_sa = $mul(&quot, &s);
618                         debug_assert!(slice_equal(const_subslice(&new_sa, 0, $len), &[0; $len]), "S overflowed");
619                         let (new_s, new_s_neg) = match (old_s_neg, s_neg) {
620                                 (true, true) => {
621                                         let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
622                                         debug_assert!(!overflow);
623                                         (new_s, true)
624                                 }
625                                 (false, true) => {
626                                         let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
627                                         debug_assert!(!overflow);
628                                         (new_s, false)
629                                 },
630                                 (true, false) => {
631                                         let (new_s, overflow) = $add(&old_s, const_subslice(&new_sa, $len, new_sa.len()));
632                                         debug_assert!(!overflow);
633                                         (new_s, true)
634                                 },
635                                 (false, false) => $sub_abs(&old_s, const_subslice(&new_sa, $len, new_sa.len())),
636                         };
637
638                         old_r = r;
639                         r = new_r;
640
641                         old_s = s;
642                         old_s_neg = s_neg;
643                         s = new_s;
644                         s_neg = new_s_neg;
645                 }
646
647                 // At this point old_r contains our GCD and old_s our first Bézout's identity coefficient.
648                 if !slice_equal(const_subslice(&old_r, 0, $len - 1), &[0; $len - 1]) || old_r[$len - 1] != 1 {
649                         Err(())
650                 } else {
651                         debug_assert!(slice_greater_than(m, &old_s));
652                         if old_s_neg {
653                                 let (modinv, underflow) = $sub_abs(m, &old_s);
654                                 debug_assert!(!underflow);
655                                 debug_assert!(slice_greater_than(m, &modinv));
656                                 Ok(modinv)
657                         } else {
658                                 Ok(old_s)
659                         }
660                 }
661         }
662 } }
663 #[cfg(fuzzing)]
664 define_mod_inv!(mod_inv_2, 2, div_rem_2, add_2, sub_abs_2, mul_2);
665 define_mod_inv!(mod_inv_4, 4, div_rem_4, add_4, sub_abs_4, mul_4);
666 define_mod_inv!(mod_inv_6, 6, div_rem_6, add_6, sub_abs_6, mul_6);
667 #[cfg(fuzzing)]
668 define_mod_inv!(mod_inv_8, 8, div_rem_8, add_8, sub_abs_8, mul_8);
669
670 impl U4096 {
671         /// Constructs a new [`U4096`] from a variable number of big-endian bytes.
672         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U4096, ()> {
673                 if bytes.len() > 4096/8 { return Err(()); }
674                 let u64s = (bytes.len() + 7) / 8;
675                 let mut res = [0; WORD_COUNT_4096];
676                 for i in 0..u64s {
677                         let mut b = [0; 8];
678                         let pos = (u64s - i) * 8;
679                         let start = bytes.len().saturating_sub(pos);
680                         let end = bytes.len() + 8 - pos;
681                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
682                         res[i + WORD_COUNT_4096 - u64s] = u64::from_be_bytes(b);
683                 }
684                 Ok(U4096(res))
685         }
686
687         /// Naively multiplies `self` * `b` mod `m`, returning a new [`U4096`].
688         ///
689         /// Fails iff m is 0 or self or b are greater than m.
690         #[cfg(debug_assertions)]
691         fn mulmod_naive(&self, b: &U4096, m: &U4096) -> Result<U4096, ()> {
692                 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
693                 if self > m || b > m { return Err(()); }
694
695                 let mul = mul_64(&self.0, &b.0);
696
697                 let mut m_zeros = [0; 128];
698                 m_zeros[WORD_COUNT_4096..].copy_from_slice(&m.0);
699                 let (_, rem) = div_rem_128(&mul, &m_zeros)?;
700                 let mut res = [0; WORD_COUNT_4096];
701                 debug_assert_eq!(&rem[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
702                 res.copy_from_slice(&rem[WORD_COUNT_4096..]);
703                 Ok(U4096(res))
704         }
705
706         /// Calculates `self` ^ `exp` mod `m`, returning a new [`U4096`].
707         ///
708         /// Fails iff m is 0, even, or self or b are greater than m.
709         pub(super) fn expmod_odd_mod(&self, mut exp: u32, m: &U4096) -> Result<U4096, ()> {
710                 #![allow(non_camel_case_types)]
711
712                 if m.0 == [0; WORD_COUNT_4096] { return Err(()); }
713                 if m.0[WORD_COUNT_4096 - 1] & 1 == 0 { return Err(()); }
714                 if self > m { return Err(()); }
715
716                 let mut t = [0; WORD_COUNT_4096];
717                 if &m.0[..WORD_COUNT_4096 - 1] == &[0; WORD_COUNT_4096 - 1] && m.0[WORD_COUNT_4096 - 1] == 1 {
718                         return Ok(U4096(t));
719                 }
720                 t[WORD_COUNT_4096 - 1] = 1;
721                 if exp == 0 { return Ok(U4096(t)); }
722
723                 // Because m is not even, using 2^4096 as the Montgomery R value is always safe - it is
724                 // guaranteed to be co-prime with any non-even integer.
725
726                 type mul_ty = fn(&[u64], &[u64]) -> [u64; WORD_COUNT_4096 * 2];
727                 type sqr_ty = fn(&[u64]) -> [u64; WORD_COUNT_4096 * 2];
728                 type add_double_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool);
729                 type sub_ty = fn(&[u64], &[u64]) -> ([u64; WORD_COUNT_4096], bool);
730                 let (word_count, log_bits, mul, sqr, add_double, sub) =
731                         if m.0[..WORD_COUNT_4096 / 2] == [0; WORD_COUNT_4096 / 2] {
732                                 if m.0[..WORD_COUNT_4096 * 3 / 4] == [0; WORD_COUNT_4096 * 3 / 4] {
733                                         fn mul_16_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
734                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
735                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
736                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
737                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
738                                                 let mut res = [0; WORD_COUNT_4096 * 2];
739                                                 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
740                                                         &mul_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]));
741                                                 res
742                                         }
743                                         fn sqr_16_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
744                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
745                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
746                                                 let mut res = [0; WORD_COUNT_4096 * 2];
747                                                 res[WORD_COUNT_4096 + WORD_COUNT_4096 / 2..].copy_from_slice(
748                                                         &sqr_16(&a[WORD_COUNT_4096 * 3 / 4..]));
749                                                 res
750                                         }
751                                         fn add_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
752                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
753                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
754                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
755                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 2], &[0; WORD_COUNT_4096 * 3 / 2]);
756                                                 let (add, overflow) = add_32(&a[WORD_COUNT_4096 * 3 / 2..], &b[WORD_COUNT_4096 * 3 / 2..]);
757                                                 let mut res = [0; WORD_COUNT_4096 * 2];
758                                                 res[WORD_COUNT_4096 * 3 / 2..].copy_from_slice(&add);
759                                                 (res, overflow)
760                                         }
761                                         fn sub_16_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
762                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
763                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
764                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
765                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 * 3 / 4], &[0; WORD_COUNT_4096 * 3 / 4]);
766                                                 let (sub, underflow) = sub_16(&a[WORD_COUNT_4096 * 3 / 4..], &b[WORD_COUNT_4096 * 3 / 4..]);
767                                                 let mut res = [0; WORD_COUNT_4096];
768                                                 res[WORD_COUNT_4096 * 3 / 4..].copy_from_slice(&sub);
769                                                 (res, underflow)
770                                         }
771                                         (16, 10, mul_16_subarr as mul_ty, sqr_16_subarr as sqr_ty, add_32_subarr as add_double_ty, sub_16_subarr as sub_ty)
772                                 } else {
773                                         fn mul_32_subarr(a: &[u64], b: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
774                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
775                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
776                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
777                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
778                                                 let mut res = [0; WORD_COUNT_4096 * 2];
779                                                 res[WORD_COUNT_4096..].copy_from_slice(
780                                                         &mul_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]));
781                                                 res
782                                         }
783                                         fn sqr_32_subarr(a: &[u64]) -> [u64; WORD_COUNT_4096 * 2] {
784                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
785                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
786                                                 let mut res = [0; WORD_COUNT_4096 * 2];
787                                                 res[WORD_COUNT_4096..].copy_from_slice(
788                                                         &sqr_32(&a[WORD_COUNT_4096 / 2..]));
789                                                 res
790                                         }
791                                         fn add_64_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096 * 2], bool) {
792                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096 * 2);
793                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096 * 2);
794                                                 debug_assert_eq!(&a[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
795                                                 debug_assert_eq!(&b[..WORD_COUNT_4096], &[0; WORD_COUNT_4096]);
796                                                 let (add, overflow) = add_64(&a[WORD_COUNT_4096..], &b[WORD_COUNT_4096..]);
797                                                 let mut res = [0; WORD_COUNT_4096 * 2];
798                                                 res[WORD_COUNT_4096..].copy_from_slice(&add);
799                                                 (res, overflow)
800                                         }
801                                         fn sub_32_subarr(a: &[u64], b: &[u64]) -> ([u64; WORD_COUNT_4096], bool) {
802                                                 debug_assert_eq!(a.len(), WORD_COUNT_4096);
803                                                 debug_assert_eq!(b.len(), WORD_COUNT_4096);
804                                                 debug_assert_eq!(&a[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
805                                                 debug_assert_eq!(&b[..WORD_COUNT_4096 / 2], &[0; WORD_COUNT_4096 / 2]);
806                                                 let (sub, underflow) = sub_32(&a[WORD_COUNT_4096 / 2..], &b[WORD_COUNT_4096 / 2..]);
807                                                 let mut res = [0; WORD_COUNT_4096];
808                                                 res[WORD_COUNT_4096 / 2..].copy_from_slice(&sub);
809                                                 (res, underflow)
810                                         }
811                                         (32, 11, mul_32_subarr as mul_ty, sqr_32_subarr as sqr_ty, add_64_subarr as add_double_ty, sub_32_subarr as sub_ty)
812                                 }
813                         } else {
814                                 (64, 12, mul_64 as mul_ty, sqr_64 as sqr_ty, add_128 as add_double_ty, sub_64 as sub_ty)
815                         };
816
817                 let mut r = [0; WORD_COUNT_4096 * 2];
818                 r[WORD_COUNT_4096 * 2 - word_count - 1] = 1;
819
820                 let mut m_inv_pos = [0; WORD_COUNT_4096];
821                 m_inv_pos[WORD_COUNT_4096 - 1] = 1;
822                 let mut two = [0; WORD_COUNT_4096];
823                 two[WORD_COUNT_4096 - 1] = 2;
824                 for _ in 0..log_bits {
825                         let mut m_m_inv = mul(&m_inv_pos, &m.0);
826                         m_m_inv[..WORD_COUNT_4096 * 2 - word_count].fill(0);
827                         let m_inv = mul(&sub(&two, &m_m_inv[WORD_COUNT_4096..]).0, &m_inv_pos);
828                         m_inv_pos[WORD_COUNT_4096 - word_count..].copy_from_slice(&m_inv[WORD_COUNT_4096 * 2 - word_count..]);
829                 }
830                 m_inv_pos[..WORD_COUNT_4096 - word_count].fill(0);
831
832                 // We want the negative modular inverse of m mod R, so subtract m_inv from R.
833                 let mut m_inv = m_inv_pos;
834                 negate!(m_inv);
835                 m_inv[..WORD_COUNT_4096 - word_count].fill(0);
836                 debug_assert_eq!(&mul(&m_inv, &m.0)[WORD_COUNT_4096 * 2 - word_count..],
837                         // R - 1 == -1 % R
838                         &[0xffff_ffff_ffff_ffff; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..]);
839
840                 debug_assert_eq!(&m_inv[..WORD_COUNT_4096 - word_count], &[0; WORD_COUNT_4096][..WORD_COUNT_4096 - word_count]);
841
842                 let mont_reduction = |mu: [u64; WORD_COUNT_4096 * 2]| -> [u64; WORD_COUNT_4096] {
843                         debug_assert_eq!(&mu[..WORD_COUNT_4096 * 2 - word_count * 2],
844                                 &[0; WORD_COUNT_4096 * 2][..WORD_COUNT_4096 * 2 - word_count * 2]);
845                         let mut mu_mod_r = [0; WORD_COUNT_4096];
846                         mu_mod_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&mu[WORD_COUNT_4096 * 2 - word_count..]);
847                         let mut v = mul(&mu_mod_r, &m_inv);
848                         v[..WORD_COUNT_4096 * 2 - word_count].fill(0); // mod R
849                         let t0 = mul(&v[WORD_COUNT_4096..], &m.0);
850                         let (t1, t1_extra_bit) = add_double(&t0, &mu);
851                         let mut t1_on_r = [0; WORD_COUNT_4096];
852                         debug_assert_eq!(&t1[WORD_COUNT_4096 * 2 - word_count..], &[0; WORD_COUNT_4096][WORD_COUNT_4096 - word_count..],
853                                 "t1 should be divisible by r");
854                         t1_on_r[WORD_COUNT_4096 - word_count..].copy_from_slice(&t1[WORD_COUNT_4096 * 2 - word_count * 2..WORD_COUNT_4096 * 2 - word_count]);
855                         if t1_extra_bit || t1_on_r >= m.0 {
856                                 let underflow;
857                                 (t1_on_r, underflow) = sub(&t1_on_r, &m.0);
858                                 debug_assert_eq!(t1_extra_bit, underflow);
859                         }
860                         t1_on_r
861                 };
862
863                 // Calculate R^2 mod m as ((2^DOUBLES * R) mod m)^(log_bits - LOG2_DOUBLES) mod R
864                 let mut r_minus_one = [0xffff_ffff_ffff_ffffu64; WORD_COUNT_4096];
865                 r_minus_one[..WORD_COUNT_4096 - word_count].fill(0);
866                 // While we do a full div here, in general R should be less than 2x m (assuming the RSA
867                 // modulus used its full bit range and is 1024, 2048, or 4096 bits), so it should be cheap.
868                 // In cases with a nonstandard RSA modulus we may end up being pretty slow here, but we'll
869                 // survive.
870                 // If we ever find a problem with this we should reduce R to be tigher on m, as we're
871                 // wasting extra bits of calculation if R is too far from m.
872                 let (_, mut r_mod_m) = debug_unwrap!(div_rem_64(&r_minus_one, &m.0));
873                 let r_mod_m_overflow = add_u64!(r_mod_m, 1);
874                 if r_mod_m_overflow || r_mod_m >= m.0 {
875                         (r_mod_m, _) = sub_64(&r_mod_m, &m.0);
876                 }
877
878                 let mut r2_mod_m: [u64; 64] = r_mod_m;
879                 const DOUBLES: usize = 32;
880                 const LOG2_DOUBLES: usize = 5;
881
882                 for _ in 0..DOUBLES {
883                         let overflow = double!(r2_mod_m);
884                         if overflow || r2_mod_m > m.0 {
885                                 (r2_mod_m, _) = sub_64(&r2_mod_m, &m.0);
886                         }
887                 }
888                 for _ in 0..log_bits - LOG2_DOUBLES {
889                         r2_mod_m = mont_reduction(sqr(&r2_mod_m));
890                 }
891                 // Clear excess high bits
892                 for (m_limb, r2_limb) in m.0.iter().zip(r2_mod_m.iter_mut()) {
893                         let clear_bits = m_limb.leading_zeros();
894                         if clear_bits == 0 { break; }
895                         *r2_limb &= !(0xffff_ffff_ffff_ffffu64 << (64 - clear_bits));
896                         if *m_limb != 0 { break; }
897                 }
898                 debug_assert!(r2_mod_m < m.0);
899
900                 // Calculate t * R and a * R as mont multiplications by R^2 mod m
901                 let mut tr = mont_reduction(mul(&r2_mod_m, &t));
902                 let mut ar = mont_reduction(mul(&r2_mod_m, &self.0));
903
904                 #[cfg(debug_assertions)] {
905                         debug_assert_eq!(r2_mod_m, U4096(r_mod_m).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
906                         debug_assert_eq!(&tr, &U4096(t).mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
907                         debug_assert_eq!(&ar, &self.mulmod_naive(&U4096(r_mod_m), &m).unwrap().0);
908                 }
909
910                 while exp != 1 {
911                         if exp % 2 == 1 {
912                                 tr = mont_reduction(mul(&tr, &ar));
913                                 exp -= 1;
914                         }
915                         ar = mont_reduction(sqr(&ar));
916                         exp /= 2;
917                 }
918                 ar = mont_reduction(mul(&ar, &tr));
919                 let mut resr = [0; WORD_COUNT_4096 * 2];
920                 resr[WORD_COUNT_4096..].copy_from_slice(&ar);
921                 Ok(U4096(mont_reduction(resr)))
922         }
923 }
924
925 const fn u64_from_bytes_a_panicking(b: &[u8]) -> u64 {
926         match b {
927                 [a, b, c, d, e, f, g, h, ..] => {
928                         ((*a as u64) << 8*7) |
929                         ((*b as u64) << 8*6) |
930                         ((*c as u64) << 8*5) |
931                         ((*d as u64) << 8*4) |
932                         ((*e as u64) << 8*3) |
933                         ((*f as u64) << 8*2) |
934                         ((*g as u64) << 8*1) |
935                         ((*h as u64) << 8*0)
936                 },
937                 _ => panic!(),
938         }
939 }
940
941 const fn u64_from_bytes_b_panicking(b: &[u8]) -> u64 {
942         match b {
943                 [_, _, _, _, _, _, _, _,
944                  a, b, c, d, e, f, g, h, ..] => {
945                         ((*a as u64) << 8*7) |
946                         ((*b as u64) << 8*6) |
947                         ((*c as u64) << 8*5) |
948                         ((*d as u64) << 8*4) |
949                         ((*e as u64) << 8*3) |
950                         ((*f as u64) << 8*2) |
951                         ((*g as u64) << 8*1) |
952                         ((*h as u64) << 8*0)
953                 },
954                 _ => panic!(),
955         }
956 }
957
958 const fn u64_from_bytes_c_panicking(b: &[u8]) -> u64 {
959         match b {
960                 [_, _, _, _, _, _, _, _,
961                  _, _, _, _, _, _, _, _,
962                  a, b, c, d, e, f, g, h, ..] => {
963                         ((*a as u64) << 8*7) |
964                         ((*b as u64) << 8*6) |
965                         ((*c as u64) << 8*5) |
966                         ((*d as u64) << 8*4) |
967                         ((*e as u64) << 8*3) |
968                         ((*f as u64) << 8*2) |
969                         ((*g as u64) << 8*1) |
970                         ((*h as u64) << 8*0)
971                 },
972                 _ => panic!(),
973         }
974 }
975
976 const fn u64_from_bytes_d_panicking(b: &[u8]) -> u64 {
977         match b {
978                 [_, _, _, _, _, _, _, _,
979                  _, _, _, _, _, _, _, _,
980                  _, _, _, _, _, _, _, _,
981                  a, b, c, d, e, f, g, h, ..] => {
982                         ((*a as u64) << 8*7) |
983                         ((*b as u64) << 8*6) |
984                         ((*c as u64) << 8*5) |
985                         ((*d as u64) << 8*4) |
986                         ((*e as u64) << 8*3) |
987                         ((*f as u64) << 8*2) |
988                         ((*g as u64) << 8*1) |
989                         ((*h as u64) << 8*0)
990                 },
991                 _ => panic!(),
992         }
993 }
994
995 const fn u64_from_bytes_e_panicking(b: &[u8]) -> u64 {
996         match b {
997                 [_, _, _, _, _, _, _, _,
998                  _, _, _, _, _, _, _, _,
999                  _, _, _, _, _, _, _, _,
1000                  _, _, _, _, _, _, _, _,
1001                  a, b, c, d, e, f, g, h, ..] => {
1002                         ((*a as u64) << 8*7) |
1003                         ((*b as u64) << 8*6) |
1004                         ((*c as u64) << 8*5) |
1005                         ((*d as u64) << 8*4) |
1006                         ((*e as u64) << 8*3) |
1007                         ((*f as u64) << 8*2) |
1008                         ((*g as u64) << 8*1) |
1009                         ((*h as u64) << 8*0)
1010                 },
1011                 _ => panic!(),
1012         }
1013 }
1014
1015 const fn u64_from_bytes_f_panicking(b: &[u8]) -> u64 {
1016         match b {
1017                 [_, _, _, _, _, _, _, _,
1018                  _, _, _, _, _, _, _, _,
1019                  _, _, _, _, _, _, _, _,
1020                  _, _, _, _, _, _, _, _,
1021                  _, _, _, _, _, _, _, _,
1022                  a, b, c, d, e, f, g, h, ..] => {
1023                         ((*a as u64) << 8*7) |
1024                         ((*b as u64) << 8*6) |
1025                         ((*c as u64) << 8*5) |
1026                         ((*d as u64) << 8*4) |
1027                         ((*e as u64) << 8*3) |
1028                         ((*f as u64) << 8*2) |
1029                         ((*g as u64) << 8*1) |
1030                         ((*h as u64) << 8*0)
1031                 },
1032                 _ => panic!(),
1033         }
1034 }
1035
1036 impl U256 {
1037         /// Constructs a new [`U256`] from a variable number of big-endian bytes.
1038         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U256, ()> {
1039                 if bytes.len() > 256/8 { return Err(()); }
1040                 let u64s = (bytes.len() + 7) / 8;
1041                 let mut res = [0; WORD_COUNT_256];
1042                 for i in 0..u64s {
1043                         let mut b = [0; 8];
1044                         let pos = (u64s - i) * 8;
1045                         let start = bytes.len().saturating_sub(pos);
1046                         let end = bytes.len() + 8 - pos;
1047                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
1048                         res[i + WORD_COUNT_256 - u64s] = u64::from_be_bytes(b);
1049                 }
1050                 Ok(U256(res))
1051         }
1052
1053         /// Constructs a new [`U256`] from a fixed number of big-endian bytes.
1054         pub(super) const fn from_32_be_bytes_panicking(bytes: &[u8; 32]) -> U256 {
1055                 let res = [
1056                         u64_from_bytes_a_panicking(bytes),
1057                         u64_from_bytes_b_panicking(bytes),
1058                         u64_from_bytes_c_panicking(bytes),
1059                         u64_from_bytes_d_panicking(bytes),
1060                 ];
1061                 U256(res)
1062         }
1063
1064         pub(super) const fn zero() -> U256 { U256([0, 0, 0, 0]) }
1065         pub(super) const fn one() -> U256 { U256([0, 0, 0, 1]) }
1066         pub(super) const fn three() -> U256 { U256([0, 0, 0, 3]) }
1067 }
1068
1069 impl<M: PrimeModulus<U256>> U256Mod<M> {
1070         const fn mont_reduction(mu: [u64; 8]) -> Self {
1071                 #[cfg(debug_assertions)] {
1072                         // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
1073                         // should be able to do it at compile time alone.
1074                         let minus_one_mod_r = mul_4(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1075                         assert!(slice_equal(const_subslice(&minus_one_mod_r, 4, 8), &[0xffff_ffff_ffff_ffff; 4]));
1076                 }
1077
1078                 #[cfg(debug_assertions)] {
1079                         // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
1080                         // should be able to do it at compile time alone.
1081                         let r_minus_one = [0xffff_ffff_ffff_ffff; 4];
1082                         let (mut r_mod_prime, _) = sub_4(&r_minus_one, &M::PRIME.0);
1083                         add_u64!(r_mod_prime, 1);
1084                         let r_squared = sqr_4(&r_mod_prime);
1085                         let mut prime_extended = [0; 8];
1086                         let prime = M::PRIME.0;
1087                         copy_from_slice!(prime_extended, 4, 8, prime);
1088                         let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_8(&r_squared, &prime_extended) { v } else { panic!() };
1089                         assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
1090                         assert!(slice_equal(const_subslice(&r_squared_mod_prime, 4, 8), &M::R_SQUARED_MOD_PRIME.0));
1091                 }
1092
1093                 let mu_mod_r = const_subslice(&mu, 4, 8);
1094                 let mut v = mul_4(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1095                 const ZEROS: &[u64; 4] = &[0; 4];
1096                 copy_from_slice!(v, 0, 4, ZEROS); // mod R
1097                 let t0 = mul_4(const_subslice(&v, 4, 8), &M::PRIME.0);
1098                 let (t1, t1_extra_bit) = add_8(&t0, &mu);
1099                 let t1_on_r = const_subslice(&t1, 0, 4);
1100                 let mut res = [0; 4];
1101                 if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
1102                         let underflow;
1103                         (res, underflow) = sub_4(&t1_on_r, &M::PRIME.0);
1104                         debug_assert!(t1_extra_bit == underflow);
1105                 } else {
1106                         copy_from_slice!(res, 0, 4, t1_on_r);
1107                 }
1108                 Self(U256(res), PhantomData)
1109         }
1110
1111         pub(super) const fn from_u256_panicking(v: U256) -> Self {
1112                 assert!(v.0[0] <= M::PRIME.0[0]);
1113                 if v.0[0] == M::PRIME.0[0] {
1114                         assert!(v.0[1] <= M::PRIME.0[1]);
1115                         if v.0[1] == M::PRIME.0[1] {
1116                                 assert!(v.0[2] <= M::PRIME.0[2]);
1117                                 if v.0[2] == M::PRIME.0[2] {
1118                                         assert!(v.0[3] < M::PRIME.0[3]);
1119                                 }
1120                         }
1121                 }
1122                 assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0 || M::PRIME.0[3] != 0);
1123                 Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1124         }
1125
1126         pub(super) fn from_u256(mut v: U256) -> Self {
1127                 debug_assert!(M::PRIME.0 != [0; 4]);
1128                 debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
1129                 while v >= M::PRIME {
1130                         let (new_v, spurious_underflow) = sub_4(&v.0, &M::PRIME.0);
1131                         debug_assert!(!spurious_underflow);
1132                         v = U256(new_v);
1133                 }
1134                 Self::mont_reduction(mul_4(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1135         }
1136
1137         pub(super) fn from_modinv_of(v: U256) -> Result<Self, ()> {
1138                 Ok(Self::from_u256(U256(mod_inv_4(&v.0, &M::PRIME.0)?)))
1139         }
1140
1141         /// Multiplies `self` * `b` mod `m`.
1142         ///
1143         /// Panics if `self`'s modulus is not equal to `b`'s
1144         pub(super) fn mul(&self, b: &Self) -> Self {
1145                 Self::mont_reduction(mul_4(&self.0.0, &b.0.0))
1146         }
1147
1148         /// Doubles `self` mod `m`.
1149         pub(super) fn double(&self) -> Self {
1150                 let mut res = self.0.0;
1151                 let overflow = double!(res);
1152                 if overflow || !slice_greater_than(&M::PRIME.0, &res) {
1153                         let underflow;
1154                         (res, underflow) = sub_4(&res, &M::PRIME.0);
1155                         debug_assert_eq!(overflow, underflow);
1156                 }
1157                 Self(U256(res), PhantomData)
1158         }
1159
1160         /// Multiplies `self` by 3 mod `m`.
1161         pub(super) fn times_three(&self) -> Self {
1162                 // TODO: Optimize this a lot
1163                 self.mul(&U256Mod::from_u256(U256::three()))
1164         }
1165
1166         /// Multiplies `self` by 4 mod `m`.
1167         pub(super) fn times_four(&self) -> Self {
1168                 // TODO: Optimize this somewhat?
1169                 self.double().double()
1170         }
1171
1172         /// Multiplies `self` by 8 mod `m`.
1173         pub(super) fn times_eight(&self) -> Self {
1174                 // TODO: Optimize this somewhat?
1175                 self.double().double().double()
1176         }
1177
1178         /// Multiplies `self` by 8 mod `m`.
1179         pub(super) fn square(&self) -> Self {
1180                 Self::mont_reduction(sqr_4(&self.0.0))
1181         }
1182
1183         /// Subtracts `b` from `self` % `m`.
1184         pub(super) fn sub(&self, b: &Self) -> Self {
1185                 let (mut val, underflow) = sub_4(&self.0.0, &b.0.0);
1186                 if underflow {
1187                         let overflow;
1188                         (val, overflow) = add_4(&val, &M::PRIME.0);
1189                         debug_assert_eq!(overflow, underflow);
1190                 }
1191                 Self(U256(val), PhantomData)
1192         }
1193
1194         /// Adds `b` to `self` % `m`.
1195         pub(super) fn add(&self, b: &Self) -> Self {
1196                 let (mut val, overflow) = add_4(&self.0.0, &b.0.0);
1197                 if overflow || !slice_greater_than(&M::PRIME.0, &val) {
1198                         let underflow;
1199                         (val, underflow) = sub_4(&val, &M::PRIME.0);
1200                         debug_assert_eq!(overflow, underflow);
1201                 }
1202                 Self(U256(val), PhantomData)
1203         }
1204
1205         /// Returns the underlying [`U256`].
1206         pub(super) fn into_u256(self) -> U256 {
1207                 let mut expanded_self = [0; 8];
1208                 expanded_self[4..].copy_from_slice(&self.0.0);
1209                 Self::mont_reduction(expanded_self).0
1210         }
1211 }
1212
1213 impl U384 {
1214         /// Constructs a new [`U384`] from a variable number of big-endian bytes.
1215         pub(super) fn from_be_bytes(bytes: &[u8]) -> Result<U384, ()> {
1216                 if bytes.len() > 384/8 { return Err(()); }
1217                 let u64s = (bytes.len() + 7) / 8;
1218                 let mut res = [0; WORD_COUNT_384];
1219                 for i in 0..u64s {
1220                         let mut b = [0; 8];
1221                         let pos = (u64s - i) * 8;
1222                         let start = bytes.len().saturating_sub(pos);
1223                         let end = bytes.len() + 8 - pos;
1224                         b[8 + start - end..].copy_from_slice(&bytes[start..end]);
1225                         res[i + WORD_COUNT_384 - u64s] = u64::from_be_bytes(b);
1226                 }
1227                 Ok(U384(res))
1228         }
1229
1230         /// Constructs a new [`U384`] from a fixed number of big-endian bytes.
1231         pub(super) const fn from_48_be_bytes_panicking(bytes: &[u8; 48]) -> U384 {
1232                 let res = [
1233                         u64_from_bytes_a_panicking(bytes),
1234                         u64_from_bytes_b_panicking(bytes),
1235                         u64_from_bytes_c_panicking(bytes),
1236                         u64_from_bytes_d_panicking(bytes),
1237                         u64_from_bytes_e_panicking(bytes),
1238                         u64_from_bytes_f_panicking(bytes),
1239                 ];
1240                 U384(res)
1241         }
1242
1243         pub(super) const fn zero() -> U384 { U384([0, 0, 0, 0, 0, 0]) }
1244         pub(super) const fn one() -> U384 { U384([0, 0, 0, 0, 0, 1]) }
1245         pub(super) const fn three() -> U384 { U384([0, 0, 0, 0, 0, 3]) }
1246 }
1247
1248 impl<M: PrimeModulus<U384>> U384Mod<M> {
1249         const fn mont_reduction(mu: [u64; 12]) -> Self {
1250                 #[cfg(debug_assertions)] {
1251                         // Check NEGATIVE_PRIME_INV_MOD_R is correct. Since this is all const, the compiler
1252                         // should be able to do it at compile time alone.
1253                         let minus_one_mod_r = mul_6(&M::PRIME.0, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1254                         assert!(slice_equal(const_subslice(&minus_one_mod_r, 6, 12), &[0xffff_ffff_ffff_ffff; 6]));
1255                 }
1256
1257                 #[cfg(debug_assertions)] {
1258                         // Check R_SQUARED_MOD_PRIME is correct. Since this is all const, the compiler
1259                         // should be able to do it at compile time alone.
1260                         let r_minus_one = [0xffff_ffff_ffff_ffff; 6];
1261                         let (mut r_mod_prime, _) = sub_6(&r_minus_one, &M::PRIME.0);
1262                         add_u64!(r_mod_prime, 1);
1263                         let r_squared = sqr_6(&r_mod_prime);
1264                         let mut prime_extended = [0; 12];
1265                         let prime = M::PRIME.0;
1266                         copy_from_slice!(prime_extended, 6, 12, prime);
1267                         let (_, r_squared_mod_prime) = if let Ok(v) = div_rem_12(&r_squared, &prime_extended) { v } else { panic!() };
1268                         assert!(slice_greater_than(&prime_extended, &r_squared_mod_prime));
1269                         assert!(slice_equal(const_subslice(&r_squared_mod_prime, 6, 12), &M::R_SQUARED_MOD_PRIME.0));
1270                 }
1271
1272                 let mu_mod_r = const_subslice(&mu, 6, 12);
1273                 let mut v = mul_6(&mu_mod_r, &M::NEGATIVE_PRIME_INV_MOD_R.0);
1274                 const ZEROS: &[u64; 6] = &[0; 6];
1275                 copy_from_slice!(v, 0, 6, ZEROS); // mod R
1276                 let t0 = mul_6(const_subslice(&v, 6, 12), &M::PRIME.0);
1277                 let (t1, t1_extra_bit) = add_12(&t0, &mu);
1278                 let t1_on_r = const_subslice(&t1, 0, 6);
1279                 let mut res = [0; 6];
1280                 if t1_extra_bit || slice_greater_than(&t1_on_r, &M::PRIME.0) {
1281                         let underflow;
1282                         (res, underflow) = sub_6(&t1_on_r, &M::PRIME.0);
1283                         debug_assert!(t1_extra_bit == underflow);
1284                 } else {
1285                         copy_from_slice!(res, 0, 6, t1_on_r);
1286                 }
1287                 Self(U384(res), PhantomData)
1288         }
1289
1290         pub(super) const fn from_u384_panicking(v: U384) -> Self {
1291                 assert!(v.0[0] <= M::PRIME.0[0]);
1292                 if v.0[0] == M::PRIME.0[0] {
1293                         assert!(v.0[1] <= M::PRIME.0[1]);
1294                         if v.0[1] == M::PRIME.0[1] {
1295                                 assert!(v.0[2] <= M::PRIME.0[2]);
1296                                 if v.0[2] == M::PRIME.0[2] {
1297                                         assert!(v.0[3] <= M::PRIME.0[3]);
1298                                         if v.0[3] == M::PRIME.0[3] {
1299                                                 assert!(v.0[4] <= M::PRIME.0[4]);
1300                                                 if v.0[4] == M::PRIME.0[4] {
1301                                                         assert!(v.0[5] < M::PRIME.0[5]);
1302                                                 }
1303                                         }
1304                                 }
1305                         }
1306                 }
1307                 assert!(M::PRIME.0[0] != 0 || M::PRIME.0[1] != 0 || M::PRIME.0[2] != 0
1308                         || M::PRIME.0[3] != 0|| M::PRIME.0[4] != 0|| M::PRIME.0[5] != 0);
1309                 Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1310         }
1311
1312         pub(super) fn from_u384(mut v: U384) -> Self {
1313                 debug_assert!(M::PRIME.0 != [0; 6]);
1314                 debug_assert!(M::PRIME.0[0] > (1 << 63), "PRIME should have the top bit set");
1315                 while v >= M::PRIME {
1316                         let (new_v, spurious_underflow) = sub_6(&v.0, &M::PRIME.0);
1317                         debug_assert!(!spurious_underflow);
1318                         v = U384(new_v);
1319                 }
1320                 Self::mont_reduction(mul_6(&M::R_SQUARED_MOD_PRIME.0, &v.0))
1321         }
1322
1323         pub(super) fn from_modinv_of(v: U384) -> Result<Self, ()> {
1324                 Ok(Self::from_u384(U384(mod_inv_6(&v.0, &M::PRIME.0)?)))
1325         }
1326
1327         /// Multiplies `self` * `b` mod `m`.
1328         ///
1329         /// Panics if `self`'s modulus is not equal to `b`'s
1330         pub(super) fn mul(&self, b: &Self) -> Self {
1331                 Self::mont_reduction(mul_6(&self.0.0, &b.0.0))
1332         }
1333
1334         /// Doubles `self` mod `m`.
1335         pub(super) fn double(&self) -> Self {
1336                 let mut res = self.0.0;
1337                 let overflow = double!(res);
1338                 if overflow || !slice_greater_than(&M::PRIME.0, &res) {
1339                         let underflow;
1340                         (res, underflow) = sub_6(&res, &M::PRIME.0);
1341                         debug_assert_eq!(overflow, underflow);
1342                 }
1343                 Self(U384(res), PhantomData)
1344         }
1345
1346         /// Multiplies `self` by 3 mod `m`.
1347         pub(super) fn times_three(&self) -> Self {
1348                 // TODO: Optimize this a lot
1349                 self.mul(&U384Mod::from_u384(U384::three()))
1350         }
1351
1352         /// Multiplies `self` by 4 mod `m`.
1353         pub(super) fn times_four(&self) -> Self {
1354                 // TODO: Optimize this somewhat?
1355                 self.double().double()
1356         }
1357
1358         /// Multiplies `self` by 8 mod `m`.
1359         pub(super) fn times_eight(&self) -> Self {
1360                 // TODO: Optimize this somewhat?
1361                 self.double().double().double()
1362         }
1363
1364         /// Multiplies `self` by 8 mod `m`.
1365         pub(super) fn square(&self) -> Self {
1366                 Self::mont_reduction(sqr_6(&self.0.0))
1367         }
1368
1369         /// Subtracts `b` from `self` % `m`.
1370         pub(super) fn sub(&self, b: &Self) -> Self {
1371                 let (mut val, underflow) = sub_6(&self.0.0, &b.0.0);
1372                 if underflow {
1373                         let overflow;
1374                         (val, overflow) = add_6(&val, &M::PRIME.0);
1375                         debug_assert_eq!(overflow, underflow);
1376                 }
1377                 Self(U384(val), PhantomData)
1378         }
1379
1380         /// Adds `b` to `self` % `m`.
1381         pub(super) fn add(&self, b: &Self) -> Self {
1382                 let (mut val, overflow) = add_6(&self.0.0, &b.0.0);
1383                 if overflow || !slice_greater_than(&M::PRIME.0, &val) {
1384                         let underflow;
1385                         (val, underflow) = sub_6(&val, &M::PRIME.0);
1386                         debug_assert_eq!(overflow, underflow);
1387                 }
1388                 Self(U384(val), PhantomData)
1389         }
1390
1391         /// Returns the underlying [`U384`].
1392         pub(super) fn into_u384(self) -> U384 {
1393                 let mut expanded_self = [0; 12];
1394                 expanded_self[6..].copy_from_slice(&self.0.0);
1395                 Self::mont_reduction(expanded_self).0
1396         }
1397 }
1398
1399 #[cfg(fuzzing)]
1400 mod fuzz_moduli {
1401         use super::*;
1402
1403         pub struct P256();
1404         impl PrimeModulus<U256> for P256 {
1405                 const PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
1406                         "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"));
1407                 const R_SQUARED_MOD_PRIME: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
1408                         "00000004fffffffdfffffffffffffffefffffffbffffffff0000000000000003"));
1409                 const NEGATIVE_PRIME_INV_MOD_R: U256 = U256::from_32_be_bytes_panicking(&hex_lit::hex!(
1410                         "ffffffff00000002000000000000000000000001000000000000000000000001"));
1411         }
1412
1413         pub struct P384();
1414         impl PrimeModulus<U384> for P384 {
1415                 const PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
1416                         "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff"));
1417                 const R_SQUARED_MOD_PRIME: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
1418                         "000000000000000000000000000000010000000200000000fffffffe000000000000000200000000fffffffe00000001"));
1419                 const NEGATIVE_PRIME_INV_MOD_R: U384 = U384::from_48_be_bytes_panicking(&hex_lit::hex!(
1420                         "00000014000000140000000c00000002fffffffcfffffffafffffffbfffffffe00000000000000010000000100000001"));
1421         }
1422 }
1423
1424 #[cfg(fuzzing)]
1425 extern crate ibig;
1426 #[cfg(fuzzing)]
1427 /// Read some bytes and use them to test bigint math by comparing results against the `ibig` crate.
1428 pub fn fuzz_math(input: &[u8]) {
1429         if input.len() < 32 || input.len() % 16 != 0 { return; }
1430         let split = core::cmp::min(input.len() / 2, 512);
1431         let (a, b) = input.split_at(core::cmp::min(input.len() / 2, 512));
1432         let b = &b[..split];
1433
1434         let ai = ibig::UBig::from_be_bytes(&a);
1435         let bi = ibig::UBig::from_be_bytes(&b);
1436
1437         let mut a_u64s = Vec::with_capacity(split / 8);
1438         for chunk in a.chunks(8) {
1439                 a_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
1440         }
1441         let mut b_u64s = Vec::with_capacity(split / 8);
1442         for chunk in b.chunks(8) {
1443                 b_u64s.push(u64::from_be_bytes(chunk.try_into().unwrap()));
1444         }
1445
1446         macro_rules! test { ($mul: ident, $sqr: ident, $add: ident, $sub: ident, $div_rem: ident, $mod_inv: ident) => {
1447                 let res = $mul(&a_u64s, &b_u64s);
1448                 let mut res_bytes = Vec::with_capacity(input.len() / 2);
1449                 for i in res {
1450                         res_bytes.extend_from_slice(&i.to_be_bytes());
1451                 }
1452                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() * bi.clone());
1453
1454                 debug_assert_eq!($mul(&a_u64s, &a_u64s), $sqr(&a_u64s));
1455                 debug_assert_eq!($mul(&b_u64s, &b_u64s), $sqr(&b_u64s));
1456
1457                 let (res, carry) = $add(&a_u64s, &b_u64s);
1458                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
1459                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
1460                 for i in res {
1461                         res_bytes.extend_from_slice(&i.to_be_bytes());
1462                 }
1463                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + bi.clone());
1464
1465                 let mut add_u64s = a_u64s.clone();
1466                 let carry = add_u64!(add_u64s, 1);
1467                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
1468                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
1469                 for i in &add_u64s {
1470                         res_bytes.extend_from_slice(&i.to_be_bytes());
1471                 }
1472                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), ai.clone() + 1);
1473
1474                 let mut double_u64s = b_u64s.clone();
1475                 let carry = double!(double_u64s);
1476                 let mut res_bytes = Vec::with_capacity(input.len() / 2 + 1);
1477                 if carry { res_bytes.push(1); } else { res_bytes.push(0); }
1478                 for i in &double_u64s {
1479                         res_bytes.extend_from_slice(&i.to_be_bytes());
1480                 }
1481                 assert_eq!(ibig::UBig::from_be_bytes(&res_bytes), bi.clone() * 2);
1482
1483                 let (quot, rem) = if let Ok(res) =
1484                         $div_rem(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
1485                                 res
1486                         } else { return };
1487                 let mut quot_bytes = Vec::with_capacity(input.len() / 2);
1488                 for i in quot {
1489                         quot_bytes.extend_from_slice(&i.to_be_bytes());
1490                 }
1491                 let mut rem_bytes = Vec::with_capacity(input.len() / 2);
1492                 for i in rem {
1493                         rem_bytes.extend_from_slice(&i.to_be_bytes());
1494                 }
1495                 let (quoti, remi) = ibig::ops::DivRem::div_rem(ai.clone(), &bi);
1496                 assert_eq!(ibig::UBig::from_be_bytes(&quot_bytes), quoti);
1497                 assert_eq!(ibig::UBig::from_be_bytes(&rem_bytes), remi);
1498
1499                 if ai != ibig::UBig::from(0u32) { // ibig provides a spurious modular inverse for 0
1500                         let ring = ibig::modular::ModuloRing::new(&bi);
1501                         let ar = ring.from(ai.clone());
1502                         let invi = ar.inverse().map(|i| i.residue());
1503
1504                         if let Ok(modinv) = $mod_inv(&a_u64s[..].try_into().unwrap(), &b_u64s[..].try_into().unwrap()) {
1505                                 let mut modinv_bytes = Vec::with_capacity(input.len() / 2);
1506                                 for i in modinv {
1507                                         modinv_bytes.extend_from_slice(&i.to_be_bytes());
1508                                 }
1509                                 assert_eq!(invi.unwrap(), ibig::UBig::from_be_bytes(&modinv_bytes));
1510                         } else {
1511                                 assert!(invi.is_none());
1512                         }
1513                 }
1514         } }
1515
1516         macro_rules! test_mod { ($amodp: expr, $bmodp: expr, $PRIME: expr, $len: expr, $into: ident, $div_rem_double: ident, $div_rem: ident, $mul: ident, $add: ident, $sub: ident) => {
1517                 // Test the U256/U384Mod wrapper, which operates in Montgomery representation
1518                 let mut p_extended = [0; $len * 2];
1519                 p_extended[$len..].copy_from_slice(&$PRIME);
1520
1521                 let amodp_squared = $div_rem_double(&$mul(&a_u64s, &a_u64s), &p_extended).unwrap().1;
1522                 assert_eq!(&amodp_squared[..$len], &[0; $len]);
1523                 assert_eq!(&$amodp.square().$into().0, &amodp_squared[$len..]);
1524
1525                 let abmodp = $div_rem_double(&$mul(&a_u64s, &b_u64s), &p_extended).unwrap().1;
1526                 assert_eq!(&abmodp[..$len], &[0; $len]);
1527                 assert_eq!(&$amodp.mul(&$bmodp).$into().0, &abmodp[$len..]);
1528
1529                 let (aplusb, aplusb_overflow) = $add(&a_u64s, &b_u64s);
1530                 let mut aplusb_extended = [0; $len * 2];
1531                 aplusb_extended[$len..].copy_from_slice(&aplusb);
1532                 if aplusb_overflow { aplusb_extended[$len - 1] = 1; }
1533                 let aplusbmodp = $div_rem_double(&aplusb_extended, &p_extended).unwrap().1;
1534                 assert_eq!(&aplusbmodp[..$len], &[0; $len]);
1535                 assert_eq!(&$amodp.add(&$bmodp).$into().0, &aplusbmodp[$len..]);
1536
1537                 let (mut aminusb, aminusb_underflow) = $sub(&a_u64s, &b_u64s);
1538                 if aminusb_underflow {
1539                         let mut overflow;
1540                         (aminusb, overflow) = $add(&aminusb, &$PRIME);
1541                         if !overflow {
1542                                 (aminusb, overflow) = $add(&aminusb, &$PRIME);
1543                         }
1544                         assert!(overflow);
1545                 }
1546                 let aminusbmodp = $div_rem(&aminusb, &$PRIME).unwrap().1;
1547                 assert_eq!(&$amodp.sub(&$bmodp).$into().0, &aminusbmodp);
1548         } }
1549
1550         if a_u64s.len() == 2 {
1551                 test!(mul_2, sqr_2, add_2, sub_2, div_rem_2, mod_inv_2);
1552         } else if a_u64s.len() == 4 {
1553                 test!(mul_4, sqr_4, add_4, sub_4, div_rem_4, mod_inv_4);
1554                 let amodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(a_u64s[..].try_into().unwrap()));
1555                 let bmodp = U256Mod::<fuzz_moduli::P256>::from_u256(U256(b_u64s[..].try_into().unwrap()));
1556                 test_mod!(amodp, bmodp, fuzz_moduli::P256::PRIME.0, 4, into_u256, div_rem_8, div_rem_4, mul_4, add_4, sub_4);
1557         } else if a_u64s.len() == 6 {
1558                 test!(mul_6, sqr_6, add_6, sub_6, div_rem_6, mod_inv_6);
1559                 let amodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(a_u64s[..].try_into().unwrap()));
1560                 let bmodp = U384Mod::<fuzz_moduli::P384>::from_u384(U384(b_u64s[..].try_into().unwrap()));
1561                 test_mod!(amodp, bmodp, fuzz_moduli::P384::PRIME.0, 6, into_u384, div_rem_12, div_rem_6, mul_6, add_6, sub_6);
1562         } else if a_u64s.len() == 8 {
1563                 test!(mul_8, sqr_8, add_8, sub_8, div_rem_8, mod_inv_8);
1564         } else if input.len() == 512*2 + 4 {
1565                 let mut e_bytes = [0; 4];
1566                 e_bytes.copy_from_slice(&input[512 * 2..512 * 2 + 4]);
1567                 let e = u32::from_le_bytes(e_bytes);
1568                 let a = U4096::from_be_bytes(&a).unwrap();
1569                 let b = U4096::from_be_bytes(&b).unwrap();
1570
1571                 let res = if let Ok(r) = a.expmod_odd_mod(e, &b) { r } else { return };
1572                 let mut res_bytes = Vec::with_capacity(512);
1573                 for i in res.0 {
1574                         res_bytes.extend_from_slice(&i.to_be_bytes());
1575                 }
1576
1577                 let ring = ibig::modular::ModuloRing::new(&bi);
1578                 let ar = ring.from(ai.clone());
1579                 assert_eq!(ar.pow(&e.into()).residue(), ibig::UBig::from_be_bytes(&res_bytes));
1580         }
1581 }
1582
1583 #[cfg(test)]
1584 mod tests {
1585         use super::*;
1586
1587         fn u64s_to_u128(v: [u64; 2]) -> u128 {
1588                 let mut r = 0;
1589                 r |= v[1] as u128;
1590                 r |= (v[0] as u128) << 64;
1591                 r
1592         }
1593
1594         fn u64s_to_i128(v: [u64; 2]) -> i128 {
1595                 let mut r = 0;
1596                 r |= v[1] as i128;
1597                 r |= (v[0] as i128) << 64;
1598                 r
1599         }
1600
1601         #[test]
1602         fn test_negate() {
1603                 let mut zero = [0u64; 2];
1604                 negate!(zero);
1605                 assert_eq!(zero, [0; 2]);
1606
1607                 let mut one = [0u64, 1u64];
1608                 negate!(one);
1609                 assert_eq!(u64s_to_i128(one), -1);
1610
1611                 let mut minus_one: [u64; 2] = [u64::MAX, u64::MAX];
1612                 negate!(minus_one);
1613                 assert_eq!(minus_one, [0, 1]);
1614         }
1615
1616         #[test]
1617         fn test_double() {
1618                 let mut zero = [0u64; 2];
1619                 assert!(!double!(zero));
1620                 assert_eq!(zero, [0; 2]);
1621
1622                 let mut one = [0u64, 1u64];
1623                 assert!(!double!(one));
1624                 assert_eq!(one, [0, 2]);
1625
1626                 let mut u64_max = [0, u64::MAX];
1627                 assert!(!double!(u64_max));
1628                 assert_eq!(u64_max, [1, u64::MAX - 1]);
1629
1630                 let mut u64_carry_overflow = [0x7fff_ffff_ffff_ffffu64, 0x8000_0000_0000_0000];
1631                 assert!(!double!(u64_carry_overflow));
1632                 assert_eq!(u64_carry_overflow, [u64::MAX, 0]);
1633
1634                 let mut max = [u64::MAX; 4];
1635                 assert!(double!(max));
1636                 assert_eq!(max, [u64::MAX, u64::MAX, u64::MAX, u64::MAX - 1]);
1637         }
1638
1639         #[test]
1640         fn mul_min_simple_tests() {
1641                 let a = [1, 2];
1642                 let b = [3, 4];
1643                 let res = mul_2(&a, &b);
1644                 assert_eq!(res, [0, 3, 10, 8]);
1645
1646                 let a = [0x1bad_cafe_dead_beef, 2424];
1647                 let b = [0x2bad_beef_dead_cafe, 4242];
1648                 let res = mul_2(&a, &b);
1649                 assert_eq!(res, [340296855556511776, 15015369169016130186, 4248480538569992542, 10282608]);
1650
1651                 let a = [0xf6d9_f8eb_8b60_7a6d, 0x4b93_833e_2194_fc2e];
1652                 let b = [0xfdab_0000_6952_8ab4, 0xd302_0000_8282_0000];
1653                 let res = mul_2(&a, &b);
1654                 assert_eq!(res, [17625486516939878681, 18390748118453258282, 2695286104209847530, 1510594524414214144]);
1655
1656                 let a = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
1657                 let b = [0x8b8b_8b8b_8b8b_8b8b, 0x8b8b_8b8b_8b8b_8b8b];
1658                 let res = mul_2(&a, &b);
1659                 assert_eq!(res, [5481115605507762349, 8230042173354675923, 16737530186064798, 15714555036048702841]);
1660
1661                 let a = [0x0000_0000_0000_0020, 0x002d_362c_005b_7753];
1662                 let b = [0x0900_0000_0030_0003, 0xb708_00fe_0000_00cd];
1663                 let res = mul_2(&a, &b);
1664                 assert_eq!(res, [1, 2306290405521702946, 17647397529888728169, 10271802099389861239]);
1665
1666                 let a = [0x0000_0000_7fff_ffff, 0xffff_ffff_0000_0000];
1667                 let b = [0x0000_0800_0000_0000, 0x0000_1000_0000_00e1];
1668                 let res = mul_2(&a, &b);
1669                 assert_eq!(res, [1024, 0, 483183816703, 18446743107341910016]);
1670
1671                 let a = [0xf6d9_f8eb_ebeb_eb6d, 0x4b93_83a0_bb35_0680];
1672                 let b = [0xfd02_b9b9_b9b9_b9b9, 0xb9b9_b9b9_b9b9_b9b9];
1673                 let res = mul_2(&a, &b);
1674                 assert_eq!(res, [17579814114991930107, 15033987447865175985, 488855932380801351, 5453318140933190272]);
1675
1676                 let a = [u64::MAX; 2];
1677                 let b = [u64::MAX; 2];
1678                 let res = mul_2(&a, &b);
1679                 assert_eq!(res, [18446744073709551615, 18446744073709551614, 0, 1]);
1680         }
1681
1682         #[test]
1683         fn test_add_sub() {
1684                 fn test(a: [u64; 2], b: [u64; 2]) {
1685                         let a_int = u64s_to_u128(a);
1686                         let b_int = u64s_to_u128(b);
1687
1688                         let res = add_2(&a, &b);
1689                         assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_add(b_int));
1690
1691                         let res = sub_2(&a, &b);
1692                         assert_eq!((u64s_to_u128(res.0), res.1), a_int.overflowing_sub(b_int));
1693                 }
1694
1695                 test([0; 2], [0; 2]);
1696                 test([0x1bad_cafe_dead_beef, 2424], [0x2bad_cafe_dead_cafe, 4242]);
1697                 test([u64::MAX; 2], [u64::MAX; 2]);
1698                 test([u64::MAX, 0x8000_0000_0000_0000], [0, 0x7fff_ffff_ffff_ffff]);
1699                 test([0, 0x7fff_ffff_ffff_ffff], [u64::MAX, 0x8000_0000_0000_0000]);
1700                 test([u64::MAX, 0], [0, u64::MAX]);
1701                 test([0, u64::MAX], [u64::MAX, 0]);
1702                 test([u64::MAX; 2], [0; 2]);
1703                 test([0; 2], [u64::MAX; 2]);
1704         }
1705
1706         #[test]
1707         fn mul_4_simple_tests() {
1708                 let a = [1; 4];
1709                 let b = [2; 4];
1710                 assert_eq!(mul_4(&a, &b),
1711                         [0, 2, 4, 6, 8, 6, 4, 2]);
1712
1713                 let a = [0x1bad_cafe_dead_beef, 2424, 0x1bad_cafe_dead_beef, 2424];
1714                 let b = [0x2bad_beef_dead_cafe, 4242, 0x2bad_beef_dead_cafe, 4242];
1715                 assert_eq!(mul_4(&a, &b),
1716                         [340296855556511776, 15015369169016130186, 4929074249683016095, 11583994264332991364,
1717                          8837257932696496860, 15015369169036695402, 4248480538569992542, 10282608]);
1718
1719                 let a = [u64::MAX; 4];
1720                 let b = [u64::MAX; 4];
1721                 assert_eq!(mul_4(&a, &b),
1722                         [18446744073709551615, 18446744073709551615, 18446744073709551615,
1723                          18446744073709551614, 0, 0, 0, 1]);
1724         }
1725
1726         #[test]
1727         fn double_simple_tests() {
1728                 let mut a = [0xfff5_b32d_01ff_0000, 0x00e7_e7e7_e7e7_e7e7];
1729                 assert!(double!(a));
1730                 assert_eq!(a, [18440945635998695424, 130551405668716494]);
1731
1732                 let mut a = [u64::MAX, u64::MAX];
1733                 assert!(double!(a));
1734                 assert_eq!(a, [18446744073709551615, 18446744073709551614]);
1735         }
1736 }