Swap `ring` for our own in-crate ECDSA validator
[dnssec-prover] / src / crypto / ec.rs
1 //! Simple verification of ECDSA signatures over SECP Random curves
2
3 use super::bigint::*;
4
5 pub(super) trait IntMod: Clone + Eq + Sized {
6         type I: Int;
7         fn from_i(v: Self::I) -> Self;
8         fn from_modinv_of(v: Self::I) -> Result<Self, ()>;
9
10         const ZERO: Self;
11         const ONE: Self;
12
13         fn mul(&self, o: &Self) -> Self;
14         fn square(&self) -> Self;
15         fn add(&self, o: &Self) -> Self;
16         fn sub(&self, o: &Self) -> Self;
17         fn double(&self) -> Self;
18         fn times_three(&self) -> Self;
19         fn times_four(&self) -> Self;
20         fn times_eight(&self) -> Self;
21
22         fn into_i(self) -> Self::I;
23 }
24 impl<M: PrimeModulus<U256> + Clone + Eq> IntMod for U256Mod<M> {
25         type I = U256;
26         fn from_i(v: Self::I) -> Self { U256Mod::from_u256(v) }
27         fn from_modinv_of(v: Self::I) -> Result<Self, ()> { U256Mod::from_modinv_of(v) }
28
29         const ZERO: Self = U256Mod::<M>::from_u256_panicking(U256::zero());
30         const ONE: Self = U256Mod::<M>::from_u256_panicking(U256::one());
31
32         fn mul(&self, o: &Self) -> Self { self.mul(o) }
33         fn square(&self) -> Self { self.square() }
34         fn add(&self, o: &Self) -> Self { self.add(o) }
35         fn sub(&self, o: &Self) -> Self { self.sub(o) }
36         fn double(&self) -> Self { self.double() }
37         fn times_three(&self) -> Self { self.times_three() }
38         fn times_four(&self) -> Self { self.times_four() }
39         fn times_eight(&self) -> Self { self.times_eight() }
40
41         fn into_i(self) -> Self::I { self.into_u256() }
42 }
43 impl<M: PrimeModulus<U384> + Clone + Eq> IntMod for U384Mod<M> {
44         type I = U384;
45         fn from_i(v: Self::I) -> Self { U384Mod::from_u384(v) }
46         fn from_modinv_of(v: Self::I) -> Result<Self, ()> { U384Mod::from_modinv_of(v) }
47
48         const ZERO: Self = U384Mod::<M>::from_u384_panicking(U384::zero());
49         const ONE: Self = U384Mod::<M>::from_u384_panicking(U384::one());
50
51         fn mul(&self, o: &Self) -> Self { self.mul(o) }
52         fn square(&self) -> Self { self.square() }
53         fn add(&self, o: &Self) -> Self { self.add(o) }
54         fn sub(&self, o: &Self) -> Self { self.sub(o) }
55         fn double(&self) -> Self { self.double() }
56         fn times_three(&self) -> Self { self.times_three() }
57         fn times_four(&self) -> Self { self.times_four() }
58         fn times_eight(&self) -> Self { self.times_eight() }
59
60         fn into_i(self) -> Self::I { self.into_u384() }
61 }
62
63 pub(super) trait Curve : Copy {
64         type Int: Int;
65
66         // With const generics, both IntModP and IntModN can be replaced with a single IntMod.
67         type IntModP: IntMod<I = Self::Int>;
68         type IntModN: IntMod<I = Self::Int>;
69
70         type P: PrimeModulus<Self::Int>;
71         type N: PrimeModulus<Self::Int>;
72
73         // Curve parameters y^2 = x^3 + ax + b
74         const A: Self::IntModP;
75         const B: Self::IntModP;
76
77         const G: Point<Self>;
78 }
79
80 #[derive(Clone, PartialEq, Eq)]
81 pub(super) struct Point<C: Curve + ?Sized> {
82         x: C::IntModP,
83         y: C::IntModP,
84         z: C::IntModP,
85 }
86
87 impl<C: Curve + ?Sized> Point<C> {
88         fn on_curve(x: &C::IntModP, y: &C::IntModP) -> Result<(), ()> {
89                 let x_2 = x.square();
90                 let x_3 = x_2.mul(&x);
91                 let v = x_3.add(&C::A.mul(&x)).add(&C::B);
92
93                 let y_2 = y.square();
94                 if y_2 != v {
95                         Err(())
96                 } else {
97                         Ok(())
98                 }
99         }
100
101         #[cfg(debug_assertions)]
102         fn on_curve_z(x: &C::IntModP, y: &C::IntModP, z: &C::IntModP) -> Result<(), ()> {
103                 let m = C::IntModP::from_modinv_of(z.clone().into_i())?;
104                 let m_2 = m.square();
105                 let m_3 = m_2.mul(&m);
106                 let x_norm = x.mul(&m_2);
107                 let y_norm = y.mul(&m_3);
108                 Self::on_curve(&x_norm, &y_norm)
109         }
110
111         #[cfg(test)]
112         fn normalize_x(&self) -> Result<C::IntModP, ()> {
113                 let m = C::IntModP::from_modinv_of(self.z.clone().into_i())?;
114                 Ok(self.x.mul(&m.square()))
115         }
116
117         fn from_xy(x: C::Int, y: C::Int) -> Result<Self, ()> {
118                 let x = C::IntModP::from_i(x);
119                 let y = C::IntModP::from_i(y);
120                 Self::on_curve(&x, &y)?;
121                 Ok(Point { x, y, z: C::IntModP::ONE })
122         }
123
124         pub(super) const fn from_xy_assuming_on_curve(x: C::IntModP, y: C::IntModP) -> Self {
125                 Point { x, y, z: C::IntModP::ONE }
126         }
127
128         /// Checks that `expected_x` is equal to our X affine coordinate (without modular inversion).
129         fn eq_x(&self, expected_x: &C::IntModN) -> Result<(), ()> {
130                 debug_assert!(expected_x.clone().into_i() < C::P::PRIME, "N is < P");
131
132                 // If x is between N and P the below calculations will fail and we'll spuriously reject a
133                 // signature and the wycheproof tests will fail. We should in theory accept such
134                 // signatures, but the probability of this happening at random is roughly 1/2^128, i.e. we
135                 // really don't need to handle it in practice. Thus, we only bother to do this in tests.
136                 #[allow(unused_mut, unused_assignments)]
137                 let mut slow_check = None;
138                 #[cfg(test)] {
139                         slow_check = Some(C::IntModN::from_i(self.normalize_x()?.into_i()) == *expected_x);
140                 }
141
142                 let e: C::IntModP = C::IntModP::from_i(expected_x.clone().into_i());
143                 if self.z == C::IntModP::ZERO { return Err(()); }
144                 let ezz = e.mul(&self.z).mul(&self.z);
145                 if self.x == ezz { Ok(()) } else {
146                         if slow_check == Some(true) { Ok(()) } else { Err(()) }
147                 }
148         }
149
150         fn double(&self) -> Result<Self, ()> {
151                 if self.y == C::IntModP::ZERO { return Err(()); }
152                 if self.z == C::IntModP::ZERO { return Err(()); }
153
154                 let s = self.x.times_four().mul(&self.y.square());
155                 let z_2 = self.z.square();
156                 let z_4 = z_2.square();
157                 let y_2 = self.y.square();
158                 let y_4 = y_2.square();
159                 let x_2 = self.x.square();
160                 let m = x_2.times_three().add(&C::A.mul(&z_4));
161                 let x = m.square().sub(&s.double());
162                 let y = m.mul(&s.sub(&x)).sub(&y_4.times_eight());
163                 let z = self.y.double().mul(&self.z);
164
165                 #[cfg(debug_assertions)] { assert!(Self::on_curve_z(&x, &y, &z).is_ok()); }
166                 Ok(Point { x, y, z })
167         }
168
169         fn add(&self, o: &Self) -> Result<Self, ()> {
170                 let o_z_2 = o.z.square();
171                 let self_z_2 = self.z.square();
172
173                 let u1 = self.x.mul(&o_z_2);
174                 let u2 = o.x.mul(&self_z_2);
175                 let s1 = self.y.mul(&o.z.mul(&o_z_2));
176                 let s2 = o.y.mul(&self.z.mul(&self_z_2));
177                 if u1 == u2 {
178                         if s1 != s2 { /* PAI */ return Err(()); }
179                         return self.double();
180                 }
181                 let h = u2.sub(&u1);
182                 let h_2 = h.square();
183                 let h_3 = h.mul(&h_2);
184                 let r = s2.sub(&s1);
185                 let x = r.square().sub(&h_3).sub(&u1.double().mul(&h_2));
186                 let y = r.mul(&u1.mul(&h_2).sub(&x)).sub(&s1.mul(&h_3));
187                 let z = h.mul(&self.z).mul(&o.z);
188
189                 #[cfg(debug_assertions)] { assert!(Self::on_curve_z(&x, &y, &z).is_ok()); }
190                 Ok(Point { x, y, z})
191         }
192 }
193
194 /// Calculates i * I + j * J
195 #[allow(non_snake_case)]
196 fn add_two_mul<C: Curve>(i: C::IntModN, I: &Point<C>, j: C::IntModN, J: &Point<C>) -> Result<Point<C>, ()> {
197         let i = i.into_i();
198         let j = j.into_i();
199
200         if i == C::Int::ZERO { /* Infinity */ return Err(()); }
201         if j == C::Int::ZERO { /* Infinity */ return Err(()); }
202
203         let mut res_opt: Result<Point<C>, ()> = Err(());
204         let i_limbs = i.limbs();
205         let j_limbs = j.limbs();
206         let mut skip_limb = 0;
207         let mut limbs_skip_iter = i_limbs.iter().zip(j_limbs.iter());
208         while limbs_skip_iter.next() == Some((&0, &0)) {
209                 skip_limb += 1;
210         }
211         for (idx, (il, jl)) in i_limbs.iter().zip(j_limbs.iter()).skip(skip_limb).enumerate() {
212                 let start_bit = if idx == 0 {
213                         core::cmp::min(il.leading_zeros(), jl.leading_zeros())
214                 } else { 0 };
215                 for b in start_bit..64 {
216                         let i_bit = (*il & (1 << (63 - b))) != 0;
217                         let j_bit = (*jl & (1 << (63 - b))) != 0;
218                         if let Ok(res) = res_opt.as_mut() {
219                                 *res = res.double()?;
220                         }
221                         if i_bit {
222                                 if let Ok(res) = res_opt.as_mut() {
223                                         // The wycheproof tests expect to see signatures pass even if we hit PAI on an
224                                         // intermediate result. While that's fine, I'm too lazy to go figure out if all
225                                         // our PAI definitions are right and the probability of this happening at
226                                         // random is, basically, the probability of guessing a private key anyway, so
227                                         // its not really worth actually handling outside of tests.
228                                         #[cfg(test)] {
229                                                 res_opt = res.add(I);
230                                         }
231                                         #[cfg(not(test))] {
232                                                 *res = res.add(I)?;
233                                         }
234                                 } else {
235                                         res_opt = Ok(I.clone());
236                                 }
237                         }
238                         if j_bit {
239                                 if let Ok(res) = res_opt.as_mut() {
240                                         // The wycheproof tests expect to see signatures pass even if we hit PAI on an
241                                         // intermediate result. While that's fine, I'm too lazy to go figure out if all
242                                         // our PAI definitions are right and the probability of this happening at
243                                         // random is, basically, the probability of guessing a private key anyway, so
244                                         // its not really worth actually handling outside of tests.
245                                         #[cfg(test)] {
246                                                 res_opt = res.add(J);
247                                         }
248                                         #[cfg(not(test))] {
249                                                 *res = res.add(J)?;
250                                         }
251                                 } else {
252                                         res_opt = Ok(J.clone());
253                                 }
254                         }
255                 }
256         }
257         res_opt
258 }
259
260 /// Validates the given signature against the given public key and message digest.
261 pub(super) fn validate_ecdsa<C: Curve>(pk: &[u8], sig: &[u8], hash_input: &[u8]) -> Result<(), ()> {
262         #![allow(non_snake_case)]
263
264         if pk.len() != C::Int::BYTES * 2 { return Err(()); }
265         if sig.len() != C::Int::BYTES * 2 { return Err(()); }
266
267         let (r_bytes, s_bytes) = sig.split_at(C::Int::BYTES);
268         let (pk_x_bytes, pk_y_bytes) = pk.split_at(C::Int::BYTES);
269
270         let pk_x = C::Int::from_be_bytes(pk_x_bytes)?;
271         let pk_y = C::Int::from_be_bytes(pk_y_bytes)?;
272         let PK = Point::from_xy(pk_x, pk_y)?;
273
274         // from_i and from_modinv_of both will simply mod if the value is out of range. While its
275         // perfectly safe to do so, the wycheproof tests expect such signatures to be rejected, so we
276         // do so here.
277         let r_u256 = C::Int::from_be_bytes(r_bytes)?;
278         if r_u256 > C::N::PRIME { return Err(()); }
279         let s_u256 = C::Int::from_be_bytes(s_bytes)?;
280         if s_u256 > C::N::PRIME { return Err(()); }
281
282         let r = C::IntModN::from_i(r_u256);
283         let s_inv = C::IntModN::from_modinv_of(s_u256)?;
284
285         let z = C::IntModN::from_i(C::Int::from_be_bytes(hash_input)?);
286
287         let u_a = z.mul(&s_inv);
288         let u_b = r.mul(&s_inv);
289
290         let V = add_two_mul(u_a, &C::G, u_b, &PK)?;
291         V.eq_x(&r)
292 }