Support mapping primitive arrays as non-arrays (eg numbers)
[ldk-java] / typescript_strings.py
index 892c63b802488a20b73b7101fa348e11b58afaee..e3b980830f6d0bce2172be184ef7d248d1abd351 100644 (file)
@@ -17,10 +17,10 @@ class Consts:
         self.function_ptr_counter = 0
         self.function_ptrs = {}
         self.c_type_map = dict(
-            uint8_t = ['number', 'Uint8Array'],
-            uint16_t = ['number', 'Uint16Array'],
-            uint32_t = ['number', 'Uint32Array'],
-            uint64_t = ['BigInt'],
+            uint8_t = ['number', 'number', 'Uint8Array'],
+            uint16_t = ['number', 'number', 'Uint16Array'],
+            uint32_t = ['number', 'number', 'Uint32Array'],
+            uint64_t = ['BigInt', 'BigInt', 'BigUint64Array'],
         )
         self.java_type_map = dict(
             String = "number"
@@ -29,14 +29,6 @@ class Consts:
             String = "string"
         )
 
-        self.wasm_decoding_map = dict(
-            int8_tArray = 'decodeUint8Array'
-        )
-
-        self.wasm_encoding_map = dict(
-            int8_tArray = 'encodeUint8Array',
-        )
-
         self.to_hu_conv_templates = dict(
             ptr = 'const {var_name}_hu_conv: {human_type} = new {human_type}(null, {var_name});',
             default = 'const {var_name}_hu_conv: {human_type} = new {human_type}(null, {var_name});',
@@ -91,11 +83,11 @@ export default class CommonBase {
        /* @internal */
        public constructor(_dummy: object, ptr: number) {
                super(ptr, bindings.TxOut_free);
-               this.script_pubkey = bindings.TxOut_get_script_pubkey(ptr);
+               this.script_pubkey = bindings.decodeUint8Array(bindings.TxOut_get_script_pubkey(ptr));
                this.value = bindings.TxOut_get_value(ptr);
        }
        public constructor_new(value: BigInt, script_pubkey: Uint8Array): TxOut {
-               return new TxOut(null, bindings.TxOut_new(script_pubkey, value));
+               return new TxOut(null, bindings.TxOut_new(bindings.encodeUint8Array(script_pubkey), value));
        }
 }"""
         self.obj_defined(["TxOut"], "structs")
@@ -353,6 +345,7 @@ import * as InternalUtils from '../InternalUtils.mjs'
             return None
 
     def map_hu_array_elems(self, arr_name, conv_name, arr_ty, elem_ty):
+        assert elem_ty.c_ty == "uint32_t" or elem_ty.c_ty.endswith("Array")
         return arr_name + " != null ? " + arr_name + ".map(" + conv_name + " => " + elem_ty.from_hu_conv[0] + ") : null"
 
     def str_ref_to_native_call(self, var_name, str_len):
@@ -458,7 +451,7 @@ const nextMultipleOfFour = (value: number) => {
        return Math.ceil(value / 4) * 4;
 }
 
-const encodeUint8Array = (inputArray: Uint8Array) => {
+export function encodeUint8Array (inputArray: Uint8Array): number {
        const cArrayPointer = wasm.TS_malloc(inputArray.length + 4);
        const arrayLengthView = new Uint32Array(wasm.memory.buffer, cArrayPointer, 1);
        arrayLengthView[0] = inputArray.length;
@@ -466,32 +459,37 @@ const encodeUint8Array = (inputArray: Uint8Array) => {
        arrayMemoryView.set(inputArray);
        return cArrayPointer;
 }
-
-const encodeUint32Array = (inputArray: Uint32Array) => {
+export function encodeUint32Array (inputArray: Uint32Array|Array<number>): number {
        const cArrayPointer = wasm.TS_malloc((inputArray.length + 1) * 4);
        const arrayMemoryView = new Uint32Array(wasm.memory.buffer, cArrayPointer, inputArray.length);
        arrayMemoryView.set(inputArray, 1);
        arrayMemoryView[0] = inputArray.length;
        return cArrayPointer;
 }
+export function encodeUint64Array (inputArray: BigUint64Array|Array<BigInt>): number {
+       const cArrayPointer = wasm.TS_malloc(inputArray.length * 8 + 1);
+       const arrayLengthView = new Uint32Array(wasm.memory.buffer, cArrayPointer, 1);
+       arrayLengthView[0] = inputArray.length;
+       const arrayMemoryView = new BigUint64Array(wasm.memory.buffer, cArrayPointer + 4, inputArray.length);
+       arrayMemoryView.set(inputArray);
+       return cArrayPointer;
+}
 
-const getArrayLength = (arrayPointer: number) => {
-       const arraySizeViewer = new Uint32Array(
-               wasm.memory.buffer, // value
-               arrayPointer, // offset
-               1 // one int
-       );
+export function check_arr_len(arr: Uint8Array, len: number): Uint8Array {
+       if (arr.length != len) { throw new Error("Expected array of length " + len + "got " + arr.length); }
+       return arr;
+}
+
+export function getArrayLength(arrayPointer: number): number {
+       const arraySizeViewer = new Uint32Array(wasm.memory.buffer, arrayPointer, 1);
        return arraySizeViewer[0];
 }
-const decodeUint8Array = (arrayPointer: number, free = true) => {
+export function decodeUint8Array (arrayPointer: number, free = true): Uint8Array {
        const arraySize = getArrayLength(arrayPointer);
-       const actualArrayViewer = new Uint8Array(
-               wasm.memory.buffer, // value
-               arrayPointer + 4, // offset (ignoring length bytes)
-               arraySize // uint8 count
-       );
+       const actualArrayViewer = new Uint8Array(wasm.memory.buffer, arrayPointer + 4, arraySize);
        // Clone the contents, TODO: In the future we should wrap the Viewer in a class that
        // will free the underlying memory when it becomes unreachable instead of copying here.
+       // Note that doing so may have edge-case interactions with memory resizing (invalidating the buffer).
        const actualArray = actualArrayViewer.slice(0, arraySize);
        if (free) {
                wasm.TS_free(arrayPointer);
@@ -514,6 +512,11 @@ const decodeUint32Array = (arrayPointer: number, free = true) => {
        return actualArray;
 }
 
+export function getU32ArrayElem(arrayPointer: number, idx: number): number {
+       const actualArrayViewer = new Uint32Array(wasm.memory.buffer, arrayPointer + 4, idx + 1);
+       return actualArrayViewer[idx];
+}
+
 export function encodeString(str: string): number {
        const charArray = new TextEncoder().encode(str);
        return encodeUint8Array(charArray);
@@ -535,12 +538,43 @@ export function decodeString(stringPointer: number, free = true): string {
     def init_str(self):
         return ""
 
+    def get_java_arr_len(self, arr_name):
+        return "bindings.getArrayLength(" + arr_name + ")"
+    def get_java_arr_elem(self, elem_ty, arr_name, idx):
+        if elem_ty.c_ty == "uint32_t" or elem_ty.c_ty == "uintptr_t" or elem_ty.c_ty.endswith("Array"):
+            return "bindings.getU32ArrayElem(" + arr_name + ", " + idx + ")"
+        else:
+            assert False
     def constr_hu_array(self, ty_info, arr_len):
         return "new Array(" + arr_len + ").fill(null)"
 
+    def primitive_arr_from_hu(self, mapped_ty, fixed_len, arr_name):
+        inner = arr_name
+        if fixed_len is not None:
+            assert mapped_ty.c_ty == "int8_t"
+            inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+        if mapped_ty.c_ty.endswith("Array"):
+            return ("bindings.encodeUint32Array(" + inner + ")", "")
+        elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
+            return ("bindings.encodeUint8Array(" + inner + ")", "")
+        elif mapped_ty.c_ty == "uint32_t":
+            return ("bindings.encodeUint32Array(" + inner + ")", "")
+        elif mapped_ty.c_ty == "int64_t":
+            return ("bindings.encodeUint64Array(" + inner + ")", "")
+        else:
+            print(mapped_ty.c_ty)
+            assert False
+
+    def primitive_arr_to_hu(self, mapped_ty, fixed_len, arr_name, conv_name):
+        assert mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t"
+        return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
+
     def var_decl_statement(self, ty_string, var_name, statement):
         return "const " + var_name + ": " + ty_string + " = " + statement
 
+    def java_arr_ty_str(self, elem_ty_str):
+        return "number"
+
     def for_n_in_range(self, n, minimum, maximum):
         return "for (var " + n + " = " + minimum + "; " + n + " < " + maximum + "; " + n + "++) {"
     def for_n_in_arr(self, n, arr_name, arr_elem_ty):
@@ -1073,13 +1107,9 @@ export class {human_ty} extends CommonBase {{
 
     def fn_call_body(self, method_name, return_c_ty, return_java_ty, method_argument_string, native_call_argument_string):
         has_return_value = return_c_ty != 'void'
-        needs_decoding = return_c_ty in self.wasm_decoding_map
         return_statement = 'return nativeResponseValue;'
         if not has_return_value:
             return_statement = '// debug statements here'
-        elif needs_decoding:
-            converter = self.wasm_decoding_map[return_c_ty]
-            return_statement = f"return {converter}(nativeResponseValue);"
 
         return f"""export function {method_name}({method_argument_string}): {return_java_ty} {{
        if(!isWasmInitialized) {{
@@ -1112,13 +1142,8 @@ export class {human_ty} extends CommonBase {{
                 out_c += (", ")
             if arg_conv_info.c_ty != "void":
                 out_c += (arg_conv_info.c_ty + " " + arg_conv_info.arg_name)
-                needs_encoding = arg_conv_info.c_ty in self.wasm_encoding_map
-                native_argument = arg_conv_info.arg_name
-                if needs_encoding:
-                    converter = self.wasm_encoding_map[arg_conv_info.c_ty]
-                    native_argument = f"{converter}({arg_conv_info.arg_name})"
                 method_argument_string += f"{arg_conv_info.arg_name}: {arg_conv_info.java_ty}"
-                native_call_argument_string += native_argument
+                native_call_argument_string += arg_conv_info.arg_name
         out_java = self.fn_call_body(method_name, return_type_info.c_ty, return_type_info.java_ty, method_argument_string, native_call_argument_string)
 
         out_java_struct = ""