From 490adeaa3273fe53040c11f415782977895d0be7 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 30 Nov 2021 01:38:04 +0000 Subject: [PATCH] Check array lengths before passing them to C When users pass a static-length array to C we currently CHECK its length, asserting only if we are built in debug mode. In production, we happily call JNI's `GetByteArrayRegion` with the expected length, ignoring any errors. `GetByteArrayRegion`, however, "THROWS ArrayIndexOutOfBoundsException: if one of the indexes in the region is not valid.". While its somewhat unclear what "THROWS" means in the context of a C API, it seems safe to assume accessing return values after a "THROWS" condition is undefined. Thus, we should ensure we check array lengths before calling into C. We do this here with a simple wrapper function added to `org.ldk.util.InternalUtils` which checks an array is the correct length before returning it. --- gen_type_mapping.py | 6 +++++- java_strings.py | 3 ++- src/main/java/org/ldk/util/InternalUtils.java | 10 ++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/ldk/util/InternalUtils.java diff --git a/gen_type_mapping.py b/gen_type_mapping.py index 01bd9291..daa18dcd 100644 --- a/gen_type_mapping.py +++ b/gen_type_mapping.py @@ -47,6 +47,7 @@ class TypeMappingGenerator: (set_pfx, set_sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_len, ty_info) ret_conv = ("int8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "") arg_conv_cleanup = None + from_hu_conv = None if not arr_len.isdigit(): arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = " + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + ";\n" @@ -69,15 +70,18 @@ class TypeMappingGenerator: arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n" arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_ref." + ty_info.arr_access, arr_len, ty_info, True) + ";" ret_conv = (ret_conv[0], "." + ty_info.arr_access + set_sfx + ";") + from_hu_conv = ("InternalUtils.check_arr_len(" + arr_name + ", " + arr_len + ")", "") else: arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n" arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_arr", arr_len, ty_info, True) + ";\n" arg_conv = arg_conv + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;" ret_conv = (ret_conv[0] + "*", set_sfx + ";") + from_hu_conv = ("InternalUtils.check_arr_len(" + arr_name + ", " + arr_len + ")", "") return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, arg_conv = arg_conv, arg_conv_name = arr_name + "_ref", arg_conv_cleanup = arg_conv_cleanup, - ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", to_hu_conv = None, to_hu_conv_name = None, from_hu_conv = None) + ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", to_hu_conv = None, to_hu_conv_name = None, + from_hu_conv = from_hu_conv) else: assert not arr_len.isdigit() # fixed length arrays not implemented assert ty_info.java_ty[len(ty_info.java_ty) - 2:] == "[]" diff --git a/java_strings.py b/java_strings.py index 6e77fc34..327a1f3e 100644 --- a/java_strings.py +++ b/java_strings.py @@ -103,9 +103,10 @@ public class version { self.util_fn_pfx = """package org.ldk.structs; import org.ldk.impl.bindings; +import org.ldk.enums.*; +import org.ldk.util.*; import java.util.Arrays; import javax.annotation.Nullable; -import org.ldk.enums.*; public class UtilMethods { """ diff --git a/src/main/java/org/ldk/util/InternalUtils.java b/src/main/java/org/ldk/util/InternalUtils.java new file mode 100644 index 00000000..692639d4 --- /dev/null +++ b/src/main/java/org/ldk/util/InternalUtils.java @@ -0,0 +1,10 @@ +package org.ldk.util; + +public class InternalUtils { + public static byte[] check_arr_len(byte[] arr, int length) throws IllegalArgumentException { + if (arr.length != length) { + throw new IllegalArgumentException("Array must be of fixed size " + length + " but was of length " + arr.length); + } + return arr; + } +} -- 2.30.2