Minor crap.
[shamirs] / shamirssecret.c
1 #define _GNU_SOURCE
2
3 #include <stdint.h>
4 #include <stdio.h>
5 #include <assert.h>
6 #include <string.h>
7 #include <unistd.h>
8 #include <stdlib.h>
9 #include <sys/mman.h>
10
11 #define MAX_LENGTH 1024
12 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
13
14 /*
15  * Calculations across the finite field GF(2^8)
16  */
17 #define P 256
18
19 static uint8_t field_add(uint8_t a, uint8_t b) {
20         return a ^ b;
21 }
22
23 static uint8_t field_sub(uint8_t a, uint8_t b) {
24         return a ^ b;
25 }
26
27 static uint8_t field_neg(uint8_t a) {
28         return field_sub(0, a);
29 }
30
31 static const uint8_t exp[P] = {
32         0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,
33         0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa,
34         0xe5, 0x34, 0x5c, 0xe4, 0x37, 0x59, 0xeb, 0x26, 0x6a, 0xbe, 0xd9, 0x70, 0x90, 0xab, 0xe6, 0x31,
35         0x53, 0xf5, 0x04, 0x0c, 0x14, 0x3c, 0x44, 0xcc, 0x4f, 0xd1, 0x68, 0xb8, 0xd3, 0x6e, 0xb2, 0xcd,
36         0x4c, 0xd4, 0x67, 0xa9, 0xe0, 0x3b, 0x4d, 0xd7, 0x62, 0xa6, 0xf1, 0x08, 0x18, 0x28, 0x78, 0x88,
37         0x83, 0x9e, 0xb9, 0xd0, 0x6b, 0xbd, 0xdc, 0x7f, 0x81, 0x98, 0xb3, 0xce, 0x49, 0xdb, 0x76, 0x9a,
38         0xb5, 0xc4, 0x57, 0xf9, 0x10, 0x30, 0x50, 0xf0, 0x0b, 0x1d, 0x27, 0x69, 0xbb, 0xd6, 0x61, 0xa3,
39         0xfe, 0x19, 0x2b, 0x7d, 0x87, 0x92, 0xad, 0xec, 0x2f, 0x71, 0x93, 0xae, 0xe9, 0x20, 0x60, 0xa0,
40         0xfb, 0x16, 0x3a, 0x4e, 0xd2, 0x6d, 0xb7, 0xc2, 0x5d, 0xe7, 0x32, 0x56, 0xfa, 0x15, 0x3f, 0x41,
41         0xc3, 0x5e, 0xe2, 0x3d, 0x47, 0xc9, 0x40, 0xc0, 0x5b, 0xed, 0x2c, 0x74, 0x9c, 0xbf, 0xda, 0x75,
42         0x9f, 0xba, 0xd5, 0x64, 0xac, 0xef, 0x2a, 0x7e, 0x82, 0x9d, 0xbc, 0xdf, 0x7a, 0x8e, 0x89, 0x80,
43         0x9b, 0xb6, 0xc1, 0x58, 0xe8, 0x23, 0x65, 0xaf, 0xea, 0x25, 0x6f, 0xb1, 0xc8, 0x43, 0xc5, 0x54,
44         0xfc, 0x1f, 0x21, 0x63, 0xa5, 0xf4, 0x07, 0x09, 0x1b, 0x2d, 0x77, 0x99, 0xb0, 0xcb, 0x46, 0xca,
45         0x45, 0xcf, 0x4a, 0xde, 0x79, 0x8b, 0x86, 0x91, 0xa8, 0xe3, 0x3e, 0x42, 0xc6, 0x51, 0xf3, 0x0e,
46         0x12, 0x36, 0x5a, 0xee, 0x29, 0x7b, 0x8d, 0x8c, 0x8f, 0x8a, 0x85, 0x94, 0xa7, 0xf2, 0x0d, 0x17,
47         0x39, 0x4b, 0xdd, 0x7c, 0x84, 0x97, 0xa2, 0xfd, 0x1c, 0x24, 0x6c, 0xb4, 0xc7, 0x52, 0xf6, 0x01};
48 static const uint8_t log[P] = {
49         0x00, // log(0) is not defined
50         0xff, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7, 0x1b, 0x68, 0x33, 0xee, 0xdf, 0x03, 0x64,
51         0x04, 0xe0, 0x0e, 0x34, 0x8d, 0x81, 0xef, 0x4c, 0x71, 0x08, 0xc8, 0xf8, 0x69, 0x1c, 0xc1, 0x7d,
52         0xc2, 0x1d, 0xb5, 0xf9, 0xb9, 0x27, 0x6a, 0x4d, 0xe4, 0xa6, 0x72, 0x9a, 0xc9, 0x09, 0x78, 0x65,
53         0x2f, 0x8a, 0x05, 0x21, 0x0f, 0xe1, 0x24, 0x12, 0xf0, 0x82, 0x45, 0x35, 0x93, 0xda, 0x8e, 0x96,
54         0x8f, 0xdb, 0xbd, 0x36, 0xd0, 0xce, 0x94, 0x13, 0x5c, 0xd2, 0xf1, 0x40, 0x46, 0x83, 0x38, 0x66,
55         0xdd, 0xfd, 0x30, 0xbf, 0x06, 0x8b, 0x62, 0xb3, 0x25, 0xe2, 0x98, 0x22, 0x88, 0x91, 0x10, 0x7e,
56         0x6e, 0x48, 0xc3, 0xa3, 0xb6, 0x1e, 0x42, 0x3a, 0x6b, 0x28, 0x54, 0xfa, 0x85, 0x3d, 0xba, 0x2b,
57         0x79, 0x0a, 0x15, 0x9b, 0x9f, 0x5e, 0xca, 0x4e, 0xd4, 0xac, 0xe5, 0xf3, 0x73, 0xa7, 0x57, 0xaf,
58         0x58, 0xa8, 0x50, 0xf4, 0xea, 0xd6, 0x74, 0x4f, 0xae, 0xe9, 0xd5, 0xe7, 0xe6, 0xad, 0xe8, 0x2c,
59         0xd7, 0x75, 0x7a, 0xeb, 0x16, 0x0b, 0xf5, 0x59, 0xcb, 0x5f, 0xb0, 0x9c, 0xa9, 0x51, 0xa0, 0x7f,
60         0x0c, 0xf6, 0x6f, 0x17, 0xc4, 0x49, 0xec, 0xd8, 0x43, 0x1f, 0x2d, 0xa4, 0x76, 0x7b, 0xb7, 0xcc,
61         0xbb, 0x3e, 0x5a, 0xfb, 0x60, 0xb1, 0x86, 0x3b, 0x52, 0xa1, 0x6c, 0xaa, 0x55, 0x29, 0x9d, 0x97,
62         0xb2, 0x87, 0x90, 0x61, 0xbe, 0xdc, 0xfc, 0xbc, 0x95, 0xcf, 0xcd, 0x37, 0x3f, 0x5b, 0xd1, 0x53,
63         0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44,
64         0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67,
65         0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};
66
67 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
68 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) __attribute__((noinline));
69 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) {
70         uint8_t ret, ret2;
71         if (a == 0)
72                 ret2 = 0;
73         else
74                 ret2 = calc;
75         if (b == 0)
76                 ret = 0;
77         else
78                 ret = ret2;
79         return ret;
80 }
81 static uint8_t field_mul(uint8_t a, uint8_t b)  {
82         return field_mul_ret(exp[(log[a] + log[b]) % 255], a, b);
83 }
84
85 static uint8_t field_invert(uint8_t a) {
86         assert(a != 0);
87         return exp[0xff - log[a]]; // log[1] == 0xff
88 }
89
90 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
91 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) __attribute__((optimize("-O0"))) __attribute__((noinline));
92 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) {
93         uint8_t ret, ret2;
94         if (a == 0)
95                 ret2 = 0;
96         else
97                 ret2 = calc;
98         if (e == 0)
99                 ret = 1;
100         else
101                 ret = ret2;
102         return ret;
103 }
104 static uint8_t field_pow(uint8_t a, uint8_t e) {
105         return field_pow_ret(exp[(log[a] * e) % 255], a, e);
106 }
107
108 #ifdef TEST
109 static uint8_t field_mul_calc(uint8_t a, uint8_t b) {
110         // side-channel attacks here
111         uint8_t ret = 0;
112         uint8_t counter;
113         uint8_t carry;
114         for (counter = 0; counter < 8; counter++) {
115                 if (b & 1)
116                         ret ^= a;
117                 carry = (a & 0x80);
118                 a <<= 1;
119                 if (carry)
120                         a ^= 0x1b; // what x^8 is modulo x^8 + x^4 + x^3 + x + 1
121                 b >>= 1;
122         }
123         return ret;
124 }
125 static uint8_t field_pow_calc(uint8_t a, uint8_t e) {
126         uint8_t ret = 1;
127         for (uint8_t i = 0; i < e; i++)
128                 ret = field_mul_calc(ret, a);
129         return ret;
130 }
131 int main() {
132         // Test inversion with the logarithm tables
133         for (uint16_t i = 1; i < P; i++)
134                 assert(field_mul_calc(i, field_invert(i)) == 1);
135
136         // Test multiplication with the logarithm tables
137         for (uint16_t i = 0; i < P; i++) {
138                 for (uint16_t j = 0; j < P; j++)
139                         assert(field_mul(i, j) == field_mul_calc(i, j));
140         }
141
142         // Test exponentiation with the logarithm tables
143         for (uint16_t i = 0; i < P; i++) {
144                 for (uint16_t j = 0; j < P; j++)
145                         assert(field_pow(i, j) == field_pow_calc(i, j));
146         }
147 }
148 #endif // defined(TEST)
149
150
151
152 /*
153  * Calculations across the polynomial q
154  */
155 #ifndef TEST
156 static uint8_t calculateQ(uint8_t a[], uint8_t k, uint8_t x) {
157         assert(x != 0); // q(0) == secret, though so does a[0]
158         uint8_t ret = a[0];
159         for (uint8_t i = 1; i < k; i++) {
160                 ret = field_add(ret, field_mul(a[i], field_pow(x, i)));
161         }
162         return ret;
163 }
164
165 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t k) {
166         // Calculate the x^0 term using a derivation of the forumula at
167         // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
168         uint8_t ret = 0;
169         for (uint8_t i = 0; i < k; i++) {
170                 uint8_t temp = q[i];
171                 for (uint8_t j = 0; j < k; j++) {
172                         if (i == j)
173                                 continue;
174                         temp = field_mul(temp, field_neg(x[j]));
175                         temp = field_mul(temp, field_invert(field_sub(x[i], x[j])));
176                 }
177                 ret = field_add(ret, temp);
178         }
179         return ret;
180 }
181
182
183
184 int main(int argc, char* argv[]) {
185         assert(mlockall(MCL_CURRENT | MCL_FUTURE) == 0);
186
187         char split = 0;
188         uint8_t n = 0, k = 0;
189         char* files[P]; uint8_t files_count = 0;
190         char *in_file = (void*)0, *out_file_param = (void*)0;
191
192         int i;
193         while((i = getopt(argc, argv, "scn:k:f:o:i:h?")) != -1)
194                 switch(i) {
195                 case 's':
196                         if ((split & 0x2) && !(split & 0x1))
197                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
198                         else
199                                 split = (0x2 | 0x1);
200                         break;
201                 case 'c':
202                         if ((split & 0x2) && (split & 0x1))
203                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
204                         else
205                                 split = 0x2;
206                         break;
207                 case 'n': {
208                         int t = atoi(optarg);
209                         if (t <= 0 || t >= P)
210                                 ERROREXIT("n must be > 0 and < %u\n", P)
211                         else
212                                 n = t;
213                         break;
214                 }
215                 case 'k': {
216                         int t = atoi(optarg);
217                         if (t <= 0 || t >= P)
218                                 ERROREXIT("n must be > 0 and < %u\n", P)
219                         else
220                                 k = t;
221                         break;
222                 }
223                 case 'i':
224                         in_file = optarg;
225                         break;
226                 case 'o':
227                         out_file_param = optarg;
228                         break;
229                 case 'f':
230                         if (files_count >= P-1)
231                                 ERROREXIT("May only specify up to %u files\n", P-1)
232                         files[files_count++] = optarg;
233                         break;
234                 case 'h':
235                 case '?':
236                         printf("Split usage: -s -n <total shares> -k <shares required> -i <input file> -o <output file path base>\n");
237                         printf("Combine usage: -c -k <shares provided == shares required> <-f <share>>*k -o <output file>\n");
238                         exit(0);
239                         break;
240                 default:
241                         ERROREXIT("getopt failed?\n")
242                 }
243         if (!(split & 0x2))
244                 ERROREXIT("Must specify one of -c, -s or -?\n")
245         split &= 0x1;
246
247         if (argc != optind)
248                 ERROREXIT("Invalid argument\n")
249
250         if (split) {
251                 if (!n || !k)
252                         ERROREXIT("n and k must be set.\n")
253
254                 if (k > n)
255                         ERROREXIT("k must be <= n\n")
256
257                 if (files_count != 0 || !in_file || !out_file_param)
258                         ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
259
260                 FILE* random = fopen("/dev/random", "r");
261                 assert(random);
262                 FILE* secret_file = fopen(in_file, "r");
263                 if (!secret_file)
264                         ERROREXIT("Could not open %s for reading.\n", in_file)
265
266                 uint8_t secret[MAX_LENGTH];
267
268                 size_t secret_length = fread(secret, 1, MAX_LENGTH*sizeof(uint8_t), secret_file);
269                 if (secret_length == 0)
270                         ERROREXIT("Error reading secret\n")
271                 if (fread(secret, 1, 1, secret_file) > 0)
272                         ERROREXIT("Secret may not be longer than %u\n", MAX_LENGTH)
273                 fclose(secret_file);
274                 printf("Using secret of length %lu\n", secret_length);
275
276                 uint8_t a[k], D[n][secret_length];
277
278                 for (uint32_t i = 0; i < secret_length; i++) {
279                         a[0] = secret[i];
280
281                         for (uint8_t j = 1; j < k; j++)
282                                 assert(fread(&a[j], sizeof(uint8_t), 1, random) == 1);
283                         for (uint8_t j = 0; j < n; j++)
284                                 D[j][i] = calculateQ(a, k, j+1);
285
286                         if (i % 32 == 0 && i != 0)
287                                 printf("Finished processing %u bytes.\n", i);
288                 }
289
290                 char out_file_name_buf[strlen(out_file_param) + 4];
291                 strcpy(out_file_name_buf, out_file_param);
292                 for (uint8_t i = 0; i < n; i++) {
293                         /*printf("%u-", i);
294                         for (uint8_t j = 0; j < secret_length; j++)
295                                 printf("%02x", D[i][j]);
296                         printf("\n");*/
297
298                         sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
299                         FILE* out_file = fopen(out_file_name_buf, "w+");
300                         if (!out_file)
301                                 ERROREXIT("Could not open output file %s\n", out_file_name_buf)
302
303                         uint8_t x = i+1;
304                         if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
305                                 ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
306
307                         if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
308                                 ERROREXIT("Could not write %lu bytes to %s\n", secret_length, out_file_name_buf)
309
310                         fclose(out_file);
311                 }
312                 /*printf("secret = ");
313                 for (uint8_t i = 0; i < secret_length; i++)
314                         printf("%02x", secret[i]);
315                 printf("\n");*/
316
317                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
318                 memset(secret, 0, sizeof(uint8_t)*secret_length);
319                 memset(a, 0, sizeof(uint8_t)*k);
320                 memset(in_file, 0, strlen(in_file));
321
322                 fclose(random);
323         } else {
324                 if (!k)
325                         ERROREXIT("k must be set.\n")
326
327                 if (files_count != k || in_file || !out_file_param)
328                         ERROREXIT("Must not specify -i and must specify -o and exactly k -f <input file>s in combine mode.\n")
329
330                 uint8_t x[k], q[k];
331                 FILE* files_fps[k];
332
333                 for (uint8_t i = 0; i < k; i++) {
334                         files_fps[i] = fopen(files[i], "r");
335                         if (!files_fps[i])
336                                 ERROREXIT("Couldn't open file %s for reading.\n", files[i])
337                         if (fread(&x[i], sizeof(uint8_t), 1, files_fps[i]) != 1)
338                                 ERROREXIT("Couldn't read the x byte of %s\n", files[i])
339                 }
340
341                 uint8_t secret[MAX_LENGTH];
342
343                 uint32_t i = 0;
344                 while (fread(&q[0], sizeof(uint8_t), 1, files_fps[0]) == 1) {
345                         for (uint8_t j = 1; j < k; j++) {
346                                 if (fread(&q[j], sizeof(uint8_t), 1, files_fps[j]) != 1)
347                                         ERROREXIT("Couldn't read next byte from %s\n", files[j])
348                         }
349                         secret[i++] = calculateSecret(x, q, k);
350                 }
351                 printf("Got secret of length %u\n", i);
352
353                 FILE* out_file = fopen(out_file_param, "w+");
354                 fwrite(secret, sizeof(uint8_t), i, out_file);
355                 fclose(out_file);
356
357                 for (uint8_t i = 0; i < k; i++)
358                         fclose(files_fps[i]);
359
360                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
361                 memset(secret, 0, sizeof(uint8_t)*i);
362                 memset(q, 0, sizeof(uint8_t)*k);
363                 memset(out_file_param, 0, strlen(out_file_param));
364                 for (uint8_t i = 0; i < k; i++)
365                         memset(files[i], 0, strlen(files[i]));
366                 memset(x, 0, sizeof(uint8_t)*k);
367         }
368
369         return 0;
370 }
371 #endif // !defined(TEST)