Memory clearing after finish.
[shamirs] / shamirssecret.c
1 #define _GNU_SOURCE
2
3 #include <stdint.h>
4 #include <stdio.h>
5 #include <assert.h>
6 #include <string.h>
7 #include <unistd.h>
8 #include <stdlib.h>
9 #include <sys/mman.h>
10
11 #define MAX_LENGTH 1024
12 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
13
14 /*
15  * Calculations across the finite field GF(2^8)
16  */
17 #define P 256
18
19 static uint8_t field_add(uint8_t a, uint8_t b) {
20         return a ^ b;
21 }
22
23 static uint8_t field_sub(uint8_t a, uint8_t b) {
24         return a ^ b;
25 }
26
27 static uint8_t field_neg(uint8_t a) {
28         return field_sub(0, a);
29 }
30
31 static const uint8_t exp[P] = {
32         0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,
33         0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa,
34         0xe5, 0x34, 0x5c, 0xe4, 0x37, 0x59, 0xeb, 0x26, 0x6a, 0xbe, 0xd9, 0x70, 0x90, 0xab, 0xe6, 0x31,
35         0x53, 0xf5, 0x04, 0x0c, 0x14, 0x3c, 0x44, 0xcc, 0x4f, 0xd1, 0x68, 0xb8, 0xd3, 0x6e, 0xb2, 0xcd,
36         0x4c, 0xd4, 0x67, 0xa9, 0xe0, 0x3b, 0x4d, 0xd7, 0x62, 0xa6, 0xf1, 0x08, 0x18, 0x28, 0x78, 0x88,
37         0x83, 0x9e, 0xb9, 0xd0, 0x6b, 0xbd, 0xdc, 0x7f, 0x81, 0x98, 0xb3, 0xce, 0x49, 0xdb, 0x76, 0x9a,
38         0xb5, 0xc4, 0x57, 0xf9, 0x10, 0x30, 0x50, 0xf0, 0x0b, 0x1d, 0x27, 0x69, 0xbb, 0xd6, 0x61, 0xa3,
39         0xfe, 0x19, 0x2b, 0x7d, 0x87, 0x92, 0xad, 0xec, 0x2f, 0x71, 0x93, 0xae, 0xe9, 0x20, 0x60, 0xa0,
40         0xfb, 0x16, 0x3a, 0x4e, 0xd2, 0x6d, 0xb7, 0xc2, 0x5d, 0xe7, 0x32, 0x56, 0xfa, 0x15, 0x3f, 0x41,
41         0xc3, 0x5e, 0xe2, 0x3d, 0x47, 0xc9, 0x40, 0xc0, 0x5b, 0xed, 0x2c, 0x74, 0x9c, 0xbf, 0xda, 0x75,
42         0x9f, 0xba, 0xd5, 0x64, 0xac, 0xef, 0x2a, 0x7e, 0x82, 0x9d, 0xbc, 0xdf, 0x7a, 0x8e, 0x89, 0x80,
43         0x9b, 0xb6, 0xc1, 0x58, 0xe8, 0x23, 0x65, 0xaf, 0xea, 0x25, 0x6f, 0xb1, 0xc8, 0x43, 0xc5, 0x54,
44         0xfc, 0x1f, 0x21, 0x63, 0xa5, 0xf4, 0x07, 0x09, 0x1b, 0x2d, 0x77, 0x99, 0xb0, 0xcb, 0x46, 0xca,
45         0x45, 0xcf, 0x4a, 0xde, 0x79, 0x8b, 0x86, 0x91, 0xa8, 0xe3, 0x3e, 0x42, 0xc6, 0x51, 0xf3, 0x0e,
46         0x12, 0x36, 0x5a, 0xee, 0x29, 0x7b, 0x8d, 0x8c, 0x8f, 0x8a, 0x85, 0x94, 0xa7, 0xf2, 0x0d, 0x17,
47         0x39, 0x4b, 0xdd, 0x7c, 0x84, 0x97, 0xa2, 0xfd, 0x1c, 0x24, 0x6c, 0xb4, 0xc7, 0x52, 0xf6, 0x01};
48 static const uint8_t log[P] = {
49         0x00, // log(0) is not defined
50         0xff, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7, 0x1b, 0x68, 0x33, 0xee, 0xdf, 0x03, 0x64,
51         0x04, 0xe0, 0x0e, 0x34, 0x8d, 0x81, 0xef, 0x4c, 0x71, 0x08, 0xc8, 0xf8, 0x69, 0x1c, 0xc1, 0x7d,
52         0xc2, 0x1d, 0xb5, 0xf9, 0xb9, 0x27, 0x6a, 0x4d, 0xe4, 0xa6, 0x72, 0x9a, 0xc9, 0x09, 0x78, 0x65,
53         0x2f, 0x8a, 0x05, 0x21, 0x0f, 0xe1, 0x24, 0x12, 0xf0, 0x82, 0x45, 0x35, 0x93, 0xda, 0x8e, 0x96,
54         0x8f, 0xdb, 0xbd, 0x36, 0xd0, 0xce, 0x94, 0x13, 0x5c, 0xd2, 0xf1, 0x40, 0x46, 0x83, 0x38, 0x66,
55         0xdd, 0xfd, 0x30, 0xbf, 0x06, 0x8b, 0x62, 0xb3, 0x25, 0xe2, 0x98, 0x22, 0x88, 0x91, 0x10, 0x7e,
56         0x6e, 0x48, 0xc3, 0xa3, 0xb6, 0x1e, 0x42, 0x3a, 0x6b, 0x28, 0x54, 0xfa, 0x85, 0x3d, 0xba, 0x2b,
57         0x79, 0x0a, 0x15, 0x9b, 0x9f, 0x5e, 0xca, 0x4e, 0xd4, 0xac, 0xe5, 0xf3, 0x73, 0xa7, 0x57, 0xaf,
58         0x58, 0xa8, 0x50, 0xf4, 0xea, 0xd6, 0x74, 0x4f, 0xae, 0xe9, 0xd5, 0xe7, 0xe6, 0xad, 0xe8, 0x2c,
59         0xd7, 0x75, 0x7a, 0xeb, 0x16, 0x0b, 0xf5, 0x59, 0xcb, 0x5f, 0xb0, 0x9c, 0xa9, 0x51, 0xa0, 0x7f,
60         0x0c, 0xf6, 0x6f, 0x17, 0xc4, 0x49, 0xec, 0xd8, 0x43, 0x1f, 0x2d, 0xa4, 0x76, 0x7b, 0xb7, 0xcc,
61         0xbb, 0x3e, 0x5a, 0xfb, 0x60, 0xb1, 0x86, 0x3b, 0x52, 0xa1, 0x6c, 0xaa, 0x55, 0x29, 0x9d, 0x97,
62         0xb2, 0x87, 0x90, 0x61, 0xbe, 0xdc, 0xfc, 0xbc, 0x95, 0xcf, 0xcd, 0x37, 0x3f, 0x5b, 0xd1, 0x53,
63         0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44,
64         0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67,
65         0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};
66
67 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
68 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) __attribute__((noinline));
69 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) {
70         uint8_t ret, ret2;
71         if (a == 0)
72                 ret2 = 0;
73         else
74                 ret2 = calc;
75         if (b == 0)
76                 ret = 0;
77         else
78                 ret = ret2;
79         return ret;
80 }
81 static uint8_t field_mul(uint8_t a, uint8_t b)  {
82         return field_mul_ret(exp[(log[a] + log[b]) % 255], a, b);
83 }
84
85 static uint8_t field_invert(uint8_t a) {
86         assert(a != 0);
87         return exp[0xff - log[a]]; // log[1] == 0xff
88 }
89
90 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
91 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) __attribute__((optimize("-O0"))) __attribute__((noinline));
92 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) {
93         uint8_t ret, ret2;
94         if (a == 0)
95                 ret2 = 0;
96         else
97                 ret2 = calc;
98         if (e == 0)
99                 ret = 1;
100         else
101                 ret = ret2;
102         return ret;
103 }
104 static uint8_t field_pow(uint8_t a, uint8_t e) {
105         return field_pow_ret(exp[(log[a] * e) % 255], a, e);
106 }
107
108 #ifdef TEST
109 static uint8_t field_mul_calc(uint8_t a, uint8_t b) {
110         // side-channel attacks here
111         uint8_t ret = 0;
112         uint8_t counter;
113         uint8_t carry;
114         for (counter = 0; counter < 8; counter++) {
115                 if (b & 1)
116                         ret ^= a;
117                 carry = (a & 0x80);
118                 a <<= 1;
119                 if (carry)
120                         a ^= 0x1b; // what x^8 is modulo x^8 + x^4 + x^3 + x + 1
121                 b >>= 1;
122         }
123         return ret;
124 }
125 static uint8_t field_pow_calc(uint8_t a, uint8_t e) {
126         uint8_t ret = 1;
127         for (uint8_t i = 0; i < e; i++)
128                 ret = field_mul(ret, a);
129         return ret;
130 }
131 int main() {
132         // Test inversion with the logarithm tables
133         for (uint16_t i = 1; i < P; i++)
134                 assert(field_mul_calc(i, field_invert(i)) == 1);
135
136         // Test multiplication with the logarithm tables
137         for (uint16_t i = 0; i < 2; i++) {
138                 for (uint16_t j = 0; j < P; j++)
139                         assert(field_mul(i, j) == field_mul_calc(i, j));
140         }
141
142         // Test exponentiation with the logarithm tables
143         for (uint16_t i = 0; i < P; i++) {
144                 for (uint16_t j = 0; j < P; j++)
145                         assert(field_pow(i, j) == field_pow_calc(i, j));
146         }
147 }
148 #endif // defined(TEST)
149
150
151
152 /*
153  * Calculations across the polynomial q
154  */
155 #ifndef TEST
156 static uint8_t calculateQ(uint8_t a[], uint8_t k, uint8_t x) {
157         assert(x != 0); // q(0) == secret, though so does a[0]
158         uint8_t ret = a[0];
159         for (uint8_t i = 1; i < k; i++) {
160                 ret = field_add(ret, field_mul(a[i], field_pow(x, i)));
161         }
162         return ret;
163 }
164
165 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t k) {
166         // Calculate the x^0 term using a derivation of the forumula at
167         // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
168         uint8_t ret = 0;
169         for (uint8_t i = 0; i < k; i++) {
170                 uint8_t temp = q[i];
171                 for (uint8_t j = 0; j < k; j++) {
172                         if (i == j)
173                                 continue;
174                         temp = field_mul(temp, field_neg(x[j]));
175                         temp = field_mul(temp, field_invert(field_sub(x[i], x[j])));
176                 }
177                 ret = field_add(ret, temp);
178         }
179         return ret;
180 }
181
182
183
184 int main(int argc, char* argv[]) {
185         assert(mlockall(MCL_CURRENT | MCL_FUTURE) == 0);
186
187         char split = 0;
188         uint8_t n = 0, k = 0;
189         char* files[P]; uint8_t files_count = 0;
190         char *in_file = (void*)0, *out_file_param = (void*)0;
191
192         int i;
193         while((i = getopt(argc, argv, "scn:k:f:o:i:h?")) != -1)
194                 switch(i) {
195                 case 's':
196                         if ((split & 0x2) && !(split & 0x1))
197                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
198                         else
199                                 split = (0x2 | 0x1);
200                         break;
201                 case 'c':
202                         if ((split & 0x2) && (split & 0x1))
203                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
204                         else
205                                 split = 0x2;
206                         break;
207                 case 'n': {
208                         int t = atoi(optarg);
209                         if (t <= 0 || t >= P)
210                                 ERROREXIT("n must be > 0 and < %u\n", P)
211                         else
212                                 n = t;
213                         break;
214                 }
215                 case 'k': {
216                         int t = atoi(optarg);
217                         if (t <= 0 || t >= P)
218                                 ERROREXIT("n must be > 0 and < %u\n", P)
219                         else
220                                 k = t;
221                         break;
222                 }
223                 case 'i':
224                         in_file = optarg;
225                         break;
226                 case 'o':
227                         out_file_param = optarg;
228                         break;
229                 case 'f':
230                         if (files_count >= P-1)
231                                 ERROREXIT("May only specify up to %u files\n", P-1)
232                         files[files_count++] = optarg;
233                         break;
234                 case 'h':
235                 case '?':
236                         printf("Split usage: -s -n <total shares> -k <shares required> -i <input file> -o <output file path base>\n");
237                         printf("Combine usage: -c -k <shares provided == shares required> <-f <share>>*k -o <output file>\n");
238                         exit(0);
239                         break;
240                 default:
241                         ERROREXIT("getopt failed?\n")
242                 }
243         if (!(split & 0x2))
244                 ERROREXIT("Must specify either -c or -s\n")
245         split &= 0x1;
246
247         if (argc != optind)
248                 ERROREXIT("Invalid argument\n")
249
250         if (split) {
251                 if (!n || !k)
252                         ERROREXIT("n and k must be set.\n")
253
254                 if (k > n)
255                         ERROREXIT("k must be <= n\n")
256
257                 if (files_count != 0 || !in_file || !out_file_param)
258                         ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
259
260                 FILE* random = fopen("/dev/random", "r");
261                 assert(random);
262                 FILE* secret_file = fopen(in_file, "r");
263                 if (!secret_file)
264                         ERROREXIT("Could not open %s for reading.\n", in_file)
265
266                 uint8_t secret[MAX_LENGTH];
267
268                 size_t secret_length = fread(secret, 1, MAX_LENGTH*sizeof(uint8_t), secret_file);
269                 if (secret_length == 0)
270                         ERROREXIT("Error reading secret\n")
271                 if (fread(secret, 1, 1, secret_file) > 0)
272                         ERROREXIT("Secret may not be longer than %u\n", MAX_LENGTH)
273                 fclose(secret_file);
274                 printf("Using secret of length %lu\n", secret_length);
275
276                 uint8_t a[k], D[n][secret_length];
277
278                 for (uint8_t i = 0; i < secret_length; i++) {
279                         a[0] = secret[i];
280
281                         for (uint8_t j = 1; j < k; j++)
282                                 assert(fread(&a[j], sizeof(uint8_t), 1, random) == 1);
283                         for (uint8_t j = 0; j < n; j++)
284                                 D[j][i] = calculateQ(a, k, j+1);
285                 }
286
287                 char out_file_name_buf[strlen(out_file_param) + 4];
288                 strcpy(out_file_name_buf, out_file_param);
289                 for (uint8_t i = 0; i < n; i++) {
290                         /*printf("%u-", i);
291                         for (uint8_t j = 0; j < secret_length; j++)
292                                 printf("%02x", D[i][j]);
293                         printf("\n");*/
294
295                         sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
296                         FILE* out_file = fopen(out_file_name_buf, "w+");
297                         if (!out_file)
298                                 ERROREXIT("Could not open output file %s\n", out_file_name_buf)
299
300                         uint8_t x = i+1;
301                         if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
302                                 ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
303
304                         if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
305                                 ERROREXIT("Could not write %lu bytes to %s\n", secret_length, out_file_name_buf)
306
307                         fclose(out_file);
308                 }
309                 /*printf("secret = ");
310                 for (uint8_t i = 0; i < secret_length; i++)
311                         printf("%02x", secret[i]);
312                 printf("\n");*/
313
314                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
315                 memset(secret, 0, sizeof(uint8_t)*secret_length);
316                 memset(a, 0, sizeof(uint8_t)*k);
317                 memset(in_file, 0, strlen(in_file));
318
319                 fclose(random);
320         } else {
321                 if (!k)
322                         ERROREXIT("k must be set.\n")
323
324                 if (files_count != k || in_file || !out_file_param)
325                         ERROREXIT("Must not specify -i and must specify -o and exactly k -f <input file>s in combine mode.\n")
326
327                 uint8_t x[k], q[k];
328                 FILE* files_fps[k];
329
330                 for (uint8_t i = 0; i < k; i++) {
331                         files_fps[i] = fopen(files[i], "r");
332                         if (!files_fps[i])
333                                 ERROREXIT("Couldn't open file %s for reading.\n", files[i])
334                         if (fread(&x[i], sizeof(uint8_t), 1, files_fps[i]) != 1)
335                                 ERROREXIT("Couldn't read the x byte of %s\n", files[i])
336                 }
337
338                 uint8_t secret[MAX_LENGTH];
339
340                 uint8_t i = 0;
341                 while (fread(&q[0], sizeof(uint8_t), 1, files_fps[0]) == 1) {
342                         for (uint8_t j = 1; j < k; j++) {
343                                 if (fread(&q[j], sizeof(uint8_t), 1, files_fps[j]) != 1)
344                                         ERROREXIT("Couldn't read next byte from %s\n", files[j])
345                         }
346                         secret[i++] = calculateSecret(x, q, k);
347                 }
348                 printf("Got secret of length %u\n", i);
349
350                 FILE* out_file = fopen(out_file_param, "w+");
351                 fwrite(secret, sizeof(uint8_t), i, out_file);
352                 fclose(out_file);
353
354                 for (uint8_t i = 0; i < k; i++)
355                         fclose(files_fps[i]);
356
357                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
358                 memset(secret, 0, sizeof(uint8_t)*i);
359                 memset(q, 0, sizeof(uint8_t)*k);
360                 memset(out_file_param, 0, strlen(out_file_param));
361                 for (uint8_t i = 0; i < k; i++)
362                         memset(files[i], 0, strlen(files[i]));
363                 memset(x, 0, sizeof(uint8_t)*k);
364         }
365
366         return 0;
367 }
368 #endif // !defined(TEST)