Fix const build error in the previous commits
[dnssec-prover] / src / crypto / rsa.rs
1 //! A simple RSA implementation which handles DNSSEC RSA validation
2
3 use super::bigint::*;
4
5 fn bytes_to_rsa_mod_exp_modlen(pubkey: &[u8]) -> Result<(U4096, u32, usize), ()> {
6         if pubkey.len() <= 3 { return Err(()); }
7
8         let mut pos = 0;
9         let exponent_length;
10         if pubkey[0] == 0 {
11                 exponent_length = ((pubkey[1] as usize) << 8) | (pubkey[2] as usize);
12                 pos += 3;
13         } else {
14                 exponent_length = pubkey[0] as usize;
15                 pos += 1;
16         }
17
18         if pubkey.len() <= pos + exponent_length { return Err(()); }
19         if exponent_length > 4 { return Err(()); }
20         let mut exp_bytes = [0; 4];
21         exp_bytes[4 - exponent_length..].copy_from_slice(&pubkey[pos..pos + exponent_length]);
22         let exp = u32::from_be_bytes(exp_bytes);
23
24         let mod_bytes = &pubkey[pos + exponent_length..];
25         let modlen = pubkey.len() - pos - exponent_length;
26         let modulus = U4096::from_be_bytes(mod_bytes)?;
27         Ok((modulus, exp, modlen))
28 }
29
30 /// Validates the given RSA signature against the given RSA public key (up to 4096-bit, in
31 /// DNSSEC-encoded form) and given message digest.
32 pub fn validate_rsa(pk: &[u8], sig_bytes: &[u8], hash_input: &[u8]) -> Result<(), ()> {
33         let (modulus, exponent, modulus_byte_len) = bytes_to_rsa_mod_exp_modlen(pk)?;
34         if modulus_byte_len > 512 { /* implied by the U4096, but explicit here */ return Err(()); }
35         let sig = U4096::from_be_bytes(sig_bytes)?;
36
37         if sig > modulus { return Err(()); }
38
39         // From https://www.rfc-editor.org/rfc/rfc5702#section-3.1
40         const SHA256_PFX: [u8; 20] = hex_lit::hex!("003031300d060960864801650304020105000420");
41         const SHA512_PFX: [u8; 20] = hex_lit::hex!("003051300d060960864801650304020305000440");
42         let pfx = if hash_input.len() == 512 / 8 { &SHA512_PFX } else { &SHA256_PFX };
43
44         if 512 - 2 - SHA256_PFX.len() <= hash_input.len() { return Err(()); }
45         let mut hash_bytes = [0; 512];
46         let mut hash_write_pos = 512 - hash_input.len();
47         hash_bytes[hash_write_pos..].copy_from_slice(&hash_input);
48         hash_write_pos -= pfx.len();
49         hash_bytes[hash_write_pos..hash_write_pos + pfx.len()].copy_from_slice(pfx);
50         while 512 + 1 - hash_write_pos < modulus_byte_len {
51                 hash_write_pos -= 1;
52                 hash_bytes[hash_write_pos] = 0xff;
53         }
54         hash_bytes[hash_write_pos] = 1;
55         let hash = U4096::from_be_bytes(&hash_bytes)?;
56
57         if hash > modulus { return Err(()); }
58
59         // While modulus could be even, if it were we'd have already factored the modulus (one of the
60         // primes is two!), so we don't particularly care if we fail spuriously for such spurious keys.
61         let res = sig.expmod_odd_mod(exponent, &modulus)?;
62         if res == hash {
63                 Ok(())
64         } else {
65                 Err(())
66         }
67 }