Small cleanups
[shamirs] / main.c
diff --git a/main.c b/main.c
index af352fd09fb692da34db9b573460410ff5fba4b7..f79f074aec31327ca19b8cc0df7a151a0b8ecc81 100644 (file)
--- a/main.c
+++ b/main.c
@@ -3,18 +3,21 @@
  *
  * Copyright (C) 2013 Matt Corallo <git@bluematt.me>
  *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms and conditions of the GNU General Public License,
- * version 2, as published by the Free Software Foundation.
+ * This file is part of ASSS (Audit-friendly Shamir's Secret Sharing)
  *
- * This program is distributed in the hope it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
- * more details.
+ * ASSS is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of
+ * the License, or (at your option) any later version.
  *
- * You should have received a copy of the GNU General Public License along with
- * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
- * Place - Suite 330, Boston, MA 02111-1307 USA.
+ * ASSS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public
+ * License along with ASSS.  If not, see
+ * <http://www.gnu.org/licenses/>.
  */
 
 #define _GNU_SOURCE
 #define MAX_LENGTH 1024
 #define ERROREXIT(str...) {fprintf(stderr, str); exit(1);}
 
+#ifndef RAND_SOURCE
+#define RAND_SOURCE "/dev/random"
+#endif
+
 #ifndef TEST
-static void derive_missing_part(uint8_t total_shares, uint8_t shares_required, bool parts_have[], uint8_t* split_version, uint8_t split_index, uint8_t split_size) {
-       uint8_t (*D)[split_size] = (uint8_t (*)[split_size])split_version;
+static void derive_missing_part(uint8_t total_shares, uint8_t shares_required, bool parts_have[], const uint8_t* split_version, const uint8_t* split_x, uint8_t split_index, uint8_t split_size) {
+       const uint8_t (*D)[split_size] = (const uint8_t (*)[split_size])split_version;
        uint8_t x[shares_required], q[shares_required];
 
        // Fill in x/q with the selected shares
        uint16_t x_pos = 0;
        for (uint8_t i = 0; i < P-1; i++) {
                if (parts_have[i]) {
-                       x[x_pos] = i+1;
+                       x[x_pos] = split_x[i];
                        q[x_pos++] = D[i][split_index];
                }
        }
@@ -51,8 +58,6 @@ static void derive_missing_part(uint8_t total_shares, uint8_t shares_required, b
        // shares, because more shares could be added arbitrarily, any x should not be
        // able to rule out any possible secrets) and try each possible q, making sure
        // that each q gives us a new possibility for the secret.
-       bool impossible_secrets[P];
-       memset(impossible_secrets, 0, sizeof(impossible_secrets));
        for (uint16_t final_x = 1; final_x < P; final_x++) { 
                bool x_already_used = false;
                for (uint8_t j = 0; j < shares_required; j++) {
@@ -78,23 +83,23 @@ static void derive_missing_part(uint8_t total_shares, uint8_t shares_required, b
        memset(q, 0, sizeof(q));
 }
 
-static void check_possible_missing_part_derivations_intern(uint8_t total_shares, uint8_t shares_required, bool parts_have[], uint8_t parts_included, uint16_t progress, uint8_t* split_version, uint8_t split_index, uint8_t split_size) {
+static void check_possible_missing_part_derivations_intern(uint8_t total_shares, uint8_t shares_required, bool parts_have[], uint8_t parts_included, uint16_t progress, const uint8_t* split_version, const uint8_t* x, uint8_t split_index, uint8_t split_size) {
        if (parts_included == shares_required-1)
-               return derive_missing_part(total_shares, shares_required, parts_have, split_version, split_index, split_size);
+               return derive_missing_part(total_shares, shares_required, parts_have, split_version, x, split_index, split_size);
 
        if (total_shares - progress < shares_required)
                return;
 
-       check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, parts_included, progress+1, split_version, split_index, split_size);
+       check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, parts_included, progress+1, split_version, x, split_index, split_size);
        parts_have[progress] = 1;
-       check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, parts_included+1, progress+1, split_version, split_index, split_size);
+       check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, parts_included+1, progress+1, split_version, x, split_index, split_size);
        parts_have[progress] = 0;
 }
 
-static void check_possible_missing_part_derivations(uint8_t total_shares, uint8_t shares_required, uint8_t* split_version, uint8_t split_index, uint8_t split_size) {
+static void check_possible_missing_part_derivations(uint8_t total_shares, uint8_t shares_required, const uint8_t* split_version, const uint8_t* x, uint8_t split_index, uint8_t split_size) {
        bool parts_have[P];
        memset(parts_have, 0, sizeof(parts_have));
-       check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, 0, 0, split_version, split_index, split_size);
+       check_possible_missing_part_derivations_intern(total_shares, shares_required, parts_have, 0, 0, split_version, x, split_index, split_size);
 }
 
 
@@ -174,7 +179,7 @@ int main(int argc, char* argv[]) {
                if (files_count != 0 || !in_file || !out_file_param)
                        ERROREXIT("Must specify -i <input file> and -o <output file path base> but not -f in split mode.\n")
 
-               FILE* random = fopen("/dev/random", "r");
+               FILE* random = fopen(RAND_SOURCE, "r");
                assert(random);
                FILE* secret_file = fopen(in_file, "r");
                if (!secret_file)
@@ -190,38 +195,46 @@ int main(int argc, char* argv[]) {
                fclose(secret_file);
                printf("Using secret of length %lu\n", secret_length);
 
-               uint8_t a[shares_required], D[total_shares][secret_length];
-
+               uint8_t a[shares_required], x[total_shares], D[total_shares][secret_length];
+
+               // TODO: The following loop may take a long time and eat lots of /dev/random if total_shares is high
+               for (uint32_t i = 0; i < total_shares; i++) {
+                       int32_t j = -1;
+                       do {
+                               assert(fread(&x[i], sizeof(uint8_t), 1, random) == 1);
+                               if (x[i] == 0)
+                                       continue;
+                               for (j = 0; j < i; j++)
+                                       if (x[j] == x[i])
+                                               break;
+                       } while (j < (int32_t)i); // Inner loop will get to j = i when x[j] != x[i] for all j
+                       if (i % 32 == 31)
+                               printf("Finished picking X coordinates for %u shares\n", i+1);
+               }
                for (uint32_t i = 0; i < secret_length; i++) {
                        a[0] = secret[i];
 
                        for (uint8_t j = 1; j < shares_required; j++)
                                assert(fread(&a[j], sizeof(uint8_t), 1, random) == 1);
                        for (uint8_t j = 0; j < total_shares; j++)
-                               D[j][i] = calculateQ(a, shares_required, j+1);
+                               D[j][i] = calculateQ(a, shares_required, x[j]);
 
                        // Now, for paranoia's sake, we ensure that no matter which piece we are missing, we can derive no information about the secret
-                       check_possible_missing_part_derivations(total_shares, shares_required, &(D[0][0]), i, secret_length);
+                       check_possible_missing_part_derivations(total_shares, shares_required, &(D[0][0]), x, i, secret_length);
 
-                       if (i % 32 == 0 && i != 0)
-                               printf("Finished processing %u bytes.\n", i);
+                       if (i % 32 == 31)
+                               printf("Finished processing %u bytes.\n", i+1);
                }
 
                char out_file_name_buf[strlen(out_file_param) + 4];
                strcpy(out_file_name_buf, out_file_param);
                for (uint8_t i = 0; i < total_shares; i++) {
-                       /*printf("%u-", i);
-                       for (uint8_t j = 0; j < secret_length; j++)
-                               printf("%02x", D[i][j]);
-                       printf("\n");*/
-
                        sprintf(((char*)out_file_name_buf) + strlen(out_file_param), "%u", i);
                        FILE* out_file = fopen(out_file_name_buf, "w+");
                        if (!out_file)
                                ERROREXIT("Could not open output file %s\n", out_file_name_buf)
 
-                       uint8_t x = i+1;
-                       if (fwrite(&x, sizeof(uint8_t), 1, out_file) != 1)
+                       if (fwrite(&x[i], sizeof(uint8_t), 1, out_file) != 1)
                                ERROREXIT("Could not write 1 byte to %s\n", out_file_name_buf)
 
                        if (fwrite(D[i], 1, secret_length, out_file) != secret_length)
@@ -229,14 +242,11 @@ int main(int argc, char* argv[]) {
 
                        fclose(out_file);
                }
-               /*printf("secret = ");
-               for (uint8_t i = 0; i < secret_length; i++)
-                       printf("%02x", secret[i]);
-               printf("\n");*/
 
                // Clear sensitive data (No, GCC 4.7.2 is currently not optimizing this out)
                memset(secret, 0, sizeof(uint8_t)*secret_length);
                memset(a, 0, sizeof(uint8_t)*shares_required);
+               memset(x, 0, sizeof(uint8_t)*total_shares);
                memset(in_file, 0, strlen(in_file));
 
                fclose(random);