893d9787f88a5cdd558e0557027f6ffffc08c74a
[shamirs] / shamirssecret.c
1 #define _GNU_SOURCE
2
3 #include <stdint.h>
4 #include <stdio.h>
5 #include <assert.h>
6 #include <string.h>
7 #include <unistd.h>
8 #include <stdlib.h>
9 #include <sys/mman.h>
10
11 #define MAX_LENGTH 1024
12
13 /*
14  * Calculations across the finite field GF(p)
15  * Lots of side-channel attacks in here
16  */
17 /*
18 const uint8_t p = 251; // Largest 8-bit prime
19 static uint8_t field_add(uint8_t a, uint8_t b) {
20         assert(a < p && b < p);
21         return (((uint16_t)a) + ((uint16_t)b)) % p;
22 }
23
24 static uint8_t field_sub(uint8_t a, uint8_t b) {
25         assert(a < p && b < p);
26         return (((uint16_t)p) + ((uint16_t)a) - ((uint16_t)b)) % p;
27 }
28
29 static uint8_t field_mul(uint8_t a, uint8_t b) {
30         assert(a < p && b < p);
31         return (((uint16_t)a) * ((uint16_t)b)) % p;
32 }
33
34 static uint8_t field_pow(uint8_t a, uint8_t e) {
35         assert(a < p);
36         uint8_t ret = 1;
37         for (uint8_t i = 0; i < e; i++) {
38                 ret = field_mul(ret, a);
39         }
40         return ret;
41 }
42
43 static uint8_t field_neg(uint8_t a) {
44         assert(a < p);
45         if (a == 0)
46                 return 0;
47         return p - a;
48 }
49
50 static uint8_t field_invert(uint8_t a) {
51         // Brute force, yay!
52         assert(a < p);
53         for (uint8_t i = 0; i < p; i++) {
54                 if (field_mul(i, a) == 1)
55                         return i;
56         }
57         assert(0);
58 }*/
59
60
61
62 /*
63  * Calculations across the finite field GF(2^8)
64  */
65 const uint16_t p = 256;
66 static uint8_t field_add(uint8_t a, uint8_t b) {
67         return a ^ b;
68 }
69
70 static uint8_t field_sub(uint8_t a, uint8_t b) {
71         return a ^ b;
72 }
73
74 static uint8_t field_mul(uint8_t a, uint8_t b) {
75         // TODO side-channel attacks here?
76         uint8_t ret = 0;
77         uint8_t counter;
78         uint8_t carry;
79         for (counter = 0; counter < 8; counter++) {
80                 if (b & 1)
81                         ret ^= a;
82                 carry = (a & 0x80);
83                 a <<= 1;
84                 if (carry)
85                         a ^= 0x1b; /* what x^8 is modulo x^8 + x^4 + x^3 + x + 1 */
86                 b >>= 1;
87         }
88         return ret;
89 }
90
91 // WARNING: Do not use if e is secret (potential side-channel attacks)
92 static uint8_t field_pow(uint8_t a, uint8_t e) {
93         // TODO: This could be sped up pretty trivially
94         uint8_t ret = 1;
95         for (uint8_t i = 0; i < e; i++)
96                 ret = field_mul(ret, a);
97         return ret;
98 }
99
100 static uint8_t field_neg(uint8_t a) {
101         return field_sub(0, a);
102 }
103
104 static const uint8_t inverse[] = { // Multiplicative inverse of each element in the field
105         0xff, // 0 has no inverse
106         0x01, 0x8d, 0xf6, 0xcb, 0x52, 0x7b, 0xd1, 0xe8, 0x4f, 0x29, 0xc0, 0xb0, 0xe1, 0xe5, 0xc7, 0x74,
107         0xb4, 0xaa, 0x4b, 0x99, 0x2b, 0x60, 0x5f, 0x58, 0x3f, 0xfd, 0xcc, 0xff, 0x40, 0xee, 0xb2, 0x3a,
108         0x6e, 0x5a, 0xf1, 0x55, 0x4d, 0xa8, 0xc9, 0xc1, 0x0a, 0x98, 0x15, 0x30, 0x44, 0xa2, 0xc2, 0x2c,
109         0x45, 0x92, 0x6c, 0xf3, 0x39, 0x66, 0x42, 0xf2, 0x35, 0x20, 0x6f, 0x77, 0xbb, 0x59, 0x19, 0x1d,
110         0xfe, 0x37, 0x67, 0x2d, 0x31, 0xf5, 0x69, 0xa7, 0x64, 0xab, 0x13, 0x54, 0x25, 0xe9, 0x09, 0xed,
111         0x5c, 0x05, 0xca, 0x4c, 0x24, 0x87, 0xbf, 0x18, 0x3e, 0x22, 0xf0, 0x51, 0xec, 0x61, 0x17, 0x16,
112         0x5e, 0xaf, 0xd3, 0x49, 0xa6, 0x36, 0x43, 0xf4, 0x47, 0x91, 0xdf, 0x33, 0x93, 0x21, 0x3b, 0x79,
113         0xb7, 0x97, 0x85, 0x10, 0xb5, 0xba, 0x3c, 0xb6, 0x70, 0xd0, 0x06, 0xa1, 0xfa, 0x81, 0x82, 0x83,
114         0x7e, 0x7f, 0x80, 0x96, 0x73, 0xbe, 0x56, 0x9b, 0x9e, 0x95, 0xd9, 0xf7, 0x02, 0xb9, 0xa4, 0xde,
115         0x6a, 0x32, 0x6d, 0xd8, 0x8a, 0x84, 0x72, 0x2a, 0x14, 0x9f, 0x88, 0xf9, 0xdc, 0x89, 0x9a, 0xfb,
116         0x7c, 0x2e, 0xc3, 0x8f, 0xb8, 0x65, 0x48, 0x26, 0xc8, 0x12, 0x4a, 0xce, 0xe7, 0xd2, 0x62, 0x0c,
117         0xe0, 0x1f, 0xef, 0x11, 0x75, 0x78, 0x71, 0xa5, 0x8e, 0x76, 0x3d, 0xbd, 0xbc, 0x86, 0x57, 0x0b,
118         0x28, 0x2f, 0xa3, 0xda, 0xd4, 0xe4, 0x0f, 0xa9, 0x27, 0x53, 0x04, 0x1b, 0xfc, 0xac, 0xe6, 0x7a,
119         0x07, 0xae, 0x63, 0xc5, 0xdb, 0xe2, 0xea, 0x94, 0x8b, 0xc4, 0xd5, 0x9d, 0xf8, 0x90, 0x6b, 0xb1,
120         0x0d, 0xd6, 0xeb, 0xc6, 0x0e, 0xcf, 0xad, 0x08, 0x4e, 0xd7, 0xe3, 0x5d, 0x50, 0x1e, 0xb3, 0x5b,
121         0x23, 0x38, 0x34, 0x68, 0x46, 0x03, 0x8c, 0xdd, 0x9c, 0x7d, 0xa0, 0xcd, 0x1a, 0x41, 0x1c};
122 static uint8_t field_invert(uint8_t a) {
123         assert(a != 0);
124         return inverse[a];
125 }
126
127
128
129 /*
130  * Calculations across the polynomial q
131  */
132 static uint8_t calculateQ(uint8_t a[], uint8_t k, uint8_t x) {
133         assert(x != 0); // q(0) == secret, though so does a[0]
134         uint8_t ret = a[0];
135         for (uint8_t i = 1; i < k; i++) {
136                 ret = field_add(ret, field_mul(a[i], field_pow(x, i)));
137         }
138         return ret;
139 }
140
141 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t k) {
142         // Calculate the x^0 term using a derivation of the forumula at
143         // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
144         uint8_t ret = 0;
145         for (uint8_t i = 0; i < k; i++) {
146                 uint8_t temp = q[i];
147                 for (uint8_t j = 0; j < k; j++) {
148                         if (i == j)
149                                 continue;
150                         temp = field_mul(temp, field_neg(x[j]));
151                         temp = field_mul(temp, field_invert(field_sub(x[i], x[j])));
152                 }
153                 ret = field_add(ret, temp);
154         }
155         return ret;
156 }
157
158 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
159
160 int main(int argc, char* argv[]) {
161         assert(mlockall(MCL_CURRENT | MCL_FUTURE) == 0);
162
163         char split = 0;
164         uint8_t n = 0, k = 0;
165         char* files[p]; uint8_t files_count = 0;
166         char *in_file = (void*)0, *out_file_param = (void*)0;
167
168         int i;
169         while((i = getopt(argc, argv, "scn:k:f:o:i:h?")) != -1)
170                 switch(i) {
171                 case 's':
172                         if ((split & 0x2) && !(split & 0x1))
173                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
174                         else
175                                 split = (0x2 | 0x1);
176                         break;
177                 case 'c':
178                         if ((split & 0x2) && (split & 0x1))
179                                 ERROREXIT("-s (split) and -c (combine) are mutually exclusive\n")
180                         else
181                                 split = 0x2;
182                         break;
183                 case 'n': {
184                         int t = atoi(optarg);
185                         if (t <= 0 || t >= p)
186                                 ERROREXIT("n must be > 0 and < %u\n", p)
187                         else
188                                 n = t;
189                         break;
190                 }
191                 case 'k': {
192                         int t = atoi(optarg);
193                         if (t <= 0 || t >= p)
194                                 ERROREXIT("n must be > 0 and < %u\n", p)
195                         else
196                                 k = t;
197                         break;
198                 }
199                 case 'i':
200                         in_file = optarg;
201                         break;
202                 case 'o':
203                         out_file_param = optarg;
204                         break;
205                 case 'f':
206                         if (files_count >= p-1)
207                                 ERROREXIT("May only specify up to %u files\n", p-1)
208                         files[files_count++] = optarg;
209                         break;
210                 case 'h':
211                 case '?':
212                         printf("Split usage: -s -n <total shares> -k <shares required> -i <input file> -o <output file path base>\n");
213                         printf("Combine usage: -c -k <shares provided == shares required> <-f <share>>*k -o <output file>\n");
214                         exit(0);
215                         break;
216                 default:
217                         ERROREXIT("getopt failed?\n")
218                 }
219         if (!(split & 0x2))
220                 ERROREXIT("Must specify either -c or -s\n")
221         split &= 0x1;
222
223         if (argc != optind)
224                 ERROREXIT("Invalid argument\n")
225
226         if (split) {
227                 if (!n || !k)
228                         ERROREXIT("n and k must be set.\n")
229
230                 if (k > n)
231                         ERROREXIT("k must be <= n\n")
232
233                 if (files_count != 0 || !in_file || !out_file_param)
234                         ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
235
236                 FILE* random = fopen("/dev/random", "r");
237                 assert(random);
238                 FILE* secret_file = fopen(in_file, "r");
239                 if (!secret_file)
240                         ERROREXIT("Could not open %s for reading.\n", in_file)
241
242                 uint8_t secret[MAX_LENGTH];
243                 memset(secret, 0, MAX_LENGTH*sizeof(uint8_t));
244
245                 size_t secret_length = fread(secret, 1, MAX_LENGTH*sizeof(uint8_t), secret_file);
246                 if (secret_length == 0)
247                         ERROREXIT("Error reading secret\n")
248                 if (fread(secret, 1, 1, secret_file) > 0)
249                         ERROREXIT("Secret may not be longer than %u\n", MAX_LENGTH)
250                 fclose(secret_file);
251                 printf("Using secret of length %lu\n", secret_length);
252
253                 uint8_t a[secret_length][k], D[n][secret_length];
254
255                 for (uint8_t i = 0; i < secret_length; i++) {
256                         a[i][0] = secret[i];
257
258                         for (uint8_t j = 1; j < k; j++) {
259                                 do
260                                         assert(fread(&a[i][j], sizeof(uint8_t), 1, random) == 1);
261                                 while (a[i][j] >= p);
262                         }
263                         for (uint8_t j = 0; j < n; j++)
264                                 D[j][i] = calculateQ(a[i], k, j+1);
265                 }
266
267                 char out_file_name_buf[strlen(out_file_param) + 4];
268                 strcpy(out_file_name_buf, out_file_param);
269                 for (uint8_t i = 0; i < n; i++) {
270                         /*printf("%u-", i);
271                         for (uint8_t j = 0; j < secret_length; j++)
272                                 printf("%02x", D[i][j]);
273                         printf("\n");*/
274
275                         sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
276                         FILE* out_file = fopen(out_file_name_buf, "w+");
277                         if (!out_file)
278                                 ERROREXIT("Could not open output file %s\n", out_file_name_buf)
279
280                         uint8_t x = i+1;
281                         if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
282                                 ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
283
284                         if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
285                                 ERROREXIT("Could not write %lu bytes to %s\n", secret_length, out_file_name_buf)
286
287                         fclose(out_file);
288                 }
289                 /*printf("secret = ");
290                 for (uint8_t i = 0; i < secret_length; i++)
291                         printf("%02x", secret[i]);
292                 printf("\n");*/
293
294                 fclose(random);
295         } else {
296                 if (!k)
297                         ERROREXIT("k must be set.\n")
298
299                 if (files_count != k || in_file || !out_file_param)
300                         ERROREXIT("Must not specify -i and must specify -o and exactly k -f <input file>s in combine mode.\n")
301
302                 uint8_t x[k], q[k];
303                 FILE* files_fps[k];
304
305                 for (uint8_t i = 0; i < k; i++) {
306                         files_fps[i] = fopen(files[i], "r");
307                         if (!files_fps[i])
308                                 ERROREXIT("Couldn't open file %s for reading.\n", files[i])
309                         if (fread(&x[i], sizeof(uint8_t), 1, files_fps[i]) != 1)
310                                 ERROREXIT("Couldn't read the x byte of %s\n", files[i])
311                 }
312
313                 uint8_t secret[MAX_LENGTH];
314
315                 uint8_t i = 0;
316                 while (fread(&q[0], sizeof(uint8_t), 1, files_fps[0]) == 1) {
317                         for (uint8_t j = 1; j < k; j++) {
318                                 if (fread(&q[j], sizeof(uint8_t), 1, files_fps[j]) != 1)
319                                         ERROREXIT("Couldn't read next byte from %s\n", files[j])
320                         }
321                         secret[i++] = calculateSecret(x, q, k);
322                 }
323                 printf("Got secret of length %u\n", i);
324
325                 FILE* out_file = fopen(out_file_param, "w+");
326                 fwrite(secret, sizeof(uint8_t), i, out_file);
327                 fclose(out_file);
328
329                 for (uint8_t i = 0; i < k; i++)
330                         fclose(files_fps[i]);
331         }
332
333         return 0;
334 }