Check array lengths before passing them to C
authorMatt Corallo <git@bluematt.me>
Tue, 30 Nov 2021 01:38:04 +0000 (01:38 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 1 Dec 2021 14:14:13 +0000 (14:14 +0000)
When users pass a static-length array to C we currently CHECK its
length, asserting only if we are built in debug mode. In
production, we happily call JNI's `GetByteArrayRegion` with the
expected length, ignoring any errors. `GetByteArrayRegion`,
however, "THROWS ArrayIndexOutOfBoundsException: if one of the
indexes in the region is not valid.". While its somewhat unclear
what "THROWS" means in the context of a C API, it seems safe to
assume accessing return values after a "THROWS" condition is
undefined. Thus, we should ensure we check array lengths before
calling into C.

We do this here with a simple wrapper function added to
`org.ldk.util.InternalUtils` which checks an array is the correct
length before returning it.

gen_type_mapping.py
java_strings.py
src/main/java/org/ldk/util/InternalUtils.java [new file with mode: 0644]

index 01bd92914b01dc857a5ca1b9a432889b5c8ef147..daa18dcdd5c20b8a1519def49f7ca36445afd221 100644 (file)
@@ -47,6 +47,7 @@ class TypeMappingGenerator:
                 (set_pfx, set_sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_len, ty_info)
                 ret_conv = ("int8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "")
                 arg_conv_cleanup = None
+                from_hu_conv = None
                 if not arr_len.isdigit():
                     arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
                     arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = " +  self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + ";\n"
@@ -69,15 +70,18 @@ class TypeMappingGenerator:
                     arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n"
                     arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_ref." + ty_info.arr_access, arr_len, ty_info, True) + ";"
                     ret_conv = (ret_conv[0], "." + ty_info.arr_access + set_sfx + ";")
+                    from_hu_conv = ("InternalUtils.check_arr_len(" + arr_name + ", " + arr_len + ")", "")
                 else:
                     arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n"
                     arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n"
                     arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_arr", arr_len, ty_info, True) + ";\n"
                     arg_conv = arg_conv + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
                     ret_conv = (ret_conv[0] + "*", set_sfx + ";")
+                    from_hu_conv = ("InternalUtils.check_arr_len(" + arr_name + ", " + arr_len + ")", "")
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = arg_conv, arg_conv_name = arr_name + "_ref", arg_conv_cleanup = arg_conv_cleanup,
-                    ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", to_hu_conv = None, to_hu_conv_name = None, from_hu_conv = None)
+                    ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", to_hu_conv = None, to_hu_conv_name = None,
+                    from_hu_conv = from_hu_conv)
             else:
                 assert not arr_len.isdigit() # fixed length arrays not implemented
                 assert ty_info.java_ty[len(ty_info.java_ty) - 2:] == "[]"
index 6e77fc347569a5145ef50e6e48b2575c9965b177..327a1f3e60903550f59a7166bafcd58a2a47d169 100644 (file)
@@ -103,9 +103,10 @@ public class version {
 
         self.util_fn_pfx = """package org.ldk.structs;
 import org.ldk.impl.bindings;
+import org.ldk.enums.*;
+import org.ldk.util.*;
 import java.util.Arrays;
 import javax.annotation.Nullable;
-import org.ldk.enums.*;
 
 public class UtilMethods {
 """
diff --git a/src/main/java/org/ldk/util/InternalUtils.java b/src/main/java/org/ldk/util/InternalUtils.java
new file mode 100644 (file)
index 0000000..692639d
--- /dev/null
@@ -0,0 +1,10 @@
+package org.ldk.util;
+
+public class InternalUtils {
+    public static byte[] check_arr_len(byte[] arr, int length) throws IllegalArgumentException {
+        if (arr.length != length) {
+            throw new IllegalArgumentException("Array must be of fixed size " + length + " but was of length " + arr.length);
+        }
+        return arr;
+    }
+}