From: Matt Corallo Date: Fri, 3 Mar 2023 01:12:36 +0000 (+0000) Subject: Support (fixed-length) arrays of 16-bit integers X-Git-Tag: v0.0.114.0~16 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=1d4aa806117a28f8a5899ed0dc4e9ad7ed117787;p=ldk-java Support (fixed-length) arrays of 16-bit integers --- diff --git a/gen_type_mapping.py b/gen_type_mapping.py index d24ed230..aebb290d 100644 --- a/gen_type_mapping.py +++ b/gen_type_mapping.py @@ -43,14 +43,14 @@ class TypeMappingGenerator: else: arr_name = "ret" arr_len = ret_arr_len - if ty_info.c_ty == "int8_tArray": + if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray": (set_pfx, set_sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_len, ty_info) - ret_conv = ("int8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "") + ret_conv = (ty_info.c_ty + " " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "") arg_conv_cleanup = None if not arr_len.isdigit(): arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = " + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + ";\n" - if (not ty_info.is_ptr or not holds_ref) and ty_info.rust_obj != "LDKu8slice": + if (not ty_info.is_ptr or not holds_ref) and (ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKu16slice"): arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = MALLOC(" + arr_name + "_ref." + arr_len + ", \"" + ty_info.rust_obj + " Bytes\");\n" arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_ref." + ty_info.arr_access, arr_name + "_ref." + arr_len, ty_info, True) + ";" else: @@ -59,10 +59,10 @@ class TypeMappingGenerator: if ty_info.rust_obj == "LDKTransaction" or ty_info.rust_obj == "LDKWitness": arg_conv = arg_conv + "\n" + arr_name + "_ref.data_is_owned = " + str(holds_ref).lower() + ";" ret_conv = (ty_info.rust_obj + " " + arr_name + "_var = ", "") - ret_conv = (ret_conv[0], ";\nint8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_name + "_var." + arr_len, ty_info) + ";\n") + ret_conv = (ret_conv[0], ";\n" + ty_info.c_ty + " " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_name + "_var." + arr_len, ty_info) + ";\n") (pfx, sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_name + "_var." + arr_len, ty_info) ret_conv = (ret_conv[0], ret_conv[1] + pfx + arr_name + "_var." + ty_info.arr_access + sfx + ";") - if not holds_ref and ty_info.rust_obj != "LDKu8slice": + if not holds_ref and (ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKu16slice"): ret_conv = (ret_conv[0], ret_conv[1] + "\n" + ty_info.rust_obj.replace("LDK", "") + "_free(" + arr_name + "_var);") from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, None, arr_name) to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv") @@ -74,10 +74,11 @@ class TypeMappingGenerator: from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, arr_len, arr_name) to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv") else: - arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" + # Note that we just blindly assume we should be using unsigned integers here. + arg_conv = "u" + ty_info.subty.c_ty + " " + arr_name + "_arr[" + arr_len + "];\n" arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n" arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_arr", arr_len, ty_info, True) + ";\n" - arg_conv = arg_conv + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;" + arg_conv = arg_conv + "u" + ty_info.subty.c_ty + " (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;" ret_conv = (ret_conv[0] + "*", set_sfx + ";") from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, arr_len, arr_name) to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv") diff --git a/genbindings.py b/genbindings.py index c9ae45ca..77f9b472 100755 --- a/genbindings.py +++ b/genbindings.py @@ -143,6 +143,11 @@ def java_c_types(fn_arg, ret_arr_len): assert var_is_arr_regex.match(fn_arg[8:]) rust_obj = "LDKThirtyTwoBytes" arr_access = "data" + elif fn_arg.startswith("LDKEightU16s"): + fn_arg = "uint16_t (*" + fn_arg[13:] + ")[8]" + assert var_is_arr_regex.match(fn_arg[9:]) + rust_obj = "LDKEightU16s" + arr_access = "data" elif fn_arg.startswith("LDKU128"): if fn_arg == "LDKU128": fn_arg = "LDKU128 arg" diff --git a/java_strings.py b/java_strings.py index d956e0e1..650e36bf 100644 --- a/java_strings.py +++ b/java_strings.py @@ -426,6 +426,7 @@ _Static_assert(sizeof(void*) <= 8, "Pointers must fit into 64 bits"); typedef jlongArray int64_tArray; typedef jbyteArray int8_tArray; +typedef jshortArray int16_tArray; static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) { // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator @@ -525,14 +526,17 @@ import javax.annotation.Nullable; def set_native_arr_contents(self, arr_name, arr_len, ty_info): if ty_info.c_ty == "int8_tArray": return ("(*env)->SetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")") + elif ty_info.c_ty == "int16_tArray": + return ("(*env)->SetShortArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")") else: assert False def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy): - if ty_info.c_ty == "int8_tArray": + if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray": + fn_ty = "Byte" if ty_info.c_ty == "int8_tArray" else "Short" if copy: - return "(*env)->GetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")" + return "(*env)->Get" + fn_ty + "ArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")" else: - return "(*env)->GetByteArrayElements (env, " + arr_name + ", NULL)" + return "(*env)->Get" + fn_ty + "ArrayElements (env, " + arr_name + ", NULL)" elif not ty_info.java_ty[:len(ty_info.java_ty) - 2].endswith("[]"): return "(*env)->Get" + ty_info.subty.java_ty.title() + "ArrayElements (env, " + arr_name + ", NULL)" else: @@ -616,7 +620,13 @@ import javax.annotation.Nullable; if arr_ty.rust_obj == "LDKU128": return ("" + arr_name + ".getLEBytes()", "") if fixed_len is not None: - return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "") + if mapped_ty.c_ty == "int8_t" or mapped_ty.c_ty == "uint8_t": + return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "") + elif mapped_ty.c_ty == "int16_t" or mapped_ty.c_ty == "uint16_t": + return ("InternalUtils.check_arr_16_len(" + arr_name + ", " + fixed_len + ")", "") + else: + print(arr_ty.c_ty) + assert False return None def primitive_arr_to_hu(self, arr_ty, fixed_len, arr_name, conv_name): if arr_ty.rust_obj == "LDKU128": diff --git a/src/main/java/org/ldk/util/InternalUtils.java b/src/main/java/org/ldk/util/InternalUtils.java index d9c3f6ef..4e706ec1 100644 --- a/src/main/java/org/ldk/util/InternalUtils.java +++ b/src/main/java/org/ldk/util/InternalUtils.java @@ -8,6 +8,13 @@ public class InternalUtils { return arr; } + public static short[] check_arr_16_len(short[] arr, int length) throws IllegalArgumentException { + if (arr != null && arr.length != length) { + throw new IllegalArgumentException("Array must be of fixed size " + length + " but was of length " + arr.length); + } + return arr; + } + public static byte[] convUInt5Array(UInt5[] u5s) { byte[] res = new byte[u5s.length]; for (int i = 0; i < u5s.length; i++) { diff --git a/typescript_strings.py b/typescript_strings.py index 402e6399..6706ff39 100644 --- a/typescript_strings.py +++ b/typescript_strings.py @@ -188,6 +188,16 @@ export function encodeUint8Array (inputArray: Uint8Array|null): number { return cArrayPointer; } /* @internal */ +export function encodeUint16Array (inputArray: Uint16Array|Array|null): number { + if (inputArray == null) return 0; + const cArrayPointer = wasm.TS_malloc((inputArray.length + 4) * 2); + const arrayLengthView = new BigUint64Array(wasm.memory.buffer, cArrayPointer, 1); + arrayLengthView[0] = BigInt(inputArray.length); + const arrayMemoryView = new Uint16Array(wasm.memory.buffer, cArrayPointer + 8, inputArray.length); + arrayMemoryView.set(inputArray); + return cArrayPointer; +} +/* @internal */ export function encodeUint32Array (inputArray: Uint32Array|Array|null): number { if (inputArray == null) return 0; const cArrayPointer = wasm.TS_malloc((inputArray.length + 2) * 4); @@ -213,6 +223,12 @@ export function check_arr_len(arr: Uint8Array|null, len: number): Uint8Array|nul return arr; } +/* @internal */ +export function check_16_arr_len(arr: Uint16Array|null, len: number): Uint16Array|null { + if (arr !== null && arr.length != len) { throw new Error("Expected array of length " + len + " got " + arr.length); } + return arr; +} + /* @internal */ export function getArrayLength(arrayPointer: number): number { const arraySizeViewer = new BigUint64Array(wasm.memory.buffer, arrayPointer, 1); @@ -248,15 +264,13 @@ export function decodeUint8Array (arrayPointer: number, free = true): Uint8Array } return actualArray; } -const decodeUint32Array = (arrayPointer: number, free = true) => { +/* @internal */ +export function decodeUint16Array (arrayPointer: number, free = true): Uint16Array { const arraySize = getArrayLength(arrayPointer); - const actualArrayViewer = new Uint32Array( - wasm.memory.buffer, // value - arrayPointer + 8, // offset (ignoring length bytes) - arraySize // uint32 count - ); + const actualArrayViewer = new Uint16Array(wasm.memory.buffer, arrayPointer + 8, arraySize); // Clone the contents, TODO: In the future we should wrap the Viewer in a class that // will free the underlying memory when it becomes unreachable instead of copying here. + // Note that doing so may have edge-case interactions with memory resizing (invalidating the buffer). const actualArray = actualArrayViewer.slice(0, arraySize); if (free) { wasm.TS_free(arrayPointer); @@ -643,6 +657,7 @@ _Static_assert(sizeof(void*) == 4, "Pointers mut be 32 bits"); DECL_ARR_TYPE(int64_t, int64_t); DECL_ARR_TYPE(uint64_t, uint64_t); DECL_ARR_TYPE(int8_t, int8_t); +DECL_ARR_TYPE(int16_t, int16_t); DECL_ARR_TYPE(uint32_t, uint32_t); DECL_ARR_TYPE(void*, ptr); DECL_ARR_TYPE(char, char); @@ -724,12 +739,17 @@ import * as bindings from '../bindings.mjs' def set_native_arr_contents(self, arr_name, arr_len, ty_info): if ty_info.c_ty == "int8_tArray": return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + ")") + elif ty_info.c_ty == "int16_tArray": + return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + " * 2)") else: assert False def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy): - if ty_info.c_ty == "int8_tArray": + if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray": if copy: - return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + arr_len + "); FREE(" + arr_name + ")" + byte_len = arr_len + if ty_info.c_ty == "int16_tArray": + byte_len = arr_len + " * 2" + return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + byte_len + "); FREE(" + arr_name + ")" assert not copy if ty_info.c_ty == "ptrArray": return "(void*) " + arr_name + "->elems" @@ -792,12 +812,16 @@ import * as bindings from '../bindings.mjs' if arr_ty.rust_obj == "LDKU128": return ("bindings.encodeUint128(" + inner + ")", "") if fixed_len is not None: - assert mapped_ty.c_ty == "int8_t" - inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")" + if mapped_ty.c_ty == "int8_t": + inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")" + elif mapped_ty.c_ty == "int16_t": + inner = "bindings.check_16_arr_len(" + arr_name + ", " + fixed_len + ")" if mapped_ty.c_ty.endswith("Array"): return ("bindings.encodeUint32Array(" + inner + ")", "") elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t": return ("bindings.encodeUint8Array(" + inner + ")", "") + elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t": + return ("bindings.encodeUint16Array(" + inner + ")", "") elif mapped_ty.c_ty == "uint32_t": return ("bindings.encodeUint32Array(" + inner + ")", "") elif mapped_ty.c_ty == "int64_t" or mapped_ty.c_ty == "uint64_t": @@ -812,6 +836,8 @@ import * as bindings from '../bindings.mjs' return "const " + conv_name + ": bigint = bindings.decodeUint128(" + arr_name + ");" elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t": return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");" + elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t": + return "const " + conv_name + ": Uint16Array = bindings.decodeUint16Array(" + arr_name + ");" elif mapped_ty.c_ty == "uint64_t" or mapped_ty.c_ty == "int64_t": return "const " + conv_name + ": bigint[] = bindings.decodeUint64Array(" + arr_name + ");" else: