Add the wycheproof test cases for our crypto implementation
[dnssec-prover] / src / test / crypto.rs
1 use crate::crypto::secp256r1::validate_ecdsa as validate_256r1;
2 use crate::crypto::secp384r1::validate_ecdsa as validate_384r1;
3 use crate::crypto::rsa::validate_rsa;
4 use crate::crypto::hash::{Hasher, HashResult};
5
6 use hex_conservative::FromHex;
7 use serde_json::Value;
8 use std::fs::File;
9
10 fn open_file(name: &str) -> File {
11         if let Ok(f) = File::open(name) { return f; }
12         if let Ok(f) = File::open("../".to_owned() + name) { return f; }
13         if let Ok(f) = File::open("src/test/".to_owned() + name) { return f; }
14         if let Ok(f) = File::open("../src/test/".to_owned() + name) { return f; }
15         if let Ok(f) = File::open("../../src/test/".to_owned() + name) { return f; }
16         panic!("Failed to find file {}", name);
17 }
18
19 fn decode_asn(sig: &str, int_len: usize) -> Result<Vec<u8>, ()> {
20         // Signature is in ASN, so decode the garbage excess headers
21         // Note that some tests are specifically for the ASN parser, so we have to carefully
22         // reject invalid crap here.
23         if sig.len() < 12 { return Err(()); }
24
25         if &sig[..2] != "30" { return Err(()); }
26         let total_len = (<[u8; 2]>::from_hex(&sig[..4]).unwrap())[1] as usize;
27         if total_len + 2 != sig.len() / 2 { return Err(()); }
28
29         if &sig[4..6] != "02" { return Err(()); }
30         let r_len = (<[u8; 2]>::from_hex(&sig[4..8]).unwrap())[1] as usize;
31         if sig.len() < r_len * 2 + 8 { return Err(()); }
32         if r_len == 0 { return Err(()); }
33         let r = Vec::from_hex(&sig[8..r_len * 2 + 8]).unwrap();
34         if r.len() > int_len {
35                 // If the MSB is 1, an extra byte is required to avoid the sign flag
36                 if r.len() > int_len + 1 { return Err(()); }
37                 if r[0] != 0 { return Err(()); }
38                 if r[1] & 0b1000_0000 == 0 { return Err(()); }
39         } else if r[0] & 0b1000_0000 != 0 { return Err(()); }
40
41         if sig.len() < r_len * 2 + 12 { return Err(()); }
42         if &sig[r_len * 2 + 8..r_len * 2 + 10] != "02" { return Err(()); }
43         let s_len = (<[u8; 2]>::from_hex(&sig[r_len * 2 + 8..r_len * 2 + 12]).unwrap())[1] as usize;
44         if sig.len() != r_len * 2 + s_len * 2 + 12 { return Err(()); }
45         if s_len == 0 { return Err(()); }
46         let s = Vec::from_hex(&sig[r_len * 2 + 12..]).unwrap();
47         if s.len() > int_len {
48                 // If the MSB is 1, an extra byte is required to avoid the sign flag
49                 if s.len() > int_len + 1 { return Err(()); }
50                 if s[0] != 0 { return Err(()); }
51                 if s[1] & 0b1000_0000 == 0 { return Err(()); }
52         } else if s[0] & 0b1000_0000 != 0 { return Err(()); }
53
54         let mut sig_bytes = vec![0; int_len * 2];
55         sig_bytes[int_len.saturating_sub(r.len())..int_len]
56                 .copy_from_slice(&r[r.len().saturating_sub(int_len)..]);
57         sig_bytes[int_len + int_len.saturating_sub(s.len())..int_len * 2]
58                 .copy_from_slice(&s[s.len().saturating_sub(int_len)..]);
59
60         Ok(sig_bytes)
61 }
62
63 fn test_ecdsa<
64         Validate: Fn(&[u8], &[u8], &[u8]) -> Result<(), ()>,
65         Hash: Fn(&[u8]) -> HashResult,
66 >(v: Value, int_len: usize, validate_fn: Validate, hash_fn: Hash) {
67         for (group_idx, group) in v["testGroups"].as_array().unwrap().into_iter().enumerate() {
68                 let pk_str = group["publicKey"]["uncompressed"].as_str().unwrap();
69                 assert_eq!(&pk_str[..2], "04"); // OpenSSL uncompressed encoding flag
70                 let pk = Vec::from_hex(&pk_str[2..]).unwrap();
71                 for test in group["tests"].as_array().unwrap() {
72                         let msg = Vec::from_hex(test["msg"].as_str().unwrap()).unwrap();
73
74                         let result = match test["result"].as_str().unwrap() {
75                                 "valid" => Ok(()),
76                                 "invalid" => Err(()),
77                                 r => panic!("Unknown result type {}", r),
78                         };
79
80                         let sig = decode_asn(test["sig"].as_str().unwrap(), int_len);
81                         if sig.is_err() {
82                                 assert_eq!(result, Err(()));
83                                 continue;
84                         }
85
86                         let hash = hash_fn(&msg);
87                         assert_eq!(result, validate_fn(&pk, &sig.unwrap(), hash.as_ref()),
88                                 "Failed test case group {}, test id {}, comment {}", group_idx, test["tcId"], test["comment"]);
89                 }
90         }
91 }
92
93 #[test]
94 fn test_ecdsa_256r1() {
95         let f = open_file("ecdsa_secp256r1_sha256_test.json");
96         let v: Value = serde_json::from_reader(f).unwrap();
97         test_ecdsa(v, 32, validate_256r1, |msg| {
98                 let mut hasher = Hasher::sha256();
99                 hasher.update(msg);
100                 hasher.finish()
101         });
102 }
103
104 #[test]
105 fn test_ecdsa_384r1_sha256() {
106         let f = open_file("ecdsa_secp384r1_sha256_test.json");
107         let v: Value = serde_json::from_reader(f).unwrap();
108         test_ecdsa(v, 48, validate_384r1, |msg| {
109                 let mut hasher = Hasher::sha256();
110                 hasher.update(msg);
111                 hasher.finish()
112         });
113 }
114
115 #[test]
116 fn test_ecdsa_384r1_sha384() {
117         let f = open_file("ecdsa_secp384r1_sha384_test.json");
118         let v: Value = serde_json::from_reader(f).unwrap();
119         test_ecdsa(v, 48, validate_384r1, |msg| {
120                 let mut hasher = Hasher::sha384();
121                 hasher.update(msg);
122                 hasher.finish()
123         });
124 }
125
126 fn test_rsa<Hash: Fn(&[u8]) -> HashResult>(v: Value, pk_len: usize, hash_fn: Hash) {
127         for (group_idx, group) in v["testGroups"].as_array().unwrap().into_iter().enumerate() {
128                 let pk_str = group["publicKey"]["modulus"].as_str().unwrap();
129                 assert_eq!(&pk_str[..2], "00"); // No idea why this is here
130                 let pk = Vec::from_hex(&pk_str[2..]).unwrap();
131                 assert_eq!(pk.len(), pk_len);
132                 let exp_vec = Vec::from_hex(group["publicKey"]["publicExponent"].as_str().unwrap()).unwrap();
133                 if exp_vec.len() > 4 { panic!(); }
134                 let mut exp_bytes = [0; 4];
135                 exp_bytes[4 - exp_vec.len()..].copy_from_slice(&exp_vec);
136                 let exp = u32::from_be_bytes(exp_bytes);
137
138                 let mut pk_dns_encoded = Vec::new();
139                 pk_dns_encoded.push(4);
140                 pk_dns_encoded.extend_from_slice(&exp.to_be_bytes());
141                 pk_dns_encoded.extend_from_slice(&pk);
142
143                 for test in group["tests"].as_array().unwrap() {
144                         let msg = Vec::from_hex(test["msg"].as_str().unwrap()).unwrap();
145
146                         let result = match test["result"].as_str().unwrap() {
147                                 "valid" => Ok(()),
148                                 "invalid" => Err(()),
149                                 "acceptable" => continue, // Why bother testing if the tests don't care?
150                                 r => panic!("Unknown result type {}", r),
151                         };
152
153                         let sig = Vec::from_hex(test["sig"].as_str().unwrap()).unwrap();
154
155                         let hash = hash_fn(&msg);
156                         assert_eq!(result, validate_rsa(&pk_dns_encoded, &sig, hash.as_ref()),
157                                 "Failed test case group {}, test id {}, comment {}", group_idx, test["tcId"], test["comment"]);
158                 }
159         }
160 }
161
162 #[test]
163 fn test_rsa2048_sha256() {
164         let f = open_file("rsa_signature_2048_sha256_test.json");
165         let v: Value = serde_json::from_reader(f).unwrap();
166         test_rsa(v, 256, |msg| {
167                 let mut hasher = Hasher::sha256();
168                 hasher.update(msg);
169                 hasher.finish()
170         });
171 }
172
173 #[test]
174 fn test_rsa2048_sha512() {
175         let f = open_file("rsa_signature_2048_sha512_test.json");
176         let v: Value = serde_json::from_reader(f).unwrap();
177         test_rsa(v, 256, |msg| {
178                 let mut hasher = Hasher::sha512();
179                 hasher.update(msg);
180                 hasher.finish()
181         });
182 }
183
184 #[test]
185 fn test_rsa3072_sha256() {
186         let f = open_file("rsa_signature_3072_sha256_test.json");
187         let v: Value = serde_json::from_reader(f).unwrap();
188         test_rsa(v, 384, |msg| {
189                 let mut hasher = Hasher::sha256();
190                 hasher.update(msg);
191                 hasher.finish()
192         });
193 }
194
195 #[test]
196 fn test_rsa3072_sha512() {
197         let f = open_file("rsa_signature_3072_sha512_test.json");
198         let v: Value = serde_json::from_reader(f).unwrap();
199         test_rsa(v, 384, |msg| {
200                 let mut hasher = Hasher::sha512();
201                 hasher.update(msg);
202                 hasher.finish()
203         });
204 }
205
206 #[test]
207 fn test_rsa4096_sha256() {
208         let f = open_file("rsa_signature_4096_sha256_test.json");
209         let v: Value = serde_json::from_reader(f).unwrap();
210         test_rsa(v, 512, |msg| {
211                 let mut hasher = Hasher::sha256();
212                 hasher.update(msg);
213                 hasher.finish()
214         });
215 }
216
217 #[test]
218 fn test_rsa4096_sha512() {
219         let f = open_file("rsa_signature_4096_sha512_test.json");
220         let v: Value = serde_json::from_reader(f).unwrap();
221         test_rsa(v, 512, |msg| {
222                 let mut hasher = Hasher::sha512();
223                 hasher.update(msg);
224                 hasher.finish()
225         });
226 }