Small cleanups
[shamirs] / shamirssecret.c
index 77bc0bff8bd7d2eb2dda2f05c36c9b75e0301112..fd309f9386ea06d66feeba3d2b7a13c0f0dd5a94 100644 (file)
@@ -1,13 +1,43 @@
-#include <stdint.h>
+/*
+ * Shamir's secret sharing implementation
+ *
+ * Copyright (C) 2013 Matt Corallo <git@bluematt.me>
+ *
+ * This file is part of ASSS (Audit-friendly Shamir's Secret Sharing)
+ *
+ * ASSS is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of
+ * the License, or (at your option) any later version.
+ *
+ * ASSS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public
+ * License along with ASSS.  If not, see
+ * <http://www.gnu.org/licenses/>.
+ */
+
+#ifndef IN_KERNEL
 #include <assert.h>
+#define CHECKSTATE(x) assert(x)
+#else
+#include <linux/bug.h>
+#define CHECKSTATE(x) BUG_ON(!(x))
+#endif
 
 #include "shamirssecret.h"
 
+#ifndef noinline
+#define noinline __attribute__((noinline))
+#endif
+
 /*
  * Calculations across the finite field GF(2^8)
  */
 
-#ifndef TEST
 static uint8_t field_add(uint8_t a, uint8_t b) {
        return a ^ b;
 }
@@ -19,8 +49,10 @@ static uint8_t field_sub(uint8_t a, uint8_t b) {
 static uint8_t field_neg(uint8_t a) {
        return field_sub(0, a);
 }
-#endif
 
+//TODO: Using static tables will very likely create side-channel attacks when measuring cache hits
+//      Because these are fairly small tables, we can probably get them loaded mostly/fully into
+//      cache before use to break such attacks.
 static const uint8_t exp[P] = {
        0x01, 0x03, 0x05, 0x0f, 0x11, 0x33, 0x55, 0xff, 0x1a, 0x2e, 0x72, 0x96, 0xa1, 0xf8, 0x13, 0x35,
        0x5f, 0xe1, 0x38, 0x48, 0xd8, 0x73, 0x95, 0xa4, 0xf7, 0x02, 0x06, 0x0a, 0x1e, 0x22, 0x66, 0xaa,
@@ -58,7 +90,7 @@ static const uint8_t log[P] = {
        0x4a, 0xed, 0xde, 0xc5, 0x31, 0xfe, 0x18, 0x0d, 0x63, 0x8c, 0x80, 0xc0, 0xf7, 0x70, 0x07};
 
 // We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
-static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) __attribute__((noinline));
+static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) __attribute__((optimize("-O0"))) noinline;
 static uint8_t field_mul_ret(uint8_t calc, uint8_t a, uint8_t b) {
        uint8_t ret, ret2;
        if (a == 0)
@@ -76,31 +108,21 @@ static uint8_t field_mul(uint8_t a, uint8_t b)  {
 }
 
 static uint8_t field_invert(uint8_t a) {
-       assert(a != 0);
+       CHECKSTATE(a != 0);
        return exp[0xff - log[a]]; // log[1] == 0xff
 }
 
-// We disable lots of optimizations that result in non-constant runtime (+/- branch delays)
-static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) __attribute__((optimize("-O0"))) __attribute__((noinline));
-static uint8_t field_pow_ret(uint8_t calc, uint8_t a, uint8_t e) {
-       uint8_t ret, ret2;
-       if (a == 0)
-               ret2 = 0;
-       else
-               ret2 = calc;
-       if (e == 0)
-               ret = 1;
-       else
-               ret = ret2;
-       return ret;
-}
 static uint8_t field_pow(uint8_t a, uint8_t e) {
+       uint8_t ret = exp[(log[a] * e) % 255];
 #ifndef TEST
-       // Although this function works for a==0, its not trivially obvious why,
-       // and since we never call with a==0, we just assert a != 0 (except when testing)
-       assert(a != 0);
+       // We only work for a == 0 by branching (below), but since we
+       // never call with a==0, we just assert a != 0 (except when testing)
+       CHECKSTATE(a != 0);
+#else
+       if (a == 0 && e != 0)
+               ret = 0;
 #endif
-       return field_pow_ret(exp[(log[a] * e) % 255], a, e);
+       return ret;
 }
 
 #ifdef TEST
@@ -129,18 +151,29 @@ static uint8_t field_pow_calc(uint8_t a, uint8_t e) {
 int main() {
        // Test inversion with the logarithm tables
        for (uint16_t i = 1; i < P; i++)
-               assert(field_mul_calc(i, field_invert(i)) == 1);
+               CHECKSTATE(field_mul_calc(i, field_invert(i)) == 1);
 
        // Test multiplication with the logarithm tables
        for (uint16_t i = 0; i < P; i++) {
                for (uint16_t j = 0; j < P; j++)
-                       assert(field_mul(i, j) == field_mul_calc(i, j));
+                       CHECKSTATE(field_mul(i, j) == field_mul_calc(i, j));
        }
 
        // Test exponentiation with the logarithm tables
        for (uint16_t i = 0; i < P; i++) {
                for (uint16_t j = 0; j < P; j++)
-                       assert(field_pow(i, j) == field_pow_calc(i, j));
+                       CHECKSTATE(field_pow(i, j) == field_pow_calc(i, j));
+       }
+
+       // Test invertibility of add/negate/subtract
+       for (uint16_t i = 0; i < P; i++) {
+               CHECKSTATE(field_neg(field_neg(i)) == i);
+               // Test add/sub commutativity
+               for (uint16_t j = 0; j < P; j++) {
+                       CHECKSTATE(field_add(i, j) == field_add(j, i));
+                       CHECKSTATE(field_add(i, field_neg(j)) == field_sub(i, j));
+                       CHECKSTATE(field_add(field_neg(j), i) == field_sub(i, j));
+               }
        }
 }
 #endif // defined(TEST)
@@ -156,9 +189,9 @@ int main() {
  * coefficients[0] == secret, the rest are random values
  */
 uint8_t calculateQ(uint8_t coefficients[], uint8_t shares_required, uint8_t x) {
-       assert(x != 0); // q(0) == secret, though so does a[0]
-       uint8_t ret = coefficients[0];
-       for (uint8_t i = 1; i < shares_required; i++) {
+       uint8_t ret = coefficients[0], i;
+       CHECKSTATE(x != 0); // q(0) == secret, though so does a[0]
+       for (i = 1; i < shares_required; i++) {
                ret = field_add(ret, field_mul(coefficients[i], field_pow(x, i)));
        }
        return ret;
@@ -168,12 +201,12 @@ uint8_t calculateQ(uint8_t coefficients[], uint8_t shares_required, uint8_t x) {
  * Derives the secret given a set of shares_required points (x and q coordinates)
  */
 uint8_t calculateSecret(uint8_t x[], uint8_t q[], uint8_t shares_required) {
-       // Calculate the x^0 term using a derivation of the forumula at
+       // Calculate the x^0 term using a derivation of the formula at
        // http://en.wikipedia.org/wiki/Lagrange_polynomial#Example_2
-       uint8_t ret = 0;
-       for (uint8_t i = 0; i < shares_required; i++) {
+       uint8_t ret = 0, i, j;
+       for (i = 0; i < shares_required; i++) {
                uint8_t temp = q[i];
-               for (uint8_t j = 0; j < shares_required; j++) {
+               for (j = 0; j < shares_required; j++) {
                        if (i == j)
                                continue;
                        temp = field_mul(temp, field_neg(x[j]));