Use `crate::prelude::*` rather than specific imports
[rust-lightning] / lightning / src / crypto / poly1305.rs
1 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
2 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
4 // You may not use this file except in accordance with one or both of these
5 // licenses.
6
7 // This is a port of Andrew Moons poly1305-donna
8 // https://github.com/floodyberry/poly1305-donna
9
10 use core::cmp::min;
11
12 use crate::prelude::*;
13
14 #[derive(Clone, Copy)]
15 pub struct Poly1305 {
16         r         : [u32; 5],
17         h         : [u32; 5],
18         pad       : [u32; 4],
19         leftover  : usize,
20         buffer    : [u8; 16],
21         finalized : bool,
22 }
23
24 impl Poly1305 {
25         pub fn new(key: &[u8]) -> Poly1305 {
26                 assert!(key.len() == 32);
27                 let mut poly = Poly1305{ r: [0u32; 5], h: [0u32; 5], pad: [0u32; 4], leftover: 0, buffer: [0u8; 16], finalized: false };
28
29                 // r &= 0xffffffc0ffffffc0ffffffc0fffffff
30                 poly.r[0] = (u32::from_le_bytes(key[ 0.. 4].try_into().expect("len is 4"))     ) & 0x3ffffff;
31                 poly.r[1] = (u32::from_le_bytes(key[ 3.. 7].try_into().expect("len is 4")) >> 2) & 0x3ffff03;
32                 poly.r[2] = (u32::from_le_bytes(key[ 6..10].try_into().expect("len is 4")) >> 4) & 0x3ffc0ff;
33                 poly.r[3] = (u32::from_le_bytes(key[ 9..13].try_into().expect("len is 4")) >> 6) & 0x3f03fff;
34                 poly.r[4] = (u32::from_le_bytes(key[12..16].try_into().expect("len is 4")) >> 8) & 0x00fffff;
35
36                 poly.pad[0] = u32::from_le_bytes(key[16..20].try_into().expect("len is 4"));
37                 poly.pad[1] = u32::from_le_bytes(key[20..24].try_into().expect("len is 4"));
38                 poly.pad[2] = u32::from_le_bytes(key[24..28].try_into().expect("len is 4"));
39                 poly.pad[3] = u32::from_le_bytes(key[28..32].try_into().expect("len is 4"));
40
41                 poly
42         }
43
44         fn block(&mut self, m: &[u8]) {
45                 let hibit : u32 = if self.finalized { 0 } else { 1 << 24 };
46
47                 let r0 = self.r[0];
48                 let r1 = self.r[1];
49                 let r2 = self.r[2];
50                 let r3 = self.r[3];
51                 let r4 = self.r[4];
52
53                 let s1 = r1 * 5;
54                 let s2 = r2 * 5;
55                 let s3 = r3 * 5;
56                 let s4 = r4 * 5;
57
58                 let mut h0 = self.h[0];
59                 let mut h1 = self.h[1];
60                 let mut h2 = self.h[2];
61                 let mut h3 = self.h[3];
62                 let mut h4 = self.h[4];
63
64                 // h += m
65                 h0 += (u32::from_le_bytes(m[ 0.. 4].try_into().expect("len is 4"))     ) & 0x3ffffff;
66                 h1 += (u32::from_le_bytes(m[ 3.. 7].try_into().expect("len is 4")) >> 2) & 0x3ffffff;
67                 h2 += (u32::from_le_bytes(m[ 6..10].try_into().expect("len is 4")) >> 4) & 0x3ffffff;
68                 h3 += (u32::from_le_bytes(m[ 9..13].try_into().expect("len is 4")) >> 6) & 0x3ffffff;
69                 h4 += (u32::from_le_bytes(m[12..16].try_into().expect("len is 4")) >> 8) | hibit;
70
71                 // h *= r
72                 let     d0 = (h0 as u64 * r0 as u64) + (h1 as u64 * s4 as u64) + (h2 as u64 * s3 as u64) + (h3 as u64 * s2 as u64) + (h4 as u64 * s1 as u64);
73                 let mut d1 = (h0 as u64 * r1 as u64) + (h1 as u64 * r0 as u64) + (h2 as u64 * s4 as u64) + (h3 as u64 * s3 as u64) + (h4 as u64 * s2 as u64);
74                 let mut d2 = (h0 as u64 * r2 as u64) + (h1 as u64 * r1 as u64) + (h2 as u64 * r0 as u64) + (h3 as u64 * s4 as u64) + (h4 as u64 * s3 as u64);
75                 let mut d3 = (h0 as u64 * r3 as u64) + (h1 as u64 * r2 as u64) + (h2 as u64 * r1 as u64) + (h3 as u64 * r0 as u64) + (h4 as u64 * s4 as u64);
76                 let mut d4 = (h0 as u64 * r4 as u64) + (h1 as u64 * r3 as u64) + (h2 as u64 * r2 as u64) + (h3 as u64 * r1 as u64) + (h4 as u64 * r0 as u64);
77
78                 // (partial) h %= p
79                 let mut c : u32;
80                                 c = (d0 >> 26) as u32; h0 = d0 as u32 & 0x3ffffff;
81                 d1 += c as u64; c = (d1 >> 26) as u32; h1 = d1 as u32 & 0x3ffffff;
82                 d2 += c as u64; c = (d2 >> 26) as u32; h2 = d2 as u32 & 0x3ffffff;
83                 d3 += c as u64; c = (d3 >> 26) as u32; h3 = d3 as u32 & 0x3ffffff;
84                 d4 += c as u64; c = (d4 >> 26) as u32; h4 = d4 as u32 & 0x3ffffff;
85                 h0 += c * 5;    c = h0 >> 26; h0 = h0 & 0x3ffffff;
86                 h1 += c;
87
88                 self.h[0] = h0;
89                 self.h[1] = h1;
90                 self.h[2] = h2;
91                 self.h[3] = h3;
92                 self.h[4] = h4;
93         }
94
95         pub fn finish(&mut self) {
96                 if self.leftover > 0 {
97                         self.buffer[self.leftover] = 1;
98                         for i in self.leftover+1..16 {
99                                 self.buffer[i] = 0;
100                         }
101                         self.finalized = true;
102                         let tmp = self.buffer;
103                         self.block(&tmp);
104                 }
105
106                 // fully carry h
107                 let mut h0 = self.h[0];
108                 let mut h1 = self.h[1];
109                 let mut h2 = self.h[2];
110                 let mut h3 = self.h[3];
111                 let mut h4 = self.h[4];
112
113                 let mut c : u32;
114                              c = h1 >> 26; h1 = h1 & 0x3ffffff;
115                 h2 +=     c; c = h2 >> 26; h2 = h2 & 0x3ffffff;
116                 h3 +=     c; c = h3 >> 26; h3 = h3 & 0x3ffffff;
117                 h4 +=     c; c = h4 >> 26; h4 = h4 & 0x3ffffff;
118                 h0 += c * 5; c = h0 >> 26; h0 = h0 & 0x3ffffff;
119                 h1 +=     c;
120
121                 // compute h + -p
122                 let mut g0 = h0.wrapping_add(5); c = g0 >> 26; g0 &= 0x3ffffff;
123                 let mut g1 = h1.wrapping_add(c); c = g1 >> 26; g1 &= 0x3ffffff;
124                 let mut g2 = h2.wrapping_add(c); c = g2 >> 26; g2 &= 0x3ffffff;
125                 let mut g3 = h3.wrapping_add(c); c = g3 >> 26; g3 &= 0x3ffffff;
126                 let mut g4 = h4.wrapping_add(c).wrapping_sub(1 << 26);
127
128                 // select h if h < p, or h + -p if h >= p
129                 let mut mask = (g4 >> (32 - 1)).wrapping_sub(1);
130                 g0 &= mask;
131                 g1 &= mask;
132                 g2 &= mask;
133                 g3 &= mask;
134                 g4 &= mask;
135                 mask = !mask;
136                 h0 = (h0 & mask) | g0;
137                 h1 = (h1 & mask) | g1;
138                 h2 = (h2 & mask) | g2;
139                 h3 = (h3 & mask) | g3;
140                 h4 = (h4 & mask) | g4;
141
142                 // h = h % (2^128)
143                 h0 = ((h0      ) | (h1 << 26)) & 0xffffffff;
144                 h1 = ((h1 >>  6) | (h2 << 20)) & 0xffffffff;
145                 h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff;
146                 h3 = ((h3 >> 18) | (h4 <<  8)) & 0xffffffff;
147
148                 // h = mac = (h + pad) % (2^128)
149                 let mut f : u64;
150                 f = h0 as u64 + self.pad[0] as u64            ; h0 = f as u32;
151                 f = h1 as u64 + self.pad[1] as u64 + (f >> 32); h1 = f as u32;
152                 f = h2 as u64 + self.pad[2] as u64 + (f >> 32); h2 = f as u32;
153                 f = h3 as u64 + self.pad[3] as u64 + (f >> 32); h3 = f as u32;
154
155                 self.h[0] = h0;
156                 self.h[1] = h1;
157                 self.h[2] = h2;
158                 self.h[3] = h3;
159         }
160
161         pub fn input(&mut self, data: &[u8]) {
162                 assert!(!self.finalized);
163                 let mut m = data;
164
165                 if self.leftover > 0 {
166                         let want = min(16 - self.leftover, m.len());
167                         for i in 0..want {
168                                 self.buffer[self.leftover+i] = m[i];
169                         }
170                         m = &m[want..];
171                         self.leftover += want;
172
173                         if self.leftover < 16 {
174                                 return;
175                         }
176
177                         // self.block(self.buffer[..]);
178                         let tmp = self.buffer;
179                         self.block(&tmp);
180
181                         self.leftover = 0;
182                 }
183
184                 while m.len() >= 16 {
185                         self.block(&m[0..16]);
186                         m = &m[16..];
187                 }
188
189                 for i in 0..m.len() {
190                         self.buffer[i] = m[i];
191                 }
192                 self.leftover = m.len();
193         }
194
195         pub fn raw_result(&mut self, output: &mut [u8]) {
196                 assert!(output.len() >= 16);
197                 if !self.finalized{
198                         self.finish();
199                 }
200                 output[0..4].copy_from_slice(&self.h[0].to_le_bytes());
201                 output[4..8].copy_from_slice(&self.h[1].to_le_bytes());
202                 output[8..12].copy_from_slice(&self.h[2].to_le_bytes());
203                 output[12..16].copy_from_slice(&self.h[3].to_le_bytes());
204         }
205 }
206
207 #[cfg(test)]
208 mod test {
209         use core::iter::repeat;
210
211         use super::Poly1305;
212
213         fn poly1305(key: &[u8], msg: &[u8], mac: &mut [u8]) {
214                 let mut poly = Poly1305::new(key);
215                 poly.input(msg);
216                 poly.raw_result(mac);
217         }
218
219         #[test]
220         fn test_nacl_vector() {
221                 let key = [
222                         0xee,0xa6,0xa7,0x25,0x1c,0x1e,0x72,0x91,
223                         0x6d,0x11,0xc2,0xcb,0x21,0x4d,0x3c,0x25,
224                         0x25,0x39,0x12,0x1d,0x8e,0x23,0x4e,0x65,
225                         0x2d,0x65,0x1f,0xa4,0xc8,0xcf,0xf8,0x80,
226                 ];
227
228                 let msg = [
229                         0x8e,0x99,0x3b,0x9f,0x48,0x68,0x12,0x73,
230                         0xc2,0x96,0x50,0xba,0x32,0xfc,0x76,0xce,
231                         0x48,0x33,0x2e,0xa7,0x16,0x4d,0x96,0xa4,
232                         0x47,0x6f,0xb8,0xc5,0x31,0xa1,0x18,0x6a,
233                         0xc0,0xdf,0xc1,0x7c,0x98,0xdc,0xe8,0x7b,
234                         0x4d,0xa7,0xf0,0x11,0xec,0x48,0xc9,0x72,
235                         0x71,0xd2,0xc2,0x0f,0x9b,0x92,0x8f,0xe2,
236                         0x27,0x0d,0x6f,0xb8,0x63,0xd5,0x17,0x38,
237                         0xb4,0x8e,0xee,0xe3,0x14,0xa7,0xcc,0x8a,
238                         0xb9,0x32,0x16,0x45,0x48,0xe5,0x26,0xae,
239                         0x90,0x22,0x43,0x68,0x51,0x7a,0xcf,0xea,
240                         0xbd,0x6b,0xb3,0x73,0x2b,0xc0,0xe9,0xda,
241                         0x99,0x83,0x2b,0x61,0xca,0x01,0xb6,0xde,
242                         0x56,0x24,0x4a,0x9e,0x88,0xd5,0xf9,0xb3,
243                         0x79,0x73,0xf6,0x22,0xa4,0x3d,0x14,0xa6,
244                         0x59,0x9b,0x1f,0x65,0x4c,0xb4,0x5a,0x74,
245                         0xe3,0x55,0xa5,
246                 ];
247
248                 let expected = [
249                         0xf3,0xff,0xc7,0x70,0x3f,0x94,0x00,0xe5,
250                         0x2a,0x7d,0xfb,0x4b,0x3d,0x33,0x05,0xd9,
251                 ];
252
253                 let mut mac = [0u8; 16];
254                 poly1305(&key, &msg, &mut mac);
255                 assert_eq!(&mac[..], &expected[..]);
256
257                 let mut poly = Poly1305::new(&key);
258                 poly.input(&msg[0..32]);
259                 poly.input(&msg[32..96]);
260                 poly.input(&msg[96..112]);
261                 poly.input(&msg[112..120]);
262                 poly.input(&msg[120..124]);
263                 poly.input(&msg[124..126]);
264                 poly.input(&msg[126..127]);
265                 poly.input(&msg[127..128]);
266                 poly.input(&msg[128..129]);
267                 poly.input(&msg[129..130]);
268                 poly.input(&msg[130..131]);
269                 poly.raw_result(&mut mac);
270                 assert_eq!(&mac[..], &expected[..]);
271         }
272
273         #[test]
274         fn donna_self_test() {
275                 let wrap_key = [
276                         0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
277                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
278                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
279                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
280                 ];
281
282                 let wrap_msg = [
283                         0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284                         0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
285                 ];
286
287                 let wrap_mac = [
288                         0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
289                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
290                 ];
291
292                 let mut mac = [0u8; 16];
293                 poly1305(&wrap_key, &wrap_msg, &mut mac);
294                 assert_eq!(&mac[..], &wrap_mac[..]);
295
296                 let total_key = [
297                         0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xff,
298                         0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xff, 0xff,
299                         0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
300                         0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
301                 ];
302
303                 let total_mac = [
304                         0x64, 0xaf, 0xe2, 0xe8, 0xd6, 0xad, 0x7b, 0xbd,
305                         0xd2, 0x87, 0xf9, 0x7c, 0x44, 0x62, 0x3d, 0x39,
306                 ];
307
308                 let mut tpoly = Poly1305::new(&total_key);
309                 for i in 0..256 {
310                         let key: Vec<u8> = repeat(i as u8).take(32).collect();
311                         let msg: Vec<u8> = repeat(i as u8).take(256).collect();
312                         let mut mac = [0u8; 16];
313                         poly1305(&key[..], &msg[0..i], &mut mac);
314                         tpoly.input(&mac);
315                 }
316                 tpoly.raw_result(&mut mac);
317                 assert_eq!(&mac[..], &total_mac[..]);
318         }
319
320         #[test]
321         fn test_tls_vectors() {
322                 // from http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04
323                 let key = b"this is 32-byte key for Poly1305";
324                 let msg = [0u8; 32];
325                 let expected = [
326                         0x49, 0xec, 0x78, 0x09, 0x0e, 0x48, 0x1e, 0xc6,
327                         0xc2, 0x6b, 0x33, 0xb9, 0x1c, 0xcc, 0x03, 0x07,
328                 ];
329                 let mut mac = [0u8; 16];
330                 poly1305(key, &msg, &mut mac);
331                 assert_eq!(&mac[..], &expected[..]);
332
333                 let msg = b"Hello world!";
334                 let expected= [
335                         0xa6, 0xf7, 0x45, 0x00, 0x8f, 0x81, 0xc9, 0x16,
336                         0xa2, 0x0d, 0xcc, 0x74, 0xee, 0xf2, 0xb2, 0xf0,
337                 ];
338                 poly1305(key, msg, &mut mac);
339                 assert_eq!(&mac[..], &expected[..]);
340         }
341 }