-Wall -Werror
[shamirs] / shamirssecret.c
1 #define _GNU_SOURCE
2
3 #include <stdint.h>
4 #include <stdio.h>
5 #include <assert.h>
6 #include <string.h>
7 #include <unistd.h>
8 #include <stdlib.h>
9 #include <sys/mman.h>
10
11 #define MAX_LENGTH 1024
12 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
13
14 /*
15  * Calculations across the finite field GF(2^8)
16  */
17 #define P 256
18
19 #ifndef TEST
20 static uint8_t field_add(uint8_t a, uint8_t b) {
21         return a ^ b;
22 }
23
24 static uint8_t field_sub(uint8_t a, uint8_t b) {
25         return a ^ b;
26 }
27
28 static uint8_t field_neg(uint8_t a) {
29         return field_sub(0, a);
30 }
31 #endif
32
33 static const uint8_t exp[P] = {
34         0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,
35         0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa,
36         0xe5, 0x34, 0x5c, 0xe4, 0x37, 0x59, 0xeb, 0x26, 0x6a, 0xbe, 0xd9, 0x70, 0x90, 0xab, 0xe6, 0x31,
37         0x53, 0xf5, 0x04, 0x0c, 0x14, 0x3c, 0x44, 0xcc, 0x4f, 0xd1, 0x68, 0xb8, 0xd3, 0x6e, 0xb2, 0xcd,
38         0x4c, 0xd4, 0x67, 0xa9, 0xe0, 0x3b, 0x4d, 0xd7, 0x62, 0xa6, 0xf1, 0x08, 0x18, 0x28, 0x78, 0x88,
39         0x83, 0x9e, 0xb9, 0xd0, 0x6b, 0xbd, 0xdc, 0x7f, 0x81, 0x98, 0xb3, 0xce, 0x49, 0xdb, 0x76, 0x9a,
40         0xb5, 0xc4, 0x57, 0xf9, 0x10, 0x30, 0x50, 0xf0, 0x0b, 0x1d, 0x27, 0x69, 0xbb, 0xd6, 0x61, 0xa3,
41         0xfe, 0x19, 0x2b, 0x7d, 0x87, 0x92, 0xad, 0xec, 0x2f, 0x71, 0x93, 0xae, 0xe9, 0x20, 0x60, 0xa0,
42         0xfb, 0x16, 0x3a, 0x4e, 0xd2, 0x6d, 0xb7, 0xc2, 0x5d, 0xe7, 0x32, 0x56, 0xfa, 0x15, 0x3f, 0x41,
43         0xc3, 0x5e, 0xe2, 0x3d, 0x47, 0xc9, 0x40, 0xc0, 0x5b, 0xed, 0x2c, 0x74, 0x9c, 0xbf, 0xda, 0x75,
44         0x9f, 0xba, 0xd5, 0x64, 0xac, 0xef, 0x2a, 0x7e, 0x82, 0x9d, 0xbc, 0xdf, 0x7a, 0x8e, 0x89, 0x80,
45         0x9b, 0xb6, 0xc1, 0x58, 0xe8, 0x23, 0x65, 0xaf, 0xea, 0x25, 0x6f, 0xb1, 0xc8, 0x43, 0xc5, 0x54,
46         0xfc, 0x1f, 0x21, 0x63, 0xa5, 0xf4, 0x07, 0x09, 0x1b, 0x2d, 0x77, 0x99, 0xb0, 0xcb, 0x46, 0xca,
47         0x45, 0xcf, 0x4a, 0xde, 0x79, 0x8b, 0x86, 0x91, 0xa8, 0xe3, 0x3e, 0x42, 0xc6, 0x51, 0xf3, 0x0e,
48         0x12, 0x36, 0x5a, 0xee, 0x29, 0x7b, 0x8d, 0x8c, 0x8f, 0x8a, 0x85, 0x94, 0xa7, 0xf2, 0x0d, 0x17,
49         0x39, 0x4b, 0xdd, 0x7c, 0x84, 0x97, 0xa2, 0xfd, 0x1c, 0x24, 0x6c, 0xb4, 0xc7, 0x52, 0xf6, 0x01};
50 static const uint8_t log[P] = {
51         0x00, // log(0) is not defined
52         0xff, 0x19, 0x01, 0x32, 0x02, 0x1a, 0xc6, 0x4b, 0xc7, 0x1b, 0x68, 0x33, 0xee, 0xdf, 0x03, 0x64,
53         0x04, 0xe0, 0x0e, 0x34, 0x8d, 0x81, 0xef, 0x4c, 0x71, 0x08, 0xc8, 0xf8, 0x69, 0x1c, 0xc1, 0x7d,
54         0xc2, 0x1d, 0xb5, 0xf9, 0xb9, 0x27, 0x6a, 0x4d, 0xe4, 0xa6, 0x72, 0x9a, 0xc9, 0x09, 0x78, 0x65,
55         0x2f, 0x8a, 0x05, 0x21, 0x0f, 0xe1, 0x24, 0x12, 0xf0, 0x82, 0x45, 0x35, 0x93, 0xda, 0x8e, 0x96,
56         0x8f, 0xdb, 0xbd, 0x36, 0xd0, 0xce, 0x94, 0x13, 0x5c, 0xd2, 0xf1, 0x40, 0x46, 0x83, 0x38, 0x66,
57         0xdd, 0xfd, 0x30, 0xbf, 0x06, 0x8b, 0x62, 0xb3, 0x25, 0xe2, 0x98, 0x22, 0x88, 0x91, 0x10, 0x7e,
58         0x6e, 0x48, 0xc3, 0xa3, 0xb6, 0x1e, 0x42, 0x3a, 0x6b, 0x28, 0x54, 0xfa, 0x85, 0x3d, 0xba, 0x2b,
59         0x79, 0x0a, 0x15, 0x9b, 0x9f, 0x5e, 0xca, 0x4e, 0xd4, 0xac, 0xe5, 0xf3, 0x73, 0xa7, 0x57, 0xaf,
60         0x58, 0xa8, 0x50, 0xf4, 0xea, 0xd6, 0x74, 0x4f, 0xae, 0xe9, 0xd5, 0xe7, 0xe6, 0xad, 0xe8, 0x2c,
61         0xd7, 0x75, 0x7a, 0xeb, 0x16, 0x0b, 0xf5, 0x59, 0xcb, 0x5f, 0xb0, 0x9c, 0xa9, 0x51, 0xa0, 0x7f,
62         0x0c, 0xf6, 0x6f, 0x17, 0xc4, 0x49, 0xec, 0xd8, 0x43, 0x1f, 0x2d, 0xa4, 0x76, 0x7b, 0xb7, 0xcc,
63         0xbb, 0x3e, 0x5a, 0xfb, 0x60, 0xb1, 0x86, 0x3b, 0x52, 0xa1, 0x6c, 0xaa, 0x55, 0x29, 0x9d, 0x97,
64         0xb2, 0x87, 0x90, 0x61, 0xbe, 0xdc, 0xfc, 0xbc, 0x95, 0xcf, 0xcd, 0x37, 0x3f, 0x5b, 0xd1, 0x53,
65         0x39, 0x84, 0x3c, 0x41, 0xa2, 0x6d, 0x47, 0x14, 0x2a, 0x9e, 0x5d, 0x56, 0xf2, 0xd3, 0xab, 0x44,
66         0x11, 0x92, 0xd9, 0x23, 0x20, 0x2e, 0x89, 0xb4, 0x7c, 0xb8, 0x26, 0x77, 0x99, 0xe3, 0xa5, 0x67,
67         0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};
68
69 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
70 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) __attribute__((noinline));
71 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) {
72         uint8_t ret, ret2;
73         if (a == 0)
74                 ret2 = 0;
75         else
76                 ret2 = calc;
77         if (b == 0)
78                 ret = 0;
79         else
80                 ret = ret2;
81         return ret;
82 }
83 static uint8_t field_mul(uint8_t a, uint8_t b)  {
84         return field_mul_ret(exp[(log[a] + log[b]) % 255], a, b);
85 }
86
87 static uint8_t field_invert(uint8_t a) {
88         assert(a != 0);
89         return exp[0xff - log[a]]; // log[1] == 0xff
90 }
91
92 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
93 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) __attribute__((optimize("-O0"))) __attribute__((noinline));
94 static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) {
95         uint8_t ret, ret2;
96         if (a == 0)
97                 ret2 = 0;
98         else
99                 ret2 = calc;
100         if (e == 0)
101                 ret = 1;
102         else
103                 ret = ret2;
104         return ret;
105 }
106 static uint8_t field_pow(uint8_t a, uint8_t e) {
107 #ifndef TEST
108         // Although this function works for a==0, its not trivially obvious why,
109         // and since we never call with a==0, we just assert a != 0 (except when testing)
110         assert(a != 0);
111 #endif
112         return field_pow_ret(exp[(log[a] * e) % 255], a, e);
113 }
114
115 #ifdef TEST
116 static uint8_t field_mul_calc(uint8_t a, uint8_t b) {
117         // side-channel attacks here
118         uint8_t ret = 0;
119         uint8_t counter;
120         uint8_t carry;
121         for (counter = 0; counter < 8; counter++) {
122                 if (b & 1)
123                         ret ^= a;
124                 carry = (a & 0x80);
125                 a <<= 1;
126                 if (carry)
127                         a ^= 0x1b; // what x^8 is modulo x^8 + x^4 + x^3 + x + 1
128                 b >>= 1;
129         }
130         return ret;
131 }
132 static uint8_t field_pow_calc(uint8_t a, uint8_t e) {
133         uint8_t ret = 1;
134         for (uint8_t i = 0; i < e; i++)
135                 ret = field_mul_calc(ret, a);
136         return ret;
137 }
138 int main() {
139         // Test inversion with the logarithm tables
140         for (uint16_t i = 1; i < P; i++)
141                 assert(field_mul_calc(i, field_invert(i)) == 1);
142
143         // Test multiplication with the logarithm tables
144         for (uint16_t i = 0; i < P; i++) {
145                 for (uint16_t j = 0; j < P; j++)
146                         assert(field_mul(i, j) == field_mul_calc(i, j));
147         }
148
149         // Test exponentiation with the logarithm tables
150         for (uint16_t i = 0; i < P; i++) {
151                 for (uint16_t j = 0; j < P; j++)
152                         assert(field_pow(i, j) == field_pow_calc(i, j));
153         }
154 }
155 #endif // defined(TEST)
156
157
158
159 /*
160  * Calculations across the polynomial q
161  */
162 #ifndef TEST
163 static uint8_t calculateQ(uint8_t a[], uint8_t k, uint8_t x) {
164         assert(x != 0); // q(0) == secret, though so does a[0]
165         uint8_t ret = a[0];
166         for (uint8_t i = 1; i < k; i++) {
167                 ret = field_add(ret, field_mul(a[i], field_pow(x, i)));
168         }
169         return ret;
170 }
171
172 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t k) {
173         // Calculate the x^0 term using a derivation of the forumula at
174         // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
175         uint8_t ret = 0;
176         for (uint8_t i = 0; i < k; i++) {
177                 uint8_t temp = q[i];
178                 for (uint8_t j = 0; j < k; j++) {
179                         if (i == j)
180                                 continue;
181                         temp = field_mul(temp, field_neg(x[j]));
182                         temp = field_mul(temp, field_invert(field_sub(x[i], x[j])));
183                 }
184                 ret = field_add(ret, temp);
185         }
186         return ret;
187 }
188
189
190
191 int main(int argc, char* argv[]) {
192         assert(mlockall(MCL_CURRENT | MCL_FUTURE) == 0);
193
194         char split = 0;
195         uint8_t total_shares = 0, shares_required = 0;
196         char* files[P]; uint8_t files_count = 0;
197         char *in_file = (void*)0, *out_file_param = (void*)0;
198
199         int i;
200         while((i = getopt(argc, argv, "scn:k:f:o:i:h?")) != -1)
201                 switch(i) {
202                 case 's':
203                         if ((split & 0x2) && !(split & 0x1))
204                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
205                         else
206                                 split = (0x2 | 0x1);
207                         break;
208                 case 'c':
209                         if ((split & 0x2) && (split & 0x1))
210                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
211                         else
212                                 split = 0x2;
213                         break;
214                 case 'n': {
215                         int t = atoi(optarg);
216                         if (t <= 0 || t >= P)
217                                 ERROREXIT("n must be > 0 and < %u\n", P)
218                         else
219                                 total_shares = t;
220                         break;
221                 }
222                 case 'k': {
223                         int t = atoi(optarg);
224                         if (t <= 0 || t >= P)
225                                 ERROREXIT("n must be > 0 and < %u\n", P)
226                         else
227                                 shares_required = t;
228                         break;
229                 }
230                 case 'i':
231                         in_file = optarg;
232                         break;
233                 case 'o':
234                         out_file_param = optarg;
235                         break;
236                 case 'f':
237                         if (files_count >= P-1)
238                                 ERROREXIT("May only specify up to %u files\n", P-1)
239                         files[files_count++] = optarg;
240                         break;
241                 case 'h':
242                 case '?':
243                         printf("Split usage: -s -n <total shares> -k <shares required> -i <input file> -o <output file path base>\n");
244                         printf("Combine usage: -c -k <shares provided == shares required> <-f <share>>*k -o <output file>\n");
245                         exit(0);
246                         break;
247                 default:
248                         ERROREXIT("getopt failed?\n")
249                 }
250         if (!(split & 0x2))
251                 ERROREXIT("Must specify one of -c, -s or -?\n")
252         split &= 0x1;
253
254         if (argc != optind)
255                 ERROREXIT("Invalid argument\n")
256
257         if (split) {
258                 if (!total_shares || !shares_required)
259                         ERROREXIT("n and k must be set.\n")
260
261                 if (shares_required > total_shares)
262                         ERROREXIT("k must be <= n\n")
263
264                 if (files_count != 0 || !in_file || !out_file_param)
265                         ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
266
267                 FILE* random = fopen("/dev/random", "r");
268                 assert(random);
269                 FILE* secret_file = fopen(in_file, "r");
270                 if (!secret_file)
271                         ERROREXIT("Could not open %s for reading.\n", in_file)
272
273                 uint8_t secret[MAX_LENGTH];
274
275                 size_t secret_length = fread(secret, 1, MAX_LENGTH*sizeof(uint8_t), secret_file);
276                 if (secret_length == 0)
277                         ERROREXIT("Error reading secret\n")
278                 if (fread(secret, 1, 1, secret_file) > 0)
279                         ERROREXIT("Secret may not be longer than %u\n", MAX_LENGTH)
280                 fclose(secret_file);
281                 printf("Using secret of length %lu\n", secret_length);
282
283                 uint8_t a[shares_required], D[total_shares][secret_length];
284
285                 for (uint32_t i = 0; i < secret_length; i++) {
286                         a[0] = secret[i];
287
288                         for (uint8_t j = 1; j < shares_required; j++)
289                                 assert(fread(&a[j], sizeof(uint8_t), 1, random) == 1);
290                         for (uint8_t j = 0; j < total_shares; j++)
291                                 D[j][i] = calculateQ(a, shares_required, j+1);
292
293                         if (i % 32 == 0 && i != 0)
294                                 printf("Finished processing %u bytes.\n", i);
295                 }
296
297                 char out_file_name_buf[strlen(out_file_param) + 4];
298                 strcpy(out_file_name_buf, out_file_param);
299                 for (uint8_t i = 0; i < total_shares; i++) {
300                         /*printf("%u-", i);
301                         for (uint8_t j = 0; j < secret_length; j++)
302                                 printf("%02x", D[i][j]);
303                         printf("\n");*/
304
305                         sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
306                         FILE* out_file = fopen(out_file_name_buf, "w+");
307                         if (!out_file)
308                                 ERROREXIT("Could not open output file %s\n", out_file_name_buf)
309
310                         uint8_t x = i+1;
311                         if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
312                                 ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
313
314                         if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
315                                 ERROREXIT("Could not write %lu bytes to %s\n", secret_length, out_file_name_buf)
316
317                         fclose(out_file);
318                 }
319                 /*printf("secret = ");
320                 for (uint8_t i = 0; i < secret_length; i++)
321                         printf("%02x", secret[i]);
322                 printf("\n");*/
323
324                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
325                 memset(secret, 0, sizeof(uint8_t)*secret_length);
326                 memset(a, 0, sizeof(uint8_t)*shares_required);
327                 memset(in_file, 0, strlen(in_file));
328
329                 fclose(random);
330         } else {
331                 if (!shares_required)
332                         ERROREXIT("k must be set.\n")
333
334                 if (files_count != shares_required || in_file || !out_file_param)
335                         ERROREXIT("Must not specify -i and must specify -o and exactly k -f <input file>s in combine mode.\n")
336
337                 uint8_t x[shares_required], q[shares_required];
338                 FILE* files_fps[shares_required];
339
340                 for (uint8_t i = 0; i < shares_required; i++) {
341                         files_fps[i] = fopen(files[i], "r");
342                         if (!files_fps[i])
343                                 ERROREXIT("Couldn't open file %s for reading.\n", files[i])
344                         if (fread(&x[i], sizeof(uint8_t), 1, files_fps[i]) != 1)
345                                 ERROREXIT("Couldn't read the x byte of %s\n", files[i])
346                 }
347
348                 uint8_t secret[MAX_LENGTH];
349
350                 uint32_t i = 0;
351                 while (fread(&q[0], sizeof(uint8_t), 1, files_fps[0]) == 1) {
352                         for (uint8_t j = 1; j < shares_required; j++) {
353                                 if (fread(&q[j], sizeof(uint8_t), 1, files_fps[j]) != 1)
354                                         ERROREXIT("Couldn't read next byte from %s\n", files[j])
355                         }
356                         secret[i++] = calculateSecret(x, q, shares_required);
357                 }
358                 printf("Got secret of length %u\n", i);
359
360                 FILE* out_file = fopen(out_file_param, "w+");
361                 fwrite(secret, sizeof(uint8_t), i, out_file);
362                 fclose(out_file);
363
364                 for (uint8_t i = 0; i < shares_required; i++)
365                         fclose(files_fps[i]);
366
367                 // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
368                 memset(secret, 0, sizeof(uint8_t)*i);
369                 memset(q, 0, sizeof(uint8_t)*shares_required);
370                 memset(out_file_param, 0, strlen(out_file_param));
371                 for (uint8_t i = 0; i < shares_required; i++)
372                         memset(files[i], 0, strlen(files[i]));
373                 memset(x, 0, sizeof(uint8_t)*shares_required);
374         }
375
376         return 0;
377 }
378 #endif // !defined(TEST)