80f835c46318bb2ab825a40135af853d51b96821
[shamirs] / shamirssecret.c
1 #define _GNU_SOURCE
2
3 #include <stdint.h>
4 #include <stdio.h>
5 #include <assert.h>
6 #include <string.h>
7 #include <unistd.h>
8 #include <stdlib.h>
9 #include <sys/mman.h>
10
11 #define MAX_LENGTH 1024
12 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
13
14 /*
15  * Calculations across the finite field GF(2^8)
16  */
17 #define P 256
18
19 static uint8_t field_add(uint8_t a, uint8_t b) {
20         return a ^ b;
21 }
22
23 static uint8_t field_sub(uint8_t a, uint8_t b) {
24         return a ^ b;
25 }
26
27 static uint8_t field_neg(uint8_t a) {
28         return field_sub(0, a);
29 }
30
31 static const uint8_t exp[P] = {
32         0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,
33         0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa,
34         0xe5, 0x34, 0x5c, 0xe4, 0x37, 0x59, 0xeb, 0x26, 0x6a, 0xbe, 0xd9, 0x70, 0x90, 0xab, 0xe6, 0x31,
35         0x53, 0xf5, 0x04, 0x0c, 0x14, 0x3c, 0x44, 0xcc, 0x4f, 0xd1, 0x68, 0xb8, 0xd3, 0x6e, 0xb2, 0xcd,
36         0x4c, 0xd4, 0x67, 0xa9, 0xe0, 0x3b, 0x4d, 0xd7, 0x62, 0xa6, 0xf1, 0x08, 0x18, 0x28, 0x78, 0x88,
37         0x83, 0x9e, 0xb9, 0xd0, 0x6b, 0xbd, 0xdc, 0x7f, 0x81, 0x98, 0xb3, 0xce, 0x49, 0xdb, 0x76, 0x9a,
38         0xb5, 0xc4, 0x57, 0xf9, 0x10, 0x30, 0x50, 0xf0, 0x0b, 0x1d, 0x27, 0x69, 0xbb, 0xd6, 0x61, 0xa3,
39         0xfe, 0x19, 0x2b, 0x7d, 0x87, 0x92, 0xad, 0xec, 0x2f, 0x71, 0x93, 0xae, 0xe9, 0x20, 0x60, 0xa0,
40         0xfb, 0x16, 0x3a, 0x4e, 0xd2, 0x6d, 0xb7, 0xc2, 0x5d, 0xe7, 0x32, 0x56, 0xfa, 0x15, 0x3f, 0x41,
41         0xc3, 0x5e, 0xe2, 0x3d, 0x47, 0xc9, 0x40, 0xc0, 0x5b, 0xed, 0x2c, 0x74, 0x9c, 0xbf, 0xda, 0x75,
42         0x9f, 0xba, 0xd5, 0x64, 0xac, 0xef, 0x2a, 0x7e, 0x82, 0x9d, 0xbc, 0xdf, 0x7a, 0x8e, 0x89, 0x80,
43         0x9b, 0xb6, 0xc1, 0x58, 0xe8, 0x23, 0x65, 0xaf, 0xea, 0x25, 0x6f, 0xb1, 0xc8, 0x43, 0xc5, 0x54,
44         0xfc, 0x1f, 0x21, 0x63, 0xa5, 0xf4, 0x07, 0x09, 0x1b, 0x2d, 0x77, 0x99, 0xb0, 0xcb, 0x46, 0xca,
45         0x45, 0xcf, 0x4a, 0xde, 0x79, 0x8b, 0x86, 0x91, 0xa8, 0xe3, 0x3e, 0x42, 0xc6, 0x51, 0xf3, 0x0e,
46         0x12, 0x36, 0x5a, 0xee, 0x29, 0x7b, 0x8d, 0x8c, 0x8f, 0x8a, 0x85, 0x94, 0xa7, 0xf2, 0x0d, 0x17,
47         0x39, 0x4b, 0xdd, 0x7c, 0x84, 0x97, 0xa2, 0xfd, 0x1c, 0x24, 0x6c, 0xb4, 0xc7, 0x52, 0xf6, 0x01};
48 static const uint8_t log[P] = {
49         0x00, // log(0) is not defined
50         0xff, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7, 0x1b, 0x68, 0x33, 0xee, 0xdf, 0x03, 0x64,
51         0x04, 0xe0, 0x0e, 0x34, 0x8d, 0x81, 0xef, 0x4c, 0x71, 0x08, 0xc8, 0xf8, 0x69, 0x1c, 0xc1, 0x7d,
52         0xc2, 0x1d, 0xb5, 0xf9, 0xb9, 0x27, 0x6a, 0x4d, 0xe4, 0xa6, 0x72, 0x9a, 0xc9, 0x09, 0x78, 0x65,
53         0x2f, 0x8a, 0x05, 0x21, 0x0f, 0xe1, 0x24, 0x12, 0xf0, 0x82, 0x45, 0x35, 0x93, 0xda, 0x8e, 0x96,
54         0x8f, 0xdb, 0xbd, 0x36, 0xd0, 0xce, 0x94, 0x13, 0x5c, 0xd2, 0xf1, 0x40, 0x46, 0x83, 0x38, 0x66,
55         0xdd, 0xfd, 0x30, 0xbf, 0x06, 0x8b, 0x62, 0xb3, 0x25, 0xe2, 0x98, 0x22, 0x88, 0x91, 0x10, 0x7e,
56         0x6e, 0x48, 0xc3, 0xa3, 0xb6, 0x1e, 0x42, 0x3a, 0x6b, 0x28, 0x54, 0xfa, 0x85, 0x3d, 0xba, 0x2b,
57         0x79, 0x0a, 0x15, 0x9b, 0x9f, 0x5e, 0xca, 0x4e, 0xd4, 0xac, 0xe5, 0xf3, 0x73, 0xa7, 0x57, 0xaf,
58         0x58, 0xa8, 0x50, 0xf4, 0xea, 0xd6, 0x74, 0x4f, 0xae, 0xe9, 0xd5, 0xe7, 0xe6, 0xad, 0xe8, 0x2c,
59         0xd7, 0x75, 0x7a, 0xeb, 0x16, 0x0b, 0xf5, 0x59, 0xcb, 0x5f, 0xb0, 0x9c, 0xa9, 0x51, 0xa0, 0x7f,
60         0x0c, 0xf6, 0x6f, 0x17, 0xc4, 0x49, 0xec, 0xd8, 0x43, 0x1f, 0x2d, 0xa4, 0x76, 0x7b, 0xb7, 0xcc,
61         0xbb, 0x3e, 0x5a, 0xfb, 0x60, 0xb1, 0x86, 0x3b, 0x52, 0xa1, 0x6c, 0xaa, 0x55, 0x29, 0x9d, 0x97,
62         0xb2, 0x87, 0x90, 0x61, 0xbe, 0xdc, 0xfc, 0xbc, 0x95, 0xcf, 0xcd, 0x37, 0x3f, 0x5b, 0xd1, 0x53,
63         0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44,
64         0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67,
65         0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};
66
67 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
68 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) __attribute__((noinline));
69 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) {
70         uint8_t ret, ret2;
71         if (a == 0)
72                 ret2 = 0;
73         else
74                 ret2 = calc;
75         if (b == 0)
76                 ret = 0;
77         else
78                 ret = ret2;
79         return ret;
80 }
81 static uint8_t field_mul(uint8_t a, uint8_t b)  {
82         return field_mul_ret(exp[(log[a] + log[b]) % 255], a, b);
83 }
84
85 static uint8_t field_invert(uint8_t a) {
86         assert(a != 0);
87         return exp[0xff - log[a]]; // log[1] == 0xff
88 }
89
90 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
91 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) __attribute__((optimize("-O0"))) __attribute__((noinline));
92 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) {
93         uint8_t ret, ret2;
94         if (a == 0)
95                 ret2 = 0;
96         else
97                 ret2 = calc;
98         if (e == 0)
99                 ret = 1;
100         else
101                 ret = ret2;
102         return ret;
103 }
104 static uint8_t field_pow(uint8_t a, uint8_t e) {
105 #ifndef TEST
106         // Although this function works for a==0, its not trivially obvious why,
107         // and since we never call with a==0, we just assert a != 0 (except when testing)
108         assert(a != 0);
109 #endif
110         return field_pow_ret(exp[(log[a] * e) % 255], a, e);
111 }
112
113 #ifdef TEST
114 static uint8_t field_mul_calc(uint8_t a, uint8_t b) {
115         // side-channel attacks here
116         uint8_t ret = 0;
117         uint8_t counter;
118         uint8_t carry;
119         for (counter = 0; counter < 8; counter++) {
120                 if (b & 1)
121                         ret ^= a;
122                 carry = (a & 0x80);
123                 a <<= 1;
124                 if (carry)
125                         a ^= 0x1b; // what x^8 is modulo x^8 + x^4 + x^3 + x + 1
126                 b >>= 1;
127         }
128         return ret;
129 }
130 static uint8_t field_pow_calc(uint8_t a, uint8_t e) {
131         uint8_t ret = 1;
132         for (uint8_t i = 0; i < e; i++)
133                 ret = field_mul_calc(ret, a);
134         return ret;
135 }
136 int main() {
137         // Test inversion with the logarithm tables
138         for (uint16_t i = 1; i < P; i++)
139                 assert(field_mul_calc(i, field_invert(i)) == 1);
140
141         // Test multiplication with the logarithm tables
142         for (uint16_t i = 0; i < P; i++) {
143                 for (uint16_t j = 0; j < P; j++)
144                         assert(field_mul(i, j) == field_mul_calc(i, j));
145         }
146
147         // Test exponentiation with the logarithm tables
148         for (uint16_t i = 0; i < P; i++) {
149                 for (uint16_t j = 0; j < P; j++)
150                         assert(field_pow(i, j) == field_pow_calc(i, j));
151         }
152 }
153 #endif // defined(TEST)
154
155
156
157 /*
158  * Calculations across the polynomial q
159  */
160 #ifndef TEST
161 static uint8_t calculateQ(uint8_t a[], uint8_t k, uint8_t x) {
162         assert(x != 0); // q(0) == secret, though so does a[0]
163         uint8_t ret = a[0];
164         for (uint8_t i = 1; i < k; i++) {
165                 ret = field_add(ret, field_mul(a[i], field_pow(x, i)));
166         }
167         return ret;
168 }
169
170 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t k) {
171         // Calculate the x^0 term using a derivation of the forumula at
172         // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
173         uint8_t ret = 0;
174         for (uint8_t i = 0; i < k; i++) {
175                 uint8_t temp = q[i];
176                 for (uint8_t j = 0; j < k; j++) {
177                         if (i == j)
178                                 continue;
179                         temp = field_mul(temp, field_neg(x[j]));
180                         temp = field_mul(temp, field_invert(field_sub(x[i], x[j])));
181                 }
182                 ret = field_add(ret, temp);
183         }
184         return ret;
185 }
186
187
188
189 int main(int argc, char* argv[]) {
190         assert(mlockall(MCL_CURRENT | MCL_FUTURE) == 0);
191
192         char split = 0;
193         uint8_t total_shares = 0, shares_required = 0;
194         char* files[P]; uint8_t files_count = 0;
195         char *in_file = (void*)0, *out_file_param = (void*)0;
196
197         int i;
198         while((i = getopt(argc, argv, "scn:k:f:o:i:h?")) != -1)
199                 switch(i) {
200                 case 's':
201                         if ((split & 0x2) && !(split & 0x1))
202                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
203                         else
204                                 split = (0x2 | 0x1);
205                         break;
206                 case 'c':
207                         if ((split & 0x2) && (split & 0x1))
208                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
209                         else
210                                 split = 0x2;
211                         break;
212                 case 'n': {
213                         int t = atoi(optarg);
214                         if (t <= 0 || t >= P)
215                                 ERROREXIT("n must be > 0 and < %u\n", P)
216                         else
217                                 total_shares = t;
218                         break;
219                 }
220                 case 'k': {
221                         int t = atoi(optarg);
222                         if (t <= 0 || t >= P)
223                                 ERROREXIT("n must be > 0 and < %u\n", P)
224                         else
225                                 shares_required = t;
226                         break;
227                 }
228                 case 'i':
229                         in_file = optarg;
230                         break;
231                 case 'o':
232                         out_file_param = optarg;
233                         break;
234                 case 'f':
235                         if (files_count >= P-1)
236                                 ERROREXIT("May only specify up to %u files\n", P-1)
237                         files[files_count++] = optarg;
238                         break;
239                 case 'h':
240                 case '?':
241                         printf("Split usage: -s -n <total shares> -k <shares required> -i <input file> -o <output file path base>\n");
242                         printf("Combine usage: -c -k <shares provided == shares required> <-f <share>>*k -o <output file>\n");
243                         exit(0);
244                         break;
245                 default:
246                         ERROREXIT("getopt failed?\n")
247                 }
248         if (!(split & 0x2))
249                 ERROREXIT("Must specify one of -c, -s or -?\n")
250         split &= 0x1;
251
252         if (argc != optind)
253                 ERROREXIT("Invalid argument\n")
254
255         if (split) {
256                 if (!total_shares || !shares_required)
257                         ERROREXIT("n and k must be set.\n")
258
259                 if (shares_required > total_shares)
260                         ERROREXIT("k must be <= n\n")
261
262                 if (files_count != 0 || !in_file || !out_file_param)
263                         ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
264
265                 FILE* random = fopen("/dev/random", "r");
266                 assert(random);
267                 FILE* secret_file = fopen(in_file, "r");
268                 if (!secret_file)
269                         ERROREXIT("Could not open %s for reading.\n", in_file)
270
271                 uint8_t secret[MAX_LENGTH];
272
273                 size_t secret_length = fread(secret, 1, MAX_LENGTH*sizeof(uint8_t), secret_file);
274                 if (secret_length == 0)
275                         ERROREXIT("Error reading secret\n")
276                 if (fread(secret, 1, 1, secret_file) > 0)
277                         ERROREXIT("Secret may not be longer than %u\n", MAX_LENGTH)
278                 fclose(secret_file);
279                 printf("Using secret of length %lu\n", secret_length);
280
281                 uint8_t a[shares_required], D[total_shares][secret_length];
282
283                 for (uint32_t i = 0; i < secret_length; i++) {
284                         a[0] = secret[i];
285
286                         for (uint8_t j = 1; j < shares_required; j++)
287                                 assert(fread(&a[j], sizeof(uint8_t), 1, random) == 1);
288                         for (uint8_t j = 0; j < total_shares; j++)
289                                 D[j][i] = calculateQ(a, shares_required, j+1);
290
291                         if (i % 32 == 0 && i != 0)
292                                 printf("Finished processing %u bytes.\n", i);
293                 }
294
295                 char out_file_name_buf[strlen(out_file_param) + 4];
296                 strcpy(out_file_name_buf, out_file_param);
297                 for (uint8_t i = 0; i < total_shares; i++) {
298                         /*printf("%u-", i);
299                         for (uint8_t j = 0; j < secret_length; j++)
300                                 printf("%02x", D[i][j]);
301                         printf("\n");*/
302
303                         sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
304                         FILE* out_file = fopen(out_file_name_buf, "w+");
305                         if (!out_file)
306                                 ERROREXIT("Could not open output file %s\n", out_file_name_buf)
307
308                         uint8_t x = i+1;
309                         if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
310                                 ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
311
312                         if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
313                                 ERROREXIT("Could not write %lu bytes to %s\n", secret_length, out_file_name_buf)
314
315                         fclose(out_file);
316                 }
317                 /*printf("secret = ");
318                 for (uint8_t i = 0; i < secret_length; i++)
319                         printf("%02x", secret[i]);
320                 printf("\n");*/
321
322                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
323                 memset(secret, 0, sizeof(uint8_t)*secret_length);
324                 memset(a, 0, sizeof(uint8_t)*shares_required);
325                 memset(in_file, 0, strlen(in_file));
326
327                 fclose(random);
328         } else {
329                 if (!shares_required)
330                         ERROREXIT("k must be set.\n")
331
332                 if (files_count != shares_required || in_file || !out_file_param)
333                         ERROREXIT("Must not specify -i and must specify -o and exactly k -f <input file>s in combine mode.\n")
334
335                 uint8_t x[shares_required], q[shares_required];
336                 FILE* files_fps[shares_required];
337
338                 for (uint8_t i = 0; i < shares_required; i++) {
339                         files_fps[i] = fopen(files[i], "r");
340                         if (!files_fps[i])
341                                 ERROREXIT("Couldn't open file %s for reading.\n", files[i])
342                         if (fread(&x[i], sizeof(uint8_t), 1, files_fps[i]) != 1)
343                                 ERROREXIT("Couldn't read the x byte of %s\n", files[i])
344                 }
345
346                 uint8_t secret[MAX_LENGTH];
347
348                 uint32_t i = 0;
349                 while (fread(&q[0], sizeof(uint8_t), 1, files_fps[0]) == 1) {
350                         for (uint8_t j = 1; j < shares_required; j++) {
351                                 if (fread(&q[j], sizeof(uint8_t), 1, files_fps[j]) != 1)
352                                         ERROREXIT("Couldn't read next byte from %s\n", files[j])
353                         }
354                         secret[i++] = calculateSecret(x, q, shares_required);
355                 }
356                 printf("Got secret of length %u\n", i);
357
358                 FILE* out_file = fopen(out_file_param, "w+");
359                 fwrite(secret, sizeof(uint8_t), i, out_file);
360                 fclose(out_file);
361
362                 for (uint8_t i = 0; i < shares_required; i++)
363                         fclose(files_fps[i]);
364
365                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
366                 memset(secret, 0, sizeof(uint8_t)*i);
367                 memset(q, 0, sizeof(uint8_t)*shares_required);
368                 memset(out_file_param, 0, strlen(out_file_param));
369                 for (uint8_t i = 0; i < shares_required; i++)
370                         memset(files[i], 0, strlen(files[i]));
371                 memset(x, 0, sizeof(uint8_t)*shares_required);
372         }
373
374         return 0;
375 }
376 #endif // !defined(TEST)