Bump version to 0.6.5 for new less-code-size feature
[dnssec-prover] / src / test / crypto.rs
1 use crate::crypto::secp256r1::validate_ecdsa as validate_256r1;
2 use crate::crypto::secp384r1::validate_ecdsa as validate_384r1;
3 use crate::crypto::rsa::validate_rsa;
4 use crate::crypto::hash::{Hasher, HashResult};
5
6 use hex_conservative::FromHex;
7 use serde_json::Value;
8 use std::fs::File;
9
10 fn open_file(name: &str) -> File {
11         if let Ok(f) = File::open(name) { return f; }
12         if let Ok(f) = File::open("../".to_owned() + name) { return f; }
13         if let Ok(f) = File::open("src/test/".to_owned() + name) { return f; }
14         if let Ok(f) = File::open("../src/test/".to_owned() + name) { return f; }
15         if let Ok(f) = File::open("../../src/test/".to_owned() + name) { return f; }
16         panic!("Failed to find file {}", name);
17 }
18
19 fn decode_asn(sig: &str, int_len: usize) -> Result<Vec<u8>, ()> {
20         // Signature is in ASN, so decode the garbage excess headers
21         // Note that some tests are specifically for the ASN parser, so we have to carefully
22         // reject invalid crap here.
23         if sig.len() < 12 { return Err(()); }
24
25         if &sig[..2] != "30" { return Err(()); }
26         let total_len = (<[u8; 2]>::from_hex(&sig[..4]).unwrap())[1] as usize;
27         if total_len + 2 != sig.len() / 2 { return Err(()); }
28
29         if &sig[4..6] != "02" { return Err(()); }
30         let r_len = (<[u8; 2]>::from_hex(&sig[4..8]).unwrap())[1] as usize;
31         if sig.len() < r_len * 2 + 8 { return Err(()); }
32         if r_len == 0 { return Err(()); }
33         let r = Vec::from_hex(&sig[8..r_len * 2 + 8]).unwrap();
34         if r.len() > int_len {
35                 // If the MSB is 1, an extra byte is required to avoid the sign flag
36                 if r.len() > int_len + 1 { return Err(()); }
37                 if r[0] != 0 { return Err(()); }
38                 if r[1] & 0b1000_0000 == 0 { return Err(()); }
39         } else if r[0] & 0b1000_0000 != 0 { return Err(()); }
40
41         if sig.len() < r_len * 2 + 12 { return Err(()); }
42         if &sig[r_len * 2 + 8..r_len * 2 + 10] != "02" { return Err(()); }
43         let s_len = (<[u8; 2]>::from_hex(&sig[r_len * 2 + 8..r_len * 2 + 12]).unwrap())[1] as usize;
44         if sig.len() != r_len * 2 + s_len * 2 + 12 { return Err(()); }
45         if s_len == 0 { return Err(()); }
46         let s = Vec::from_hex(&sig[r_len * 2 + 12..]).unwrap();
47         if s.len() > int_len {
48                 // If the MSB is 1, an extra byte is required to avoid the sign flag
49                 if s.len() > int_len + 1 { return Err(()); }
50                 if s[0] != 0 { return Err(()); }
51                 if s[1] & 0b1000_0000 == 0 { return Err(()); }
52         } else if s[0] & 0b1000_0000 != 0 { return Err(()); }
53
54         let mut sig_bytes = vec![0; int_len * 2];
55         sig_bytes[int_len.saturating_sub(r.len())..int_len]
56                 .copy_from_slice(&r[r.len().saturating_sub(int_len)..]);
57         sig_bytes[int_len + int_len.saturating_sub(s.len())..int_len * 2]
58                 .copy_from_slice(&s[s.len().saturating_sub(int_len)..]);
59
60         Ok(sig_bytes)
61 }
62
63 fn test_ecdsa<
64         Validate: Fn(&[u8], &[u8], &[u8]) -> Result<(), ()> + Clone,
65         Hash: Fn(&[u8]) -> HashResult,
66 >(v: Value, int_len: usize, validate_fn: Validate, hash_fn: Hash) {
67         for (group_idx, group) in v["testGroups"].as_array().unwrap().into_iter().enumerate() {
68                 let pk_str = group["publicKey"]["uncompressed"].as_str().unwrap();
69                 assert_eq!(&pk_str[..2], "04"); // OpenSSL uncompressed encoding flag
70                 let pk = Vec::from_hex(&pk_str[2..]).unwrap();
71                 for test in group["tests"].as_array().unwrap() {
72                         let msg = Vec::from_hex(test["msg"].as_str().unwrap()).unwrap();
73
74                         let expected_result = match test["result"].as_str().unwrap() {
75                                 "valid" => Ok(()),
76                                 "invalid" => Err(()),
77                                 r => panic!("Unknown result type {}", r),
78                         };
79
80                         let sig = decode_asn(test["sig"].as_str().unwrap(), int_len);
81                         if sig.is_err() {
82                                 assert_eq!(expected_result, Err(()));
83                                 continue;
84                         }
85
86                         let hash = hash_fn(&msg);
87                         let pk = &pk[..];
88                         let validate_fn = validate_fn.clone();
89                         let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
90                                 let result = validate_fn(&pk, &sig.unwrap(), hash.as_ref());
91                                 assert_eq!(result, expected_result);
92                         }));
93                         if panicked.is_err() {
94                                 panic!("Test case group {}, test id {}, comment {} panicked!", group_idx, test["tcId"], test["comment"]);
95                         }
96                 }
97         }
98 }
99
100 #[test]
101 fn test_ecdsa_256r1() {
102         let f = open_file("ecdsa_secp256r1_sha256_test.json");
103         let v: Value = serde_json::from_reader(f).unwrap();
104         test_ecdsa(v, 32, validate_256r1, |msg| {
105                 let mut hasher = Hasher::sha256();
106                 hasher.update(msg);
107                 hasher.finish()
108         });
109 }
110
111 #[test]
112 fn test_ecdsa_384r1_sha256() {
113         let f = open_file("ecdsa_secp384r1_sha256_test.json");
114         let v: Value = serde_json::from_reader(f).unwrap();
115         test_ecdsa(v, 48, validate_384r1, |msg| {
116                 let mut hasher = Hasher::sha256();
117                 hasher.update(msg);
118                 hasher.finish()
119         });
120 }
121
122 #[test]
123 fn test_ecdsa_384r1_sha384() {
124         let f = open_file("ecdsa_secp384r1_sha384_test.json");
125         let v: Value = serde_json::from_reader(f).unwrap();
126         test_ecdsa(v, 48, validate_384r1, |msg| {
127                 let mut hasher = Hasher::sha384();
128                 hasher.update(msg);
129                 hasher.finish()
130         });
131 }
132
133 fn test_rsa<Hash: Fn(&[u8]) -> HashResult>(v: Value, pk_len: usize, hash_fn: Hash) {
134         for (group_idx, group) in v["testGroups"].as_array().unwrap().into_iter().enumerate() {
135                 let pk_str = group["publicKey"]["modulus"].as_str().unwrap();
136                 assert_eq!(&pk_str[..2], "00"); // No idea why this is here
137                 let pk = Vec::from_hex(&pk_str[2..]).unwrap();
138                 assert_eq!(pk.len(), pk_len);
139                 let exp_vec = Vec::from_hex(group["publicKey"]["publicExponent"].as_str().unwrap()).unwrap();
140                 if exp_vec.len() > 4 { panic!(); }
141                 let mut exp_bytes = [0; 4];
142                 exp_bytes[4 - exp_vec.len()..].copy_from_slice(&exp_vec);
143                 let exp = u32::from_be_bytes(exp_bytes);
144
145                 let mut pk_dns_encoded = Vec::new();
146                 pk_dns_encoded.push(4);
147                 pk_dns_encoded.extend_from_slice(&exp.to_be_bytes());
148                 pk_dns_encoded.extend_from_slice(&pk);
149
150                 for test in group["tests"].as_array().unwrap() {
151                         let msg = Vec::from_hex(test["msg"].as_str().unwrap()).unwrap();
152
153                         let result = match test["result"].as_str().unwrap() {
154                                 "valid" => Ok(()),
155                                 "invalid" => Err(()),
156                                 "acceptable" => continue, // Why bother testing if the tests don't care?
157                                 r => panic!("Unknown result type {}", r),
158                         };
159
160                         let sig = Vec::from_hex(test["sig"].as_str().unwrap()).unwrap();
161
162                         let hash = hash_fn(&msg);
163                         assert_eq!(result, validate_rsa(&pk_dns_encoded, &sig, hash.as_ref()),
164                                 "Failed test case group {}, test id {}, comment {}", group_idx, test["tcId"], test["comment"]);
165                 }
166         }
167 }
168
169 #[test]
170 fn test_rsa2048_sha256() {
171         let f = open_file("rsa_signature_2048_sha256_test.json");
172         let v: Value = serde_json::from_reader(f).unwrap();
173         test_rsa(v, 256, |msg| {
174                 let mut hasher = Hasher::sha256();
175                 hasher.update(msg);
176                 hasher.finish()
177         });
178 }
179
180 #[test]
181 fn test_rsa2048_sha512() {
182         let f = open_file("rsa_signature_2048_sha512_test.json");
183         let v: Value = serde_json::from_reader(f).unwrap();
184         test_rsa(v, 256, |msg| {
185                 let mut hasher = Hasher::sha512();
186                 hasher.update(msg);
187                 hasher.finish()
188         });
189 }
190
191 #[test]
192 fn test_rsa3072_sha256() {
193         let f = open_file("rsa_signature_3072_sha256_test.json");
194         let v: Value = serde_json::from_reader(f).unwrap();
195         test_rsa(v, 384, |msg| {
196                 let mut hasher = Hasher::sha256();
197                 hasher.update(msg);
198                 hasher.finish()
199         });
200 }
201
202 #[test]
203 fn test_rsa3072_sha512() {
204         let f = open_file("rsa_signature_3072_sha512_test.json");
205         let v: Value = serde_json::from_reader(f).unwrap();
206         test_rsa(v, 384, |msg| {
207                 let mut hasher = Hasher::sha512();
208                 hasher.update(msg);
209                 hasher.finish()
210         });
211 }
212
213 #[test]
214 fn test_rsa4096_sha256() {
215         let f = open_file("rsa_signature_4096_sha256_test.json");
216         let v: Value = serde_json::from_reader(f).unwrap();
217         test_rsa(v, 512, |msg| {
218                 let mut hasher = Hasher::sha256();
219                 hasher.update(msg);
220                 hasher.finish()
221         });
222 }
223
224 #[test]
225 fn test_rsa4096_sha512() {
226         let f = open_file("rsa_signature_4096_sha512_test.json");
227         let v: Value = serde_json::from_reader(f).unwrap();
228         test_rsa(v, 512, |msg| {
229                 let mut hasher = Hasher::sha512();
230                 hasher.update(msg);
231                 hasher.finish()
232         });
233 }