99625fdcd922d432bb85fac7c5bf2458e0a8fb18
[dnssec-prover] / src / crypto / ec.rs
1 //! Simple verification of ECDSA signatures over SECP Random curves
2
3 use super::bigint::*;
4
5 pub(super) trait IntMod: Clone + Eq + Sized {
6         type I: Int;
7         fn from_i(v: Self::I) -> Self;
8         fn from_modinv_of(v: Self::I) -> Result<Self, ()>;
9
10         const ZERO: Self;
11         const ONE: Self;
12
13         fn mul(&self, o: &Self) -> Self;
14         fn square(&self) -> Self;
15         fn add(&self, o: &Self) -> Self;
16         fn sub(&self, o: &Self) -> Self;
17         fn double(&self) -> Self;
18         fn times_three(&self) -> Self;
19         fn times_four(&self) -> Self;
20         fn times_eight(&self) -> Self;
21
22         fn into_i(self) -> Self::I;
23 }
24 impl<M: PrimeModulus<U256> + Clone + Eq> IntMod for U256Mod<M> {
25         type I = U256;
26         fn from_i(v: Self::I) -> Self { U256Mod::from_u256(v) }
27         fn from_modinv_of(v: Self::I) -> Result<Self, ()> { U256Mod::from_modinv_of(v) }
28
29         const ZERO: Self = U256Mod::<M>::from_u256_panicking(U256::zero());
30         const ONE: Self = U256Mod::<M>::from_u256_panicking(U256::one());
31
32         fn mul(&self, o: &Self) -> Self { self.mul(o) }
33         fn square(&self) -> Self { self.square() }
34         fn add(&self, o: &Self) -> Self { self.add(o) }
35         fn sub(&self, o: &Self) -> Self { self.sub(o) }
36         fn double(&self) -> Self { self.double() }
37         fn times_three(&self) -> Self { self.times_three() }
38         fn times_four(&self) -> Self { self.times_four() }
39         fn times_eight(&self) -> Self { self.times_eight() }
40
41         fn into_i(self) -> Self::I { self.into_u256() }
42 }
43 impl<M: PrimeModulus<U384> + Clone + Eq> IntMod for U384Mod<M> {
44         type I = U384;
45         fn from_i(v: Self::I) -> Self { U384Mod::from_u384(v) }
46         fn from_modinv_of(v: Self::I) -> Result<Self, ()> { U384Mod::from_modinv_of(v) }
47
48         const ZERO: Self = U384Mod::<M>::from_u384_panicking(U384::zero());
49         const ONE: Self = U384Mod::<M>::from_u384_panicking(U384::one());
50
51         fn mul(&self, o: &Self) -> Self { self.mul(o) }
52         fn square(&self) -> Self { self.square() }
53         fn add(&self, o: &Self) -> Self { self.add(o) }
54         fn sub(&self, o: &Self) -> Self { self.sub(o) }
55         fn double(&self) -> Self { self.double() }
56         fn times_three(&self) -> Self { self.times_three() }
57         fn times_four(&self) -> Self { self.times_four() }
58         fn times_eight(&self) -> Self { self.times_eight() }
59
60         fn into_i(self) -> Self::I { self.into_u384() }
61 }
62
63 pub(super) trait Curve : Copy {
64         type Int: Int;
65
66         // With const generics, both CurveField and ScalarField can be replaced with a single IntMod.
67         type CurveField: IntMod<I = Self::Int>;
68         type ScalarField: IntMod<I = Self::Int>;
69
70         type CurveModulus: PrimeModulus<Self::Int>;
71         type ScalarModulus: PrimeModulus<Self::Int>;
72
73         // Curve parameters y^2 = x^3 + ax + b
74         const A: Self::CurveField;
75         const B: Self::CurveField;
76
77         const G: Point<Self>;
78 }
79
80 #[derive(Clone, PartialEq, Eq)]
81 /// A Point, stored in Jacobian coordinates
82 pub(super) struct Point<C: Curve + ?Sized> {
83         x: C::CurveField,
84         y: C::CurveField,
85         z: C::CurveField,
86 }
87
88 impl<C: Curve + ?Sized> Point<C> {
89         fn check_curve_conditions() {
90                 debug_assert!(C::ScalarModulus::PRIME < C::CurveModulus::PRIME, "N is < P");
91         }
92
93         fn on_curve(x: &C::CurveField, y: &C::CurveField) -> Result<(), ()> {
94                 let x_2 = x.square();
95                 let x_3 = x_2.mul(&x);
96                 let v = x_3.add(&C::A.mul(&x)).add(&C::B);
97
98                 let y_2 = y.square();
99                 if y_2 != v {
100                         Err(())
101                 } else {
102                         Ok(())
103                 }
104         }
105
106         #[cfg(debug_assertions)]
107         fn on_curve_z(x: &C::CurveField, y: &C::CurveField, z: &C::CurveField) -> Result<(), ()> {
108                 // m = 1 / z
109                 // x_norm = x * m^2
110                 // y_norm = y * m^3
111
112                 let m = C::CurveField::from_modinv_of(z.clone().into_i())?;
113                 let m_2 = m.square();
114                 let m_3 = m_2.mul(&m);
115                 let x_norm = x.mul(&m_2);
116                 let y_norm = y.mul(&m_3);
117                 Self::on_curve(&x_norm, &y_norm)
118         }
119
120         #[cfg(test)]
121         fn normalize_x(&self) -> Result<C::CurveField, ()> {
122                 let m = C::CurveField::from_modinv_of(self.z.clone().into_i())?;
123                 Ok(self.x.mul(&m.square()))
124         }
125
126         fn from_xy(x: C::Int, y: C::Int) -> Result<Self, ()> {
127                 Self::check_curve_conditions();
128
129                 let x = C::CurveField::from_i(x);
130                 let y = C::CurveField::from_i(y);
131                 Self::on_curve(&x, &y)?;
132                 Ok(Point { x, y, z: C::CurveField::ONE })
133         }
134
135         pub(super) const fn from_xy_assuming_on_curve(x: C::CurveField, y: C::CurveField) -> Self {
136                 Point { x, y, z: C::CurveField::ONE }
137         }
138
139         /// Checks that `expected_x` is equal to our X affine coordinate (without modular inversion).
140         fn eq_x(&self, expected_x: &C::ScalarField) -> Result<(), ()> {
141                 // If x is between N and P the below calculations will fail and we'll spuriously reject a
142                 // signature and the wycheproof tests will fail. We should in theory accept such
143                 // signatures, but the probability of this happening at random is roughly 1/2^128, i.e. we
144                 // really don't need to handle it in practice. Thus, we only bother to do this in tests.
145                 debug_assert!(expected_x.clone().into_i() < C::CurveModulus::PRIME, "N is < P");
146                 debug_assert!(C::ScalarModulus::PRIME < C::CurveModulus::PRIME, "N is < P");
147                 #[cfg(debug_assertions)] {
148                         // Check the above assertion - ensure the difference between the modulus of the scalar
149                         // and curve fields is less than half the bit length of our integers, which are at
150                         // least 256 bit long.
151                         let scalar_mod_on_curve = C::CurveField::from_i(C::ScalarModulus::PRIME);
152                         let diff = C::CurveField::ZERO.sub(&scalar_mod_on_curve);
153                         assert!(C::Int::BYTES * 8 / 2 >= 128, "We assume 256-bit ints and longer");
154                         assert!(C::CurveModulus::PRIME.limbs()[0] > (1 << 63), "PRIME should have the top bit set");
155                         assert!(C::ScalarModulus::PRIME.limbs()[0] > (1 << 63), "PRIME should have the top bit set");
156                         let mut half_bitlen = C::CurveField::ONE;
157                         for _ in 0..C::Int::BYTES * 8 / 2 {
158                                 half_bitlen = half_bitlen.double();
159                         }
160                         assert!(diff.into_i() < half_bitlen.into_i());
161                 }
162
163                 #[allow(unused_mut, unused_assignments)]
164                 let mut slow_check = None;
165                 #[cfg(test)] {
166                         slow_check = Some(C::ScalarField::from_i(self.normalize_x()?.into_i()) == *expected_x);
167                 }
168
169                 let e: C::CurveField = C::CurveField::from_i(expected_x.clone().into_i());
170                 if self.z == C::CurveField::ZERO { return Err(()); }
171                 let ezz = e.mul(&self.z).mul(&self.z);
172                 if self.x == ezz || slow_check == Some(true) { Ok(()) } else { Err(()) }
173         }
174
175         fn double(&self) -> Result<Self, ()> {
176                 if self.y == C::CurveField::ZERO { return Err(()); }
177                 if self.z == C::CurveField::ZERO { return Err(()); }
178
179                 // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2001-b
180                 // delta = Z1^2
181                 // gamma = Y1^2
182                 // beta = X1*gamma
183                 // alpha = 3*(X1-delta)*(X1+delta)
184                 // X3 = alpha^2-8*beta
185                 // Z3 = (Y1+Z1)^2-gamma-delta
186                 // Y3 = alpha*(4*beta-X3)-8*gamma^2
187
188                 let delta = self.z.square();
189                 let gamma = self.y.square();
190                 let beta = self.x.mul(&gamma);
191                 let alpha = self.x.sub(&delta).times_three().mul(&self.x.add(&delta));
192                 let x = alpha.square().sub(&beta.times_eight());
193                 let y = alpha.mul(&beta.times_four().sub(&x)).sub(&gamma.square().times_eight());
194                 let z = self.y.add(&self.z).square().sub(&gamma).sub(&delta);
195
196                 #[cfg(debug_assertions)] { assert!(Self::on_curve_z(&x, &y, &z).is_ok()); }
197                 Ok(Point { x, y, z })
198         }
199
200         fn add(&self, o: &Self) -> Result<Self, ()> {
201                 // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl
202                 // Z1Z1 = Z1^2
203                 // Z2Z2 = Z2^2
204                 // U1 = X1*Z2Z2
205                 // U2 = X2*Z1Z1
206                 // S1 = Y1*Z2*Z2Z2
207                 // S2 = Y2*Z1*Z1Z1
208                 // H = U2-U1
209                 // I = (2*H)^2
210                 // J = H*I
211                 // r = 2*(S2-S1)
212                 // V = U1*I
213                 // X3 = r^2-J-2*V
214                 // Y3 = r*(V-X3)-2*S1*J
215                 // Z3 = ((Z1+Z2)^2-Z1Z1-Z2Z2)*H
216
217                 let o_z_2 = o.z.square();
218                 let self_z_2 = self.z.square();
219
220                 let u1 = self.x.mul(&o_z_2);
221                 let u2 = o.x.mul(&self_z_2);
222                 let s1 = self.y.mul(&o.z.mul(&o_z_2));
223                 let s2 = o.y.mul(&self.z.mul(&self_z_2));
224                 if u1 == u2 {
225                         if s1 != s2 { /* Point at Infinity */ return Err(()); }
226                         return self.double();
227                 }
228                 let h = u2.sub(&u1);
229                 let i = h.double().square();
230                 let j = h.mul(&i);
231                 let r = s2.sub(&s1).double();
232                 let v = u1.mul(&i);
233                 let x = r.square().sub(&j).sub(&v.double());
234                 let y = r.mul(&v.sub(&x)).sub(&s1.double().mul(&j));
235                 let z = self.z.add(&o.z).square().sub(&self_z_2).sub(&o_z_2).mul(&h);
236
237                 #[cfg(debug_assertions)] { assert!(Self::on_curve_z(&x, &y, &z).is_ok()); }
238                 Ok(Point { x, y, z})
239         }
240 }
241
242 /// Calculates i * I + j * J
243 #[allow(non_snake_case)]
244 fn add_two_mul<C: Curve>(i: C::ScalarField, I: &Point<C>, j: C::ScalarField, J: &Point<C>) -> Result<Point<C>, ()> {
245         let i = i.into_i();
246         let j = j.into_i();
247
248         if i == C::Int::ZERO { /* Infinity */ return Err(()); }
249         if j == C::Int::ZERO { /* Infinity */ return Err(()); }
250
251         let mut res_opt: Result<Point<C>, ()> = Err(());
252         let i_limbs = i.limbs();
253         let j_limbs = j.limbs();
254         let mut skip_limb = 0;
255         let mut limbs_skip_iter = i_limbs.iter().zip(j_limbs.iter());
256         while limbs_skip_iter.next() == Some((&0, &0)) {
257                 skip_limb += 1;
258         }
259         for (idx, (il, jl)) in i_limbs.iter().zip(j_limbs.iter()).skip(skip_limb).enumerate() {
260                 let start_bit = if idx == 0 {
261                         core::cmp::min(il.leading_zeros(), jl.leading_zeros())
262                 } else { 0 };
263                 for b in start_bit..64 {
264                         let i_bit = (*il & (1 << (63 - b))) != 0;
265                         let j_bit = (*jl & (1 << (63 - b))) != 0;
266                         if let Ok(res) = res_opt.as_mut() {
267                                 *res = res.double()?;
268                         }
269                         if i_bit {
270                                 if let Ok(res) = res_opt.as_mut() {
271                                         // The wycheproof tests expect to see signatures pass even if we hit Point at
272                                         // Infinity (PAI) on an intermediate result. While that's fine, I'm too lazy to
273                                         // go figure out if all our PAI definitions are right and the probability of
274                                         // this happening at random is, basically, the probability of guessing a private
275                                         // key anyway, so its not really worth actually handling outside of tests.
276                                         #[cfg(test)] {
277                                                 res_opt = res.add(I);
278                                         }
279                                         #[cfg(not(test))] {
280                                                 *res = res.add(I)?;
281                                         }
282                                 } else {
283                                         res_opt = Ok(I.clone());
284                                 }
285                         }
286                         if j_bit {
287                                 if let Ok(res) = res_opt.as_mut() {
288                                         // The wycheproof tests expect to see signatures pass even if we hit Point at
289                                         // Infinity (PAI) on an intermediate result. While that's fine, I'm too lazy to
290                                         // go figure out if all our PAI definitions are right and the probability of
291                                         // this happening at random is, basically, the probability of guessing a private
292                                         // key anyway, so its not really worth actually handling outside of tests.
293                                         #[cfg(test)] {
294                                                 res_opt = res.add(J);
295                                         }
296                                         #[cfg(not(test))] {
297                                                 *res = res.add(J)?;
298                                         }
299                                 } else {
300                                         res_opt = Ok(J.clone());
301                                 }
302                         }
303                 }
304         }
305         res_opt
306 }
307
308 /// Validates the given signature against the given public key and message digest.
309 pub(super) fn validate_ecdsa<C: Curve>(pk: &[u8], sig: &[u8], hash_input: &[u8]) -> Result<(), ()> {
310         #![allow(non_snake_case)]
311
312         if pk.len() != C::Int::BYTES * 2 { return Err(()); }
313         if sig.len() != C::Int::BYTES * 2 { return Err(()); }
314
315         let (r_bytes, s_bytes) = sig.split_at(C::Int::BYTES);
316         let (pk_x_bytes, pk_y_bytes) = pk.split_at(C::Int::BYTES);
317
318         let pk_x = C::Int::from_be_bytes(pk_x_bytes)?;
319         let pk_y = C::Int::from_be_bytes(pk_y_bytes)?;
320         let PK = Point::from_xy(pk_x, pk_y)?;
321
322         // from_i and from_modinv_of both will simply mod if the value is out of range. While its
323         // perfectly safe to do so, the wycheproof tests expect such signatures to be rejected, so we
324         // do so here.
325         let r_u256 = C::Int::from_be_bytes(r_bytes)?;
326         if r_u256 > C::ScalarModulus::PRIME { return Err(()); }
327         let s_u256 = C::Int::from_be_bytes(s_bytes)?;
328         if s_u256 > C::ScalarModulus::PRIME { return Err(()); }
329
330         let r = C::ScalarField::from_i(r_u256);
331         let s_inv = C::ScalarField::from_modinv_of(s_u256)?;
332
333         let z = C::ScalarField::from_i(C::Int::from_be_bytes(hash_input)?);
334
335         let u_a = z.mul(&s_inv);
336         let u_b = r.mul(&s_inv);
337
338         let V = add_two_mul(u_a, &C::G, u_b, &PK)?;
339         V.eq_x(&r)
340 }