a1b9fbac5160c5ddd8d090e63844f92e6da744ae
[rust-lightning] / lightning / src / crypto / poly1305.rs
1 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
2 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
4 // You may not use this file except in accordance with one or both of these
5 // licenses.
6
7 // This is a port of Andrew Moons poly1305-donna
8 // https://github.com/floodyberry/poly1305-donna
9
10 use core::cmp::min;
11 use core::convert::TryInto;
12
13 #[derive(Clone, Copy)]
14 pub struct Poly1305 {
15         r         : [u32; 5],
16         h         : [u32; 5],
17         pad       : [u32; 4],
18         leftover  : usize,
19         buffer    : [u8; 16],
20         finalized : bool,
21 }
22
23 impl Poly1305 {
24         pub fn new(key: &[u8]) -> Poly1305 {
25                 assert!(key.len() == 32);
26                 let mut poly = Poly1305{ r: [0u32; 5], h: [0u32; 5], pad: [0u32; 4], leftover: 0, buffer: [0u8; 16], finalized: false };
27
28                 // r &= 0xffffffc0ffffffc0ffffffc0fffffff
29                 poly.r[0] = (u32::from_le_bytes(key[ 0.. 4].try_into().expect("len is 4"))     ) & 0x3ffffff;
30                 poly.r[1] = (u32::from_le_bytes(key[ 3.. 7].try_into().expect("len is 4")) >> 2) & 0x3ffff03;
31                 poly.r[2] = (u32::from_le_bytes(key[ 6..10].try_into().expect("len is 4")) >> 4) & 0x3ffc0ff;
32                 poly.r[3] = (u32::from_le_bytes(key[ 9..13].try_into().expect("len is 4")) >> 6) & 0x3f03fff;
33                 poly.r[4] = (u32::from_le_bytes(key[12..16].try_into().expect("len is 4")) >> 8) & 0x00fffff;
34
35                 poly.pad[0] = u32::from_le_bytes(key[16..20].try_into().expect("len is 4"));
36                 poly.pad[1] = u32::from_le_bytes(key[20..24].try_into().expect("len is 4"));
37                 poly.pad[2] = u32::from_le_bytes(key[24..28].try_into().expect("len is 4"));
38                 poly.pad[3] = u32::from_le_bytes(key[28..32].try_into().expect("len is 4"));
39
40                 poly
41         }
42
43         fn block(&mut self, m: &[u8]) {
44                 let hibit : u32 = if self.finalized { 0 } else { 1 << 24 };
45
46                 let r0 = self.r[0];
47                 let r1 = self.r[1];
48                 let r2 = self.r[2];
49                 let r3 = self.r[3];
50                 let r4 = self.r[4];
51
52                 let s1 = r1 * 5;
53                 let s2 = r2 * 5;
54                 let s3 = r3 * 5;
55                 let s4 = r4 * 5;
56
57                 let mut h0 = self.h[0];
58                 let mut h1 = self.h[1];
59                 let mut h2 = self.h[2];
60                 let mut h3 = self.h[3];
61                 let mut h4 = self.h[4];
62
63                 // h += m
64                 h0 += (u32::from_le_bytes(m[ 0.. 4].try_into().expect("len is 4"))     ) & 0x3ffffff;
65                 h1 += (u32::from_le_bytes(m[ 3.. 7].try_into().expect("len is 4")) >> 2) & 0x3ffffff;
66                 h2 += (u32::from_le_bytes(m[ 6..10].try_into().expect("len is 4")) >> 4) & 0x3ffffff;
67                 h3 += (u32::from_le_bytes(m[ 9..13].try_into().expect("len is 4")) >> 6) & 0x3ffffff;
68                 h4 += (u32::from_le_bytes(m[12..16].try_into().expect("len is 4")) >> 8) | hibit;
69
70                 // h *= r
71                 let     d0 = (h0 as u64 * r0 as u64) + (h1 as u64 * s4 as u64) + (h2 as u64 * s3 as u64) + (h3 as u64 * s2 as u64) + (h4 as u64 * s1 as u64);
72                 let mut d1 = (h0 as u64 * r1 as u64) + (h1 as u64 * r0 as u64) + (h2 as u64 * s4 as u64) + (h3 as u64 * s3 as u64) + (h4 as u64 * s2 as u64);
73                 let mut d2 = (h0 as u64 * r2 as u64) + (h1 as u64 * r1 as u64) + (h2 as u64 * r0 as u64) + (h3 as u64 * s4 as u64) + (h4 as u64 * s3 as u64);
74                 let mut d3 = (h0 as u64 * r3 as u64) + (h1 as u64 * r2 as u64) + (h2 as u64 * r1 as u64) + (h3 as u64 * r0 as u64) + (h4 as u64 * s4 as u64);
75                 let mut d4 = (h0 as u64 * r4 as u64) + (h1 as u64 * r3 as u64) + (h2 as u64 * r2 as u64) + (h3 as u64 * r1 as u64) + (h4 as u64 * r0 as u64);
76
77                 // (partial) h %= p
78                 let mut c : u32;
79                                 c = (d0 >> 26) as u32; h0 = d0 as u32 & 0x3ffffff;
80                 d1 += c as u64; c = (d1 >> 26) as u32; h1 = d1 as u32 & 0x3ffffff;
81                 d2 += c as u64; c = (d2 >> 26) as u32; h2 = d2 as u32 & 0x3ffffff;
82                 d3 += c as u64; c = (d3 >> 26) as u32; h3 = d3 as u32 & 0x3ffffff;
83                 d4 += c as u64; c = (d4 >> 26) as u32; h4 = d4 as u32 & 0x3ffffff;
84                 h0 += c * 5;    c = h0 >> 26; h0 = h0 & 0x3ffffff;
85                 h1 += c;
86
87                 self.h[0] = h0;
88                 self.h[1] = h1;
89                 self.h[2] = h2;
90                 self.h[3] = h3;
91                 self.h[4] = h4;
92         }
93
94         pub fn finish(&mut self) {
95                 if self.leftover > 0 {
96                         self.buffer[self.leftover] = 1;
97                         for i in self.leftover+1..16 {
98                                 self.buffer[i] = 0;
99                         }
100                         self.finalized = true;
101                         let tmp = self.buffer;
102                         self.block(&tmp);
103                 }
104
105                 // fully carry h
106                 let mut h0 = self.h[0];
107                 let mut h1 = self.h[1];
108                 let mut h2 = self.h[2];
109                 let mut h3 = self.h[3];
110                 let mut h4 = self.h[4];
111
112                 let mut c : u32;
113                              c = h1 >> 26; h1 = h1 & 0x3ffffff;
114                 h2 +=     c; c = h2 >> 26; h2 = h2 & 0x3ffffff;
115                 h3 +=     c; c = h3 >> 26; h3 = h3 & 0x3ffffff;
116                 h4 +=     c; c = h4 >> 26; h4 = h4 & 0x3ffffff;
117                 h0 += c * 5; c = h0 >> 26; h0 = h0 & 0x3ffffff;
118                 h1 +=     c;
119
120                 // compute h + -p
121                 let mut g0 = h0.wrapping_add(5); c = g0 >> 26; g0 &= 0x3ffffff;
122                 let mut g1 = h1.wrapping_add(c); c = g1 >> 26; g1 &= 0x3ffffff;
123                 let mut g2 = h2.wrapping_add(c); c = g2 >> 26; g2 &= 0x3ffffff;
124                 let mut g3 = h3.wrapping_add(c); c = g3 >> 26; g3 &= 0x3ffffff;
125                 let mut g4 = h4.wrapping_add(c).wrapping_sub(1 << 26);
126
127                 // select h if h < p, or h + -p if h >= p
128                 let mut mask = (g4 >> (32 - 1)).wrapping_sub(1);
129                 g0 &= mask;
130                 g1 &= mask;
131                 g2 &= mask;
132                 g3 &= mask;
133                 g4 &= mask;
134                 mask = !mask;
135                 h0 = (h0 & mask) | g0;
136                 h1 = (h1 & mask) | g1;
137                 h2 = (h2 & mask) | g2;
138                 h3 = (h3 & mask) | g3;
139                 h4 = (h4 & mask) | g4;
140
141                 // h = h % (2^128)
142                 h0 = ((h0      ) | (h1 << 26)) & 0xffffffff;
143                 h1 = ((h1 >>  6) | (h2 << 20)) & 0xffffffff;
144                 h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff;
145                 h3 = ((h3 >> 18) | (h4 <<  8)) & 0xffffffff;
146
147                 // h = mac = (h + pad) % (2^128)
148                 let mut f : u64;
149                 f = h0 as u64 + self.pad[0] as u64            ; h0 = f as u32;
150                 f = h1 as u64 + self.pad[1] as u64 + (f >> 32); h1 = f as u32;
151                 f = h2 as u64 + self.pad[2] as u64 + (f >> 32); h2 = f as u32;
152                 f = h3 as u64 + self.pad[3] as u64 + (f >> 32); h3 = f as u32;
153
154                 self.h[0] = h0;
155                 self.h[1] = h1;
156                 self.h[2] = h2;
157                 self.h[3] = h3;
158         }
159
160         pub fn input(&mut self, data: &[u8]) {
161                 assert!(!self.finalized);
162                 let mut m = data;
163
164                 if self.leftover > 0 {
165                         let want = min(16 - self.leftover, m.len());
166                         for i in 0..want {
167                                 self.buffer[self.leftover+i] = m[i];
168                         }
169                         m = &m[want..];
170                         self.leftover += want;
171
172                         if self.leftover < 16 {
173                                 return;
174                         }
175
176                         // self.block(self.buffer[..]);
177                         let tmp = self.buffer;
178                         self.block(&tmp);
179
180                         self.leftover = 0;
181                 }
182
183                 while m.len() >= 16 {
184                         self.block(&m[0..16]);
185                         m = &m[16..];
186                 }
187
188                 for i in 0..m.len() {
189                         self.buffer[i] = m[i];
190                 }
191                 self.leftover = m.len();
192         }
193
194         pub fn raw_result(&mut self, output: &mut [u8]) {
195                 assert!(output.len() >= 16);
196                 if !self.finalized{
197                         self.finish();
198                 }
199                 output[0..4].copy_from_slice(&self.h[0].to_le_bytes());
200                 output[4..8].copy_from_slice(&self.h[1].to_le_bytes());
201                 output[8..12].copy_from_slice(&self.h[2].to_le_bytes());
202                 output[12..16].copy_from_slice(&self.h[3].to_le_bytes());
203         }
204 }
205
206 #[cfg(test)]
207 mod test {
208         use core::iter::repeat;
209         use alloc::vec::Vec;
210
211         use super::Poly1305;
212
213         fn poly1305(key: &[u8], msg: &[u8], mac: &mut [u8]) {
214                 let mut poly = Poly1305::new(key);
215                 poly.input(msg);
216                 poly.raw_result(mac);
217         }
218
219         #[test]
220         fn test_nacl_vector() {
221                 let key = [
222                         0xee,0xa6,0xa7,0x25,0x1c,0x1e,0x72,0x91,
223                         0x6d,0x11,0xc2,0xcb,0x21,0x4d,0x3c,0x25,
224                         0x25,0x39,0x12,0x1d,0x8e,0x23,0x4e,0x65,
225                         0x2d,0x65,0x1f,0xa4,0xc8,0xcf,0xf8,0x80,
226                 ];
227
228                 let msg = [
229                         0x8e,0x99,0x3b,0x9f,0x48,0x68,0x12,0x73,
230                         0xc2,0x96,0x50,0xba,0x32,0xfc,0x76,0xce,
231                         0x48,0x33,0x2e,0xa7,0x16,0x4d,0x96,0xa4,
232                         0x47,0x6f,0xb8,0xc5,0x31,0xa1,0x18,0x6a,
233                         0xc0,0xdf,0xc1,0x7c,0x98,0xdc,0xe8,0x7b,
234                         0x4d,0xa7,0xf0,0x11,0xec,0x48,0xc9,0x72,
235                         0x71,0xd2,0xc2,0x0f,0x9b,0x92,0x8f,0xe2,
236                         0x27,0x0d,0x6f,0xb8,0x63,0xd5,0x17,0x38,
237                         0xb4,0x8e,0xee,0xe3,0x14,0xa7,0xcc,0x8a,
238                         0xb9,0x32,0x16,0x45,0x48,0xe5,0x26,0xae,
239                         0x90,0x22,0x43,0x68,0x51,0x7a,0xcf,0xea,
240                         0xbd,0x6b,0xb3,0x73,0x2b,0xc0,0xe9,0xda,
241                         0x99,0x83,0x2b,0x61,0xca,0x01,0xb6,0xde,
242                         0x56,0x24,0x4a,0x9e,0x88,0xd5,0xf9,0xb3,
243                         0x79,0x73,0xf6,0x22,0xa4,0x3d,0x14,0xa6,
244                         0x59,0x9b,0x1f,0x65,0x4c,0xb4,0x5a,0x74,
245                         0xe3,0x55,0xa5,
246                 ];
247
248                 let expected = [
249                         0xf3,0xff,0xc7,0x70,0x3f,0x94,0x00,0xe5,
250                         0x2a,0x7d,0xfb,0x4b,0x3d,0x33,0x05,0xd9,
251                 ];
252
253                 let mut mac = [0u8; 16];
254                 poly1305(&key, &msg, &mut mac);
255                 assert_eq!(&mac[..], &expected[..]);
256
257                 let mut poly = Poly1305::new(&key);
258                 poly.input(&msg[0..32]);
259                 poly.input(&msg[32..96]);
260                 poly.input(&msg[96..112]);
261                 poly.input(&msg[112..120]);
262                 poly.input(&msg[120..124]);
263                 poly.input(&msg[124..126]);
264                 poly.input(&msg[126..127]);
265                 poly.input(&msg[127..128]);
266                 poly.input(&msg[128..129]);
267                 poly.input(&msg[129..130]);
268                 poly.input(&msg[130..131]);
269                 poly.raw_result(&mut mac);
270                 assert_eq!(&mac[..], &expected[..]);
271         }
272
273         #[test]
274         fn donna_self_test() {
275                 let wrap_key = [
276                         0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
277                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
278                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
279                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
280                 ];
281
282                 let wrap_msg = [
283                         0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
284                         0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
285                 ];
286
287                 let wrap_mac = [
288                         0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
289                         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
290                 ];
291
292                 let mut mac = [0u8; 16];
293                 poly1305(&wrap_key, &wrap_msg, &mut mac);
294                 assert_eq!(&mac[..], &wrap_mac[..]);
295
296                 let total_key = [
297                         0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xff,
298                         0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xff, 0xff,
299                         0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
300                         0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
301                 ];
302
303                 let total_mac = [
304                         0x64, 0xaf, 0xe2, 0xe8, 0xd6, 0xad, 0x7b, 0xbd,
305                         0xd2, 0x87, 0xf9, 0x7c, 0x44, 0x62, 0x3d, 0x39,
306                 ];
307
308                 let mut tpoly = Poly1305::new(&total_key);
309                 for i in 0..256 {
310                         let key: Vec<u8> = repeat(i as u8).take(32).collect();
311                         let msg: Vec<u8> = repeat(i as u8).take(256).collect();
312                         let mut mac = [0u8; 16];
313                         poly1305(&key[..], &msg[0..i], &mut mac);
314                         tpoly.input(&mac);
315                 }
316                 tpoly.raw_result(&mut mac);
317                 assert_eq!(&mac[..], &total_mac[..]);
318         }
319
320         #[test]
321         fn test_tls_vectors() {
322                 // from http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04
323                 let key = b"this is 32-byte key for Poly1305";
324                 let msg = [0u8; 32];
325                 let expected = [
326                         0x49, 0xec, 0x78, 0x09, 0x0e, 0x48, 0x1e, 0xc6,
327                         0xc2, 0x6b, 0x33, 0xb9, 0x1c, 0xcc, 0x03, 0x07,
328                 ];
329                 let mut mac = [0u8; 16];
330                 poly1305(key, &msg, &mut mac);
331                 assert_eq!(&mac[..], &expected[..]);
332
333                 let msg = b"Hello world!";
334                 let expected= [
335                         0xa6, 0xf7, 0x45, 0x00, 0x8f, 0x81, 0xc9, 0x16,
336                         0xa2, 0x0d, 0xcc, 0x74, 0xee, 0xf2, 0xb2, 0xf0,
337                 ];
338                 poly1305(key, msg, &mut mac);
339                 assert_eq!(&mac[..], &expected[..]);
340         }
341 }