From dbce8d1aa627e4a632610895c5ea1fddfd9a24f1 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 10 Jan 2022 00:55:04 +0000 Subject: [PATCH] Support mapping primitive arrays as non-arrays (eg numbers) TypeScript can't pass an array through to C, so we have to pass a pointer to a constructed array. This adds support in the relevant type-conversion logic to enable this (and uses it in TS). --- gen_type_mapping.py | 33 ++++++++++---- genbindings.py | 13 ++++-- java_strings.py | 14 ++++++ typescript_strings.py | 103 ++++++++++++++++++++++++++---------------- 4 files changed, 109 insertions(+), 54 deletions(-) diff --git a/gen_type_mapping.py b/gen_type_mapping.py index b285ab34..517deb2f 100644 --- a/gen_type_mapping.py +++ b/gen_type_mapping.py @@ -47,7 +47,6 @@ class TypeMappingGenerator: (set_pfx, set_sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_len, ty_info) ret_conv = ("int8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "") arg_conv_cleanup = None - from_hu_conv = None if not arr_len.isdigit(): arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = " + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + ";\n" @@ -65,26 +64,34 @@ class TypeMappingGenerator: ret_conv = (ret_conv[0], ret_conv[1] + pfx + arr_name + "_var." + ty_info.arr_access + sfx + ";") if not holds_ref and ty_info.rust_obj != "LDKu8slice": ret_conv = (ret_conv[0], ret_conv[1] + "\n" + ty_info.rust_obj.replace("LDK", "") + "_free(" + arr_name + "_var);") + from_hu_conv = self.consts.primitive_arr_from_hu(ty_info.subty, None, arr_name) + to_hu_conv = self.consts.primitive_arr_to_hu(ty_info.subty, None, arr_name, arr_name + "_conv") elif ty_info.rust_obj is not None: arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n" arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_ref." + ty_info.arr_access, arr_len, ty_info, True) + ";" ret_conv = (ret_conv[0], "." + ty_info.arr_access + set_sfx + ";") - from_hu_conv = ("InternalUtils.check_arr_len(" + arr_name + ", " + arr_len + ")", "") + from_hu_conv = self.consts.primitive_arr_from_hu(ty_info.subty, arr_len, arr_name) + to_hu_conv = self.consts.primitive_arr_to_hu(ty_info.subty, None, arr_name, arr_name + "_conv") else: arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n" arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_arr", arr_len, ty_info, True) + ";\n" arg_conv = arg_conv + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;" ret_conv = (ret_conv[0] + "*", set_sfx + ";") - from_hu_conv = ("InternalUtils.check_arr_len(" + arr_name + ", " + arr_len + ")", "") + from_hu_conv = self.consts.primitive_arr_from_hu(ty_info.subty, arr_len, arr_name) + to_hu_conv = self.consts.primitive_arr_to_hu(ty_info.subty, None, arr_name, arr_name + "_conv") + to_hu_conv_name = None + if to_hu_conv is not None: + to_hu_conv_name = arr_name + "_conv" return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, arg_conv = arg_conv, arg_conv_name = arr_name + "_ref", arg_conv_cleanup = arg_conv_cleanup, - ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", to_hu_conv = None, to_hu_conv_name = None, + ret_conv = ret_conv, ret_conv_name = arr_name + "_arr", + to_hu_conv = to_hu_conv, to_hu_conv_name = to_hu_conv_name, from_hu_conv = from_hu_conv) else: assert not arr_len.isdigit() # fixed length arrays not implemented - assert ty_info.java_ty[len(ty_info.java_ty) - 2:] == "[]" + assert ty_info.java_hu_ty[len(ty_info.java_hu_ty) - 2:] == "[]" if arr_name == "": arr_name = "ret" conv_name = arr_name + "_conv_" + str(len(ty_info.java_hu_ty)) @@ -183,19 +190,25 @@ class TypeMappingGenerator: to_hu_conv = None to_hu_conv_name = None if subty.to_hu_conv is not None: - to_hu_conv = self.consts.var_decl_statement(ty_info.java_hu_ty, conv_name + "_arr", self.consts.constr_hu_array(ty_info, arr_name + ".length")) - to_hu_conv += ";\n" + self.consts.for_n_in_range(idxc, "0", arr_name + ".length") + "\n" - to_hu_conv += "\t" + self.consts.var_decl_statement(subty.java_ty, conv_name, arr_name + "[" + idxc + "]") + ";\n" + to_hu_conv = self.consts.var_decl_statement(self.consts.c_type_map["uint32_t"][0], conv_name + "_len", self.consts.get_java_arr_len(arr_name)) + ";\n" + to_hu_conv += self.consts.var_decl_statement(ty_info.java_hu_ty, conv_name + "_arr", self.consts.constr_hu_array(ty_info, conv_name + "_len")) + to_hu_conv += ";\n" + self.consts.for_n_in_range(idxc, "0", conv_name + "_len") + "\n" + to_hu_conv += "\t" + self.consts.var_decl_statement(subty.java_ty, conv_name, self.consts.get_java_arr_elem(subty, arr_name, idxc)) + ";\n" to_hu_conv += "\t" + subty.to_hu_conv.replace("\n", "\n\t") + "\n" to_hu_conv += "\t" + conv_name + "_arr[" + idxc + "] = " + subty.to_hu_conv_name + ";\n}" to_hu_conv_name = conv_name + "_arr" - from_hu_conv = None + from_hu_conv = self.consts.primitive_arr_from_hu(ty_info.subty, None, arr_name) if subty.from_hu_conv is not None: hu_conv_b = "" if subty.from_hu_conv[1] != "": iterator = self.consts.for_n_in_arr(conv_name, arr_name, subty) hu_conv_b = iterator[0] + subty.from_hu_conv[1] + ";" + iterator[1] - from_hu_conv = (self.consts.map_hu_array_elems(arr_name, conv_name, ty_info, subty), hu_conv_b) + if from_hu_conv is not None: + arr_conv = self.consts.primitive_arr_from_hu(ty_info.subty, None, self.consts.map_hu_array_elems(arr_name, conv_name, ty_info, subty)) + assert arr_conv[1] == "" + from_hu_conv = (arr_conv[0], hu_conv_b) + else: + from_hu_conv = (self.consts.map_hu_array_elems(arr_name, conv_name, ty_info, subty), hu_conv_b) return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, arg_conv = arg_conv, arg_conv_name = arg_conv_name, arg_conv_cleanup = arg_conv_cleanup, diff --git a/genbindings.py b/genbindings.py index f426b23f..8b067f72 100755 --- a/genbindings.py +++ b/genbindings.py @@ -214,13 +214,14 @@ def java_c_types(fn_arg, ret_arr_len): return None if is_ptr: res.pass_by_ref = True + java_ty = consts.java_arr_ty_str(res.java_ty) if res.is_native_primitive or res.passed_as_ptr: - return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=res.java_ty + "[]", java_hu_ty=res.java_hu_ty + "[]", + return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=java_ty, java_hu_ty=res.java_hu_ty + "[]", java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=res.c_ty + "Array", passed_as_ptr=False, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, is_const=is_const, var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False) else: - return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=res.java_ty + "[]", java_hu_ty=res.java_hu_ty + "[]", + return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=java_ty, java_hu_ty=res.java_hu_ty + "[]", java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=consts.ptr_arr, passed_as_ptr=False, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, is_const=is_const, var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False) @@ -351,10 +352,12 @@ def java_c_types(fn_arg, ret_arr_len): assert(not take_by_ptr) assert(not is_ptr) # is there a special case for plurals? - if len(mapped_type) == 2: + if len(mapped_type) == 3: java_ty = mapped_type[1] + java_hu_ty = mapped_type[2] else: java_ty = java_ty + "[]" + java_hu_ty = java_ty c_ty = c_ty + "Array" subty = java_c_types(arr_ty, None) @@ -366,10 +369,10 @@ def java_c_types(fn_arg, ret_arr_len): if var_is_arr is not None: if var_is_arr.group(1) == "": - return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, + return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg", subty=subty, arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait) - return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, + return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1), subty=subty, arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait) diff --git a/java_strings.py b/java_strings.py index d477aa31..8fcc3e5c 100644 --- a/java_strings.py +++ b/java_strings.py @@ -662,6 +662,10 @@ import javax.annotation.Nullable; def var_decl_statement(self, ty_string, var_name, statement): return ty_string + " " + var_name + " = " + statement + def get_java_arr_len(self, arr_name): + return arr_name + ".length" + def get_java_arr_elem(self, elem_ty, arr_name, idx): + return arr_name + "[" + idx + "]" def constr_hu_array(self, ty_info, arr_len): base_ty = ty_info.subty.java_hu_ty.split("[")[0].split("<")[0] conv = "new " + base_ty + "[" + arr_len + "]" @@ -670,6 +674,16 @@ import javax.annotation.Nullable; conv += "[" + ty_info.subty.java_hu_ty.split("<")[0].split("[")[1] return conv + def primitive_arr_from_hu(self, mapped_ty, fixed_len, arr_name): + if fixed_len is not None: + return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "") + return None + def primitive_arr_to_hu(self, primitive_ty, fixed_len, arr_name, conv_name): + return None + + def java_arr_ty_str(self, elem_ty_str): + return elem_ty_str + "[]" + def for_n_in_range(self, n, minimum, maximum): return "for (int " + n + " = " + minimum + "; " + n + " < " + maximum + "; " + n + "++) {" def for_n_in_arr(self, n, arr_name, arr_elem_ty): diff --git a/typescript_strings.py b/typescript_strings.py index 892c63b8..e3b98083 100644 --- a/typescript_strings.py +++ b/typescript_strings.py @@ -17,10 +17,10 @@ class Consts: self.function_ptr_counter = 0 self.function_ptrs = {} self.c_type_map = dict( - uint8_t = ['number', 'Uint8Array'], - uint16_t = ['number', 'Uint16Array'], - uint32_t = ['number', 'Uint32Array'], - uint64_t = ['BigInt'], + uint8_t = ['number', 'number', 'Uint8Array'], + uint16_t = ['number', 'number', 'Uint16Array'], + uint32_t = ['number', 'number', 'Uint32Array'], + uint64_t = ['BigInt', 'BigInt', 'BigUint64Array'], ) self.java_type_map = dict( String = "number" @@ -29,14 +29,6 @@ class Consts: String = "string" ) - self.wasm_decoding_map = dict( - int8_tArray = 'decodeUint8Array' - ) - - self.wasm_encoding_map = dict( - int8_tArray = 'encodeUint8Array', - ) - self.to_hu_conv_templates = dict( ptr = 'const {var_name}_hu_conv: {human_type} = new {human_type}(null, {var_name});', default = 'const {var_name}_hu_conv: {human_type} = new {human_type}(null, {var_name});', @@ -91,11 +83,11 @@ export default class CommonBase { /* @internal */ public constructor(_dummy: object, ptr: number) { super(ptr, bindings.TxOut_free); - this.script_pubkey = bindings.TxOut_get_script_pubkey(ptr); + this.script_pubkey = bindings.decodeUint8Array(bindings.TxOut_get_script_pubkey(ptr)); this.value = bindings.TxOut_get_value(ptr); } public constructor_new(value: BigInt, script_pubkey: Uint8Array): TxOut { - return new TxOut(null, bindings.TxOut_new(script_pubkey, value)); + return new TxOut(null, bindings.TxOut_new(bindings.encodeUint8Array(script_pubkey), value)); } }""" self.obj_defined(["TxOut"], "structs") @@ -353,6 +345,7 @@ import * as InternalUtils from '../InternalUtils.mjs' return None def map_hu_array_elems(self, arr_name, conv_name, arr_ty, elem_ty): + assert elem_ty.c_ty == "uint32_t" or elem_ty.c_ty.endswith("Array") return arr_name + " != null ? " + arr_name + ".map(" + conv_name + " => " + elem_ty.from_hu_conv[0] + ") : null" def str_ref_to_native_call(self, var_name, str_len): @@ -458,7 +451,7 @@ const nextMultipleOfFour = (value: number) => { return Math.ceil(value / 4) * 4; } -const encodeUint8Array = (inputArray: Uint8Array) => { +export function encodeUint8Array (inputArray: Uint8Array): number { const cArrayPointer = wasm.TS_malloc(inputArray.length + 4); const arrayLengthView = new Uint32Array(wasm.memory.buffer, cArrayPointer, 1); arrayLengthView[0] = inputArray.length; @@ -466,32 +459,37 @@ const encodeUint8Array = (inputArray: Uint8Array) => { arrayMemoryView.set(inputArray); return cArrayPointer; } - -const encodeUint32Array = (inputArray: Uint32Array) => { +export function encodeUint32Array (inputArray: Uint32Array|Array): number { const cArrayPointer = wasm.TS_malloc((inputArray.length + 1) * 4); const arrayMemoryView = new Uint32Array(wasm.memory.buffer, cArrayPointer, inputArray.length); arrayMemoryView.set(inputArray, 1); arrayMemoryView[0] = inputArray.length; return cArrayPointer; } +export function encodeUint64Array (inputArray: BigUint64Array|Array): number { + const cArrayPointer = wasm.TS_malloc(inputArray.length * 8 + 1); + const arrayLengthView = new Uint32Array(wasm.memory.buffer, cArrayPointer, 1); + arrayLengthView[0] = inputArray.length; + const arrayMemoryView = new BigUint64Array(wasm.memory.buffer, cArrayPointer + 4, inputArray.length); + arrayMemoryView.set(inputArray); + return cArrayPointer; +} -const getArrayLength = (arrayPointer: number) => { - const arraySizeViewer = new Uint32Array( - wasm.memory.buffer, // value - arrayPointer, // offset - 1 // one int - ); +export function check_arr_len(arr: Uint8Array, len: number): Uint8Array { + if (arr.length != len) { throw new Error("Expected array of length " + len + "got " + arr.length); } + return arr; +} + +export function getArrayLength(arrayPointer: number): number { + const arraySizeViewer = new Uint32Array(wasm.memory.buffer, arrayPointer, 1); return arraySizeViewer[0]; } -const decodeUint8Array = (arrayPointer: number, free = true) => { +export function decodeUint8Array (arrayPointer: number, free = true): Uint8Array { const arraySize = getArrayLength(arrayPointer); - const actualArrayViewer = new Uint8Array( - wasm.memory.buffer, // value - arrayPointer + 4, // offset (ignoring length bytes) - arraySize // uint8 count - ); + const actualArrayViewer = new Uint8Array(wasm.memory.buffer, arrayPointer + 4, arraySize); // Clone the contents, TODO: In the future we should wrap the Viewer in a class that // will free the underlying memory when it becomes unreachable instead of copying here. + // Note that doing so may have edge-case interactions with memory resizing (invalidating the buffer). const actualArray = actualArrayViewer.slice(0, arraySize); if (free) { wasm.TS_free(arrayPointer); @@ -514,6 +512,11 @@ const decodeUint32Array = (arrayPointer: number, free = true) => { return actualArray; } +export function getU32ArrayElem(arrayPointer: number, idx: number): number { + const actualArrayViewer = new Uint32Array(wasm.memory.buffer, arrayPointer + 4, idx + 1); + return actualArrayViewer[idx]; +} + export function encodeString(str: string): number { const charArray = new TextEncoder().encode(str); return encodeUint8Array(charArray); @@ -535,12 +538,43 @@ export function decodeString(stringPointer: number, free = true): string { def init_str(self): return "" + def get_java_arr_len(self, arr_name): + return "bindings.getArrayLength(" + arr_name + ")" + def get_java_arr_elem(self, elem_ty, arr_name, idx): + if elem_ty.c_ty == "uint32_t" or elem_ty.c_ty == "uintptr_t" or elem_ty.c_ty.endswith("Array"): + return "bindings.getU32ArrayElem(" + arr_name + ", " + idx + ")" + else: + assert False def constr_hu_array(self, ty_info, arr_len): return "new Array(" + arr_len + ").fill(null)" + def primitive_arr_from_hu(self, mapped_ty, fixed_len, arr_name): + inner = arr_name + if fixed_len is not None: + assert mapped_ty.c_ty == "int8_t" + inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")" + if mapped_ty.c_ty.endswith("Array"): + return ("bindings.encodeUint32Array(" + inner + ")", "") + elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t": + return ("bindings.encodeUint8Array(" + inner + ")", "") + elif mapped_ty.c_ty == "uint32_t": + return ("bindings.encodeUint32Array(" + inner + ")", "") + elif mapped_ty.c_ty == "int64_t": + return ("bindings.encodeUint64Array(" + inner + ")", "") + else: + print(mapped_ty.c_ty) + assert False + + def primitive_arr_to_hu(self, mapped_ty, fixed_len, arr_name, conv_name): + assert mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t" + return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");" + def var_decl_statement(self, ty_string, var_name, statement): return "const " + var_name + ": " + ty_string + " = " + statement + def java_arr_ty_str(self, elem_ty_str): + return "number" + def for_n_in_range(self, n, minimum, maximum): return "for (var " + n + " = " + minimum + "; " + n + " < " + maximum + "; " + n + "++) {" def for_n_in_arr(self, n, arr_name, arr_elem_ty): @@ -1073,13 +1107,9 @@ export class {human_ty} extends CommonBase {{ def fn_call_body(self, method_name, return_c_ty, return_java_ty, method_argument_string, native_call_argument_string): has_return_value = return_c_ty != 'void' - needs_decoding = return_c_ty in self.wasm_decoding_map return_statement = 'return nativeResponseValue;' if not has_return_value: return_statement = '// debug statements here' - elif needs_decoding: - converter = self.wasm_decoding_map[return_c_ty] - return_statement = f"return {converter}(nativeResponseValue);" return f"""export function {method_name}({method_argument_string}): {return_java_ty} {{ if(!isWasmInitialized) {{ @@ -1112,13 +1142,8 @@ export class {human_ty} extends CommonBase {{ out_c += (", ") if arg_conv_info.c_ty != "void": out_c += (arg_conv_info.c_ty + " " + arg_conv_info.arg_name) - needs_encoding = arg_conv_info.c_ty in self.wasm_encoding_map - native_argument = arg_conv_info.arg_name - if needs_encoding: - converter = self.wasm_encoding_map[arg_conv_info.c_ty] - native_argument = f"{converter}({arg_conv_info.arg_name})" method_argument_string += f"{arg_conv_info.arg_name}: {arg_conv_info.java_ty}" - native_call_argument_string += native_argument + native_call_argument_string += arg_conv_info.arg_name out_java = self.fn_call_body(method_name, return_type_info.c_ty, return_type_info.java_ty, method_argument_string, native_call_argument_string) out_java_struct = "" -- 2.39.5