Test during split that one missing part will allow for no leakage.
[shamirs] / shamirssecret.c
1 #define _GNU_SOURCE
2
3 #include <stdint.h>
4 #include <stdio.h>
5 #include <assert.h>
6 #include <string.h>
7 #include <unistd.h>
8 #include <stdlib.h>
9 #include <sys/mman.h>
10 #include <stdbool.h>
11
12 #define MAX_LENGTH 1024
13 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
14
15 /*
16  * Calculations across the finite field GF(2^8)
17  */
18 #define P 256
19
20 #ifndef TEST
21 static uint8_t field_add(uint8_t a, uint8_t b) {
22         return a ^ b;
23 }
24
25 static uint8_t field_sub(uint8_t a, uint8_t b) {
26         return a ^ b;
27 }
28
29 static uint8_t field_neg(uint8_t a) {
30         return field_sub(0, a);
31 }
32 #endif
33
34 static const uint8_t exp[P] = {
35         0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,
36         0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa,
37         0xe5, 0x34, 0x5c, 0xe4, 0x37, 0x59, 0xeb, 0x26, 0x6a, 0xbe, 0xd9, 0x70, 0x90, 0xab, 0xe6, 0x31,
38         0x53, 0xf5, 0x04, 0x0c, 0x14, 0x3c, 0x44, 0xcc, 0x4f, 0xd1, 0x68, 0xb8, 0xd3, 0x6e, 0xb2, 0xcd,
39         0x4c, 0xd4, 0x67, 0xa9, 0xe0, 0x3b, 0x4d, 0xd7, 0x62, 0xa6, 0xf1, 0x08, 0x18, 0x28, 0x78, 0x88,
40         0x83, 0x9e, 0xb9, 0xd0, 0x6b, 0xbd, 0xdc, 0x7f, 0x81, 0x98, 0xb3, 0xce, 0x49, 0xdb, 0x76, 0x9a,
41         0xb5, 0xc4, 0x57, 0xf9, 0x10, 0x30, 0x50, 0xf0, 0x0b, 0x1d, 0x27, 0x69, 0xbb, 0xd6, 0x61, 0xa3,
42         0xfe, 0x19, 0x2b, 0x7d, 0x87, 0x92, 0xad, 0xec, 0x2f, 0x71, 0x93, 0xae, 0xe9, 0x20, 0x60, 0xa0,
43         0xfb, 0x16, 0x3a, 0x4e, 0xd2, 0x6d, 0xb7, 0xc2, 0x5d, 0xe7, 0x32, 0x56, 0xfa, 0x15, 0x3f, 0x41,
44         0xc3, 0x5e, 0xe2, 0x3d, 0x47, 0xc9, 0x40, 0xc0, 0x5b, 0xed, 0x2c, 0x74, 0x9c, 0xbf, 0xda, 0x75,
45         0x9f, 0xba, 0xd5, 0x64, 0xac, 0xef, 0x2a, 0x7e, 0x82, 0x9d, 0xbc, 0xdf, 0x7a, 0x8e, 0x89, 0x80,
46         0x9b, 0xb6, 0xc1, 0x58, 0xe8, 0x23, 0x65, 0xaf, 0xea, 0x25, 0x6f, 0xb1, 0xc8, 0x43, 0xc5, 0x54,
47         0xfc, 0x1f, 0x21, 0x63, 0xa5, 0xf4, 0x07, 0x09, 0x1b, 0x2d, 0x77, 0x99, 0xb0, 0xcb, 0x46, 0xca,
48         0x45, 0xcf, 0x4a, 0xde, 0x79, 0x8b, 0x86, 0x91, 0xa8, 0xe3, 0x3e, 0x42, 0xc6, 0x51, 0xf3, 0x0e,
49         0x12, 0x36, 0x5a, 0xee, 0x29, 0x7b, 0x8d, 0x8c, 0x8f, 0x8a, 0x85, 0x94, 0xa7, 0xf2, 0x0d, 0x17,
50         0x39, 0x4b, 0xdd, 0x7c, 0x84, 0x97, 0xa2, 0xfd, 0x1c, 0x24, 0x6c, 0xb4, 0xc7, 0x52, 0xf6, 0x01};
51 static const uint8_t log[P] = {
52         0x00, // log(0) is not defined
53         0xff, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7, 0x1b, 0x68, 0x33, 0xee, 0xdf, 0x03, 0x64,
54         0x04, 0xe0, 0x0e, 0x34, 0x8d, 0x81, 0xef, 0x4c, 0x71, 0x08, 0xc8, 0xf8, 0x69, 0x1c, 0xc1, 0x7d,
55         0xc2, 0x1d, 0xb5, 0xf9, 0xb9, 0x27, 0x6a, 0x4d, 0xe4, 0xa6, 0x72, 0x9a, 0xc9, 0x09, 0x78, 0x65,
56         0x2f, 0x8a, 0x05, 0x21, 0x0f, 0xe1, 0x24, 0x12, 0xf0, 0x82, 0x45, 0x35, 0x93, 0xda, 0x8e, 0x96,
57         0x8f, 0xdb, 0xbd, 0x36, 0xd0, 0xce, 0x94, 0x13, 0x5c, 0xd2, 0xf1, 0x40, 0x46, 0x83, 0x38, 0x66,
58         0xdd, 0xfd, 0x30, 0xbf, 0x06, 0x8b, 0x62, 0xb3, 0x25, 0xe2, 0x98, 0x22, 0x88, 0x91, 0x10, 0x7e,
59         0x6e, 0x48, 0xc3, 0xa3, 0xb6, 0x1e, 0x42, 0x3a, 0x6b, 0x28, 0x54, 0xfa, 0x85, 0x3d, 0xba, 0x2b,
60         0x79, 0x0a, 0x15, 0x9b, 0x9f, 0x5e, 0xca, 0x4e, 0xd4, 0xac, 0xe5, 0xf3, 0x73, 0xa7, 0x57, 0xaf,
61         0x58, 0xa8, 0x50, 0xf4, 0xea, 0xd6, 0x74, 0x4f, 0xae, 0xe9, 0xd5, 0xe7, 0xe6, 0xad, 0xe8, 0x2c,
62         0xd7, 0x75, 0x7a, 0xeb, 0x16, 0x0b, 0xf5, 0x59, 0xcb, 0x5f, 0xb0, 0x9c, 0xa9, 0x51, 0xa0, 0x7f,
63         0x0c, 0xf6, 0x6f, 0x17, 0xc4, 0x49, 0xec, 0xd8, 0x43, 0x1f, 0x2d, 0xa4, 0x76, 0x7b, 0xb7, 0xcc,
64         0xbb, 0x3e, 0x5a, 0xfb, 0x60, 0xb1, 0x86, 0x3b, 0x52, 0xa1, 0x6c, 0xaa, 0x55, 0x29, 0x9d, 0x97,
65         0xb2, 0x87, 0x90, 0x61, 0xbe, 0xdc, 0xfc, 0xbc, 0x95, 0xcf, 0xcd, 0x37, 0x3f, 0x5b, 0xd1, 0x53,
66         0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44,
67         0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67,
68         0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};
69
70 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
71 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) __attribute__((noinline));
72 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) {
73         uint8_t ret, ret2;
74         if (a == 0)
75                 ret2 = 0;
76         else
77                 ret2 = calc;
78         if (b == 0)
79                 ret = 0;
80         else
81                 ret = ret2;
82         return ret;
83 }
84 static uint8_t field_mul(uint8_t a, uint8_t b)  {
85         return field_mul_ret(exp[(log[a] + log[b]) % 255], a, b);
86 }
87
88 static uint8_t field_invert(uint8_t a) {
89         assert(a != 0);
90         return exp[0xff - log[a]]; // log[1] == 0xff
91 }
92
93 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
94 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) __attribute__((optimize("-O0"))) __attribute__((noinline));
95 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) {
96         uint8_t ret, ret2;
97         if (a == 0)
98                 ret2 = 0;
99         else
100                 ret2 = calc;
101         if (e == 0)
102                 ret = 1;
103         else
104                 ret = ret2;
105         return ret;
106 }
107 static uint8_t field_pow(uint8_t a, uint8_t e) {
108 #ifndef TEST
109         // Although this function works for a==0, its not trivially obvious why,
110         // and since we never call with a==0, we just assert a != 0 (except when testing)
111         assert(a != 0);
112 #endif
113         return field_pow_ret(exp[(log[a] * e) % 255], a, e);
114 }
115
116 #ifdef TEST
117 static uint8_t field_mul_calc(uint8_t a, uint8_t b) {
118         // side-channel attacks here
119         uint8_t ret = 0;
120         uint8_t counter;
121         uint8_t carry;
122         for (counter = 0; counter < 8; counter++) {
123                 if (b & 1)
124                         ret ^= a;
125                 carry = (a & 0x80);
126                 a <<= 1;
127                 if (carry)
128                         a ^= 0x1b; // what x^8 is modulo x^8 + x^4 + x^3 + x + 1
129                 b >>= 1;
130         }
131         return ret;
132 }
133 static uint8_t field_pow_calc(uint8_t a, uint8_t e) {
134         uint8_t ret = 1;
135         for (uint8_t i = 0; i < e; i++)
136                 ret = field_mul_calc(ret, a);
137         return ret;
138 }
139 int main() {
140         // Test inversion with the logarithm tables
141         for (uint16_t i = 1; i < P; i++)
142                 assert(field_mul_calc(i, field_invert(i)) == 1);
143
144         // Test multiplication with the logarithm tables
145         for (uint16_t i = 0; i < P; i++) {
146                 for (uint16_t j = 0; j < P; j++)
147                         assert(field_mul(i, j) == field_mul_calc(i, j));
148         }
149
150         // Test exponentiation with the logarithm tables
151         for (uint16_t i = 0; i < P; i++) {
152                 for (uint16_t j = 0; j < P; j++)
153                         assert(field_pow(i, j) == field_pow_calc(i, j));
154         }
155 }
156 #endif // defined(TEST)
157
158
159
160 /*
161  * Calculations across the polynomial q
162  */
163 #ifndef TEST
164 static uint8_t calculateQ(uint8_t a[], uint8_t k, uint8_t x) {
165         assert(x != 0); // q(0) == secret, though so does a[0]
166         uint8_t ret = a[0];
167         for (uint8_t i = 1; i < k; i++) {
168                 ret = field_add(ret, field_mul(a[i], field_pow(x, i)));
169         }
170         return ret;
171 }
172
173 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t k) {
174         // Calculate the x^0 term using a derivation of the forumula at
175         // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
176         uint8_t ret = 0;
177         for (uint8_t i = 0; i < k; i++) {
178                 uint8_t temp = q[i];
179                 for (uint8_t j = 0; j < k; j++) {
180                         if (i == j)
181                                 continue;
182                         temp = field_mul(temp, field_neg(x[j]));
183                         temp = field_mul(temp, field_invert(field_sub(x[i], x[j])));
184                 }
185                 ret = field_add(ret, temp);
186         }
187         return ret;
188 }
189
190
191 void derive_missing_part(uint8_t total_shares, uint8_t shares_required, bool parts_have[], uint8_t* split_version, uint8_t split_index, uint8_t split_size) {
192         uint8_t (*D)[split_size] = (uint8_t (*)[split_size])split_version;
193         uint8_t x[shares_required], q[shares_required];
194
195         // Fill in x/q with the selected shares
196         uint16_t x_pos = 0;
197         for (uint8_t i = 0; i < P-1; i++) {
198                 if (parts_have[i]) {
199                         x[x_pos] = i+1;
200                         q[x_pos++] = D[i][split_index];
201                 }
202         }
203         assert(x_pos == shares_required - 1);
204
205         // Now loop through ALL x we didn't already set (despite not having that many
206         // shares, because more shares could be added arbitrarily, any x should not be
207         // able to rule out any possible secrets) and try each possible q, making sure
208         // that each q gives us a new possibility for the secret.
209         bool impossible_secrets[P];
210         memset(impossible_secrets, 0, sizeof(impossible_secrets));
211         for (uint16_t final_x = 1; final_x < P; final_x++) { 
212                 bool x_already_used = false;
213                 for (uint8_t j = 0; j < shares_required; j++) {
214                         if (x[j] == final_x)
215                                 x_already_used = true;
216                 }
217                 if (x_already_used)
218                         continue;
219
220                 x[shares_required-1] = final_x;
221                 bool possible_secrets[P];
222                 memset(possible_secrets, 0, sizeof(possible_secrets));
223                 for (uint16_t new_q = 0; new_q < P; new_q++) {
224                         q[shares_required-1] = new_q;
225                         possible_secrets[calculateSecret(x, q, shares_required)] = 1;
226                 }
227
228                 for (uint16_t i = 0; i < P; i++)
229                         assert(possible_secrets[i]);
230         }
231
232         //TODO: Check that gcc isn't optimizing this one away
233         memset(q, 0, sizeof(q));
234 }
235
236 void check_possible_missing_part_derivations_intern(uint8_t total_shares, uint8_t shares_required, bool parts_have[], uint8_t parts_included, uint16_t progress, uint8_t* split_version, uint8_t split_index, uint8_t split_size) {
237         if (parts_included == shares_required-1)
238                 return derive_missing_part(total_shares, shares_required, parts_have, split_version, split_index, split_size);
239
240         if (total_shares - progress < shares_required)
241                 return;
242
243         check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, parts_included, progress+1, split_version, split_index, split_size);
244         parts_have[progress] = 1;
245         check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, parts_included+1, progress+1, split_version, split_index, split_size);
246         parts_have[progress] = 0;
247 }
248
249 void check_possible_missing_part_derivations(uint8_t total_shares, uint8_t shares_required, uint8_t* split_version, uint8_t split_index, uint8_t split_size) {
250         bool parts_have[P];
251         memset(parts_have, 0, sizeof(parts_have));
252         check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, 0, 0, split_version, split_index, split_size);
253 }
254
255
256
257 int main(int argc, char* argv[]) {
258         assert(mlockall(MCL_CURRENT | MCL_FUTURE) == 0);
259
260         char split = 0;
261         uint8_t total_shares = 0, shares_required = 0;
262         char* files[P]; uint8_t files_count = 0;
263         char *in_file = (void*)0, *out_file_param = (void*)0;
264
265         int i;
266         while((i = getopt(argc, argv, "scn:k:f:o:i:h?")) != -1)
267                 switch(i) {
268                 case 's':
269                         if ((split & 0x2) && !(split & 0x1))
270                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
271                         else
272                                 split = (0x2 | 0x1);
273                         break;
274                 case 'c':
275                         if ((split & 0x2) && (split & 0x1))
276                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
277                         else
278                                 split = 0x2;
279                         break;
280                 case 'n': {
281                         int t = atoi(optarg);
282                         if (t <= 0 || t >= P)
283                                 ERROREXIT("n must be > 0 and < %u\n", P)
284                         else
285                                 total_shares = t;
286                         break;
287                 }
288                 case 'k': {
289                         int t = atoi(optarg);
290                         if (t <= 0 || t >= P)
291                                 ERROREXIT("n must be > 0 and < %u\n", P)
292                         else
293                                 shares_required = t;
294                         break;
295                 }
296                 case 'i':
297                         in_file = optarg;
298                         break;
299                 case 'o':
300                         out_file_param = optarg;
301                         break;
302                 case 'f':
303                         if (files_count >= P-1)
304                                 ERROREXIT("May only specify up to %u files\n", P-1)
305                         files[files_count++] = optarg;
306                         break;
307                 case 'h':
308                 case '?':
309                         printf("Split usage: -s -n <total shares> -k <shares required> -i <input file> -o <output file path base>\n");
310                         printf("Combine usage: -c -k <shares provided == shares required> <-f <share>>*k -o <output file>\n");
311                         exit(0);
312                         break;
313                 default:
314                         ERROREXIT("getopt failed?\n")
315                 }
316         if (!(split & 0x2))
317                 ERROREXIT("Must specify one of -c, -s or -?\n")
318         split &= 0x1;
319
320         if (argc != optind)
321                 ERROREXIT("Invalid argument\n")
322
323         if (split) {
324                 if (!total_shares || !shares_required)
325                         ERROREXIT("n and k must be set.\n")
326
327                 if (shares_required > total_shares)
328                         ERROREXIT("k must be <= n\n")
329
330                 if (files_count != 0 || !in_file || !out_file_param)
331                         ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
332
333                 FILE* random = fopen("/dev/random", "r");
334                 assert(random);
335                 FILE* secret_file = fopen(in_file, "r");
336                 if (!secret_file)
337                         ERROREXIT("Could not open %s for reading.\n", in_file)
338
339                 uint8_t secret[MAX_LENGTH];
340
341                 size_t secret_length = fread(secret, 1, MAX_LENGTH*sizeof(uint8_t), secret_file);
342                 if (secret_length == 0)
343                         ERROREXIT("Error reading secret\n")
344                 if (fread(secret, 1, 1, secret_file) > 0)
345                         ERROREXIT("Secret may not be longer than %u\n", MAX_LENGTH)
346                 fclose(secret_file);
347                 printf("Using secret of length %lu\n", secret_length);
348
349                 uint8_t a[shares_required], D[total_shares][secret_length];
350
351                 for (uint32_t i = 0; i < secret_length; i++) {
352                         a[0] = secret[i];
353
354                         for (uint8_t j = 1; j < shares_required; j++)
355                                 assert(fread(&a[j], sizeof(uint8_t), 1, random) == 1);
356                         for (uint8_t j = 0; j < total_shares; j++)
357                                 D[j][i] = calculateQ(a, shares_required, j+1);
358
359                         // Now, for sanity's sake, we ensure that no matter which piece we are missing, we can derive no information about the secret
360                         check_possible_missing_part_derivations(total_shares, shares_required, &(D[0][0]), i, secret_length);
361
362                         if (i % 32 == 0 && i != 0)
363                                 printf("Finished processing %u bytes.\n", i);
364                 }
365
366                 char out_file_name_buf[strlen(out_file_param) + 4];
367                 strcpy(out_file_name_buf, out_file_param);
368                 for (uint8_t i = 0; i < total_shares; i++) {
369                         /*printf("%u-", i);
370                         for (uint8_t j = 0; j < secret_length; j++)
371                                 printf("%02x", D[i][j]);
372                         printf("\n");*/
373
374                         sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
375                         FILE* out_file = fopen(out_file_name_buf, "w+");
376                         if (!out_file)
377                                 ERROREXIT("Could not open output file %s\n", out_file_name_buf)
378
379                         uint8_t x = i+1;
380                         if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
381                                 ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
382
383                         if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
384                                 ERROREXIT("Could not write %lu bytes to %s\n", secret_length, out_file_name_buf)
385
386                         fclose(out_file);
387                 }
388                 /*printf("secret = ");
389                 for (uint8_t i = 0; i < secret_length; i++)
390                         printf("%02x", secret[i]);
391                 printf("\n");*/
392
393                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
394                 memset(secret, 0, sizeof(uint8_t)*secret_length);
395                 memset(a, 0, sizeof(uint8_t)*shares_required);
396                 memset(in_file, 0, strlen(in_file));
397
398                 fclose(random);
399         } else {
400                 if (!shares_required)
401                         ERROREXIT("k must be set.\n")
402
403                 if (files_count != shares_required || in_file || !out_file_param)
404                         ERROREXIT("Must not specify -i and must specify -o and exactly k -f <input file>s in combine mode.\n")
405
406                 uint8_t x[shares_required], q[shares_required];
407                 FILE* files_fps[shares_required];
408
409                 for (uint8_t i = 0; i < shares_required; i++) {
410                         files_fps[i] = fopen(files[i], "r");
411                         if (!files_fps[i])
412                                 ERROREXIT("Couldn't open file %s for reading.\n", files[i])
413                         if (fread(&x[i], sizeof(uint8_t), 1, files_fps[i]) != 1)
414                                 ERROREXIT("Couldn't read the x byte of %s\n", files[i])
415                 }
416
417                 uint8_t secret[MAX_LENGTH];
418
419                 uint32_t i = 0;
420                 while (fread(&q[0], sizeof(uint8_t), 1, files_fps[0]) == 1) {
421                         for (uint8_t j = 1; j < shares_required; j++) {
422                                 if (fread(&q[j], sizeof(uint8_t), 1, files_fps[j]) != 1)
423                                         ERROREXIT("Couldn't read next byte from %s\n", files[j])
424                         }
425                         secret[i++] = calculateSecret(x, q, shares_required);
426                 }
427                 printf("Got secret of length %u\n", i);
428
429                 FILE* out_file = fopen(out_file_param, "w+");
430                 fwrite(secret, sizeof(uint8_t), i, out_file);
431                 fclose(out_file);
432
433                 for (uint8_t i = 0; i < shares_required; i++)
434                         fclose(files_fps[i]);
435
436                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
437                 memset(secret, 0, sizeof(uint8_t)*i);
438                 memset(q, 0, sizeof(uint8_t)*shares_required);
439                 memset(out_file_param, 0, strlen(out_file_param));
440                 for (uint8_t i = 0; i < shares_required; i++)
441                         memset(files[i], 0, strlen(files[i]));
442                 memset(x, 0, sizeof(uint8_t)*shares_required);
443         }
444
445         return 0;
446 }
447 #endif // !defined(TEST)