[TS] Call FinalizationRegistry.register with a unregister token
[ldk-java] / typescript_strings.py
index 76e43828913cc71225f65eebfab8d336a5c1a574..fc2b7dde1261bb8cd08f882ea94859cd5d6f3599 100644 (file)
@@ -36,7 +36,7 @@ class Consts:
 
         self.bindings_header = """
 import * as version from './version.mjs';
-import { UInt5 } from './structs/CommonBase.mjs';
+import { UInt5, WitnessVersion } from './structs/CommonBase.mjs';
 
 const imports: any = {};
 imports.env = {};
@@ -46,14 +46,17 @@ var js_invoke: Function;
 var getRandomValues: Function;
 
 imports.wasi_snapshot_preview1 = {
-       "fd_write": (fd: number, iovec_array_ptr: number, iovec_array_len: number) => {
+       "fd_write": (fd: number, iovec_array_ptr: number, iovec_array_len: number, bytes_written_ptr: number) => {
                // This should generally only be used to print panic messages
-               console.log("FD_WRITE to " + fd + " in " + iovec_array_len + " chunks.");
                const ptr_len_view = new Uint32Array(wasm.memory.buffer, iovec_array_ptr, iovec_array_len * 2);
+               var bytes_written = 0;
                for (var i = 0; i < iovec_array_len; i++) {
                        const bytes_view = new Uint8Array(wasm.memory.buffer, ptr_len_view[i*2], ptr_len_view[i*2+1]);
-                       console.log(String.fromCharCode(...bytes_view));
+                       console.log("[fd " + fd + "]: " + String.fromCharCode(...bytes_view));
+                       bytes_written += ptr_len_view[i*2+1];
                }
+               const written_view = new Uint32Array(wasm.memory.buffer, bytes_written_ptr, 1);
+               written_view[0] = bytes_written;
                return 0;
        },
        "fd_close": (_fd: number) => {
@@ -73,7 +76,6 @@ imports.wasi_snapshot_preview1 = {
        },
        "environ_sizes_get": (environ_var_count_ptr: number, environ_len_ptr: number) => {
                // This is called before fd_write to format + print panic messages
-               console.log("wasi_snapshot_preview1:environ_sizes_get");
                const out_count_view = new Uint32Array(wasm.memory.buffer, environ_var_count_ptr, 1);
                out_count_view[0] = 0;
                const out_len_view = new Uint32Array(wasm.memory.buffer, environ_len_ptr, 1);
@@ -81,7 +83,8 @@ imports.wasi_snapshot_preview1 = {
                return 0;
        },
        "environ_get": (environ_ptr: number, environ_buf_ptr: number) => {
-               // This is called before fd_write to format + print panic messages
+               // This is called before fd_write to format + print panic messages,
+               // but only if we have variables in environ_sizes_get, so shouldn't ever actually happen!
                console.log("wasi_snapshot_preview1:environ_get");
                return 58; // Note supported - we said there were 0 environment entries!
        },
@@ -107,7 +110,7 @@ async function finishInitializeWasm(wasmInstance: WebAssembly.Instance) {
        }
 
        if (decodeString(wasm.TS_get_lib_version_string()) !== version.get_ldk_java_bindings_version())
-               throw new Error(\"Compiled LDK library and LDK class failes do not match\");
+               throw new Error(\"Compiled LDK library and LDK class files do not match\");
        // Fetching the LDK versions from C also checks that the header and binaries match
        const c_bindings_ver: number = wasm.TS_get_ldk_c_bindings_version();
        const ldk_ver: number = wasm.TS_get_ldk_version();
@@ -124,7 +127,8 @@ async function finishInitializeWasm(wasmInstance: WebAssembly.Instance) {
 
 /* @internal */
 export async function initializeWasmFromUint8Array(wasmBinary: Uint8Array) {
-       imports.env["js_invoke_function"] = js_invoke;
+       imports.env["js_invoke_function_u"] = js_invoke;
+       imports.env["js_invoke_function_b"] = js_invoke;
        const { instance: wasmInstance } = await WebAssembly.instantiate(wasmBinary, imports);
        await finishInitializeWasm(wasmInstance);
 }
@@ -132,7 +136,8 @@ export async function initializeWasmFromUint8Array(wasmBinary: Uint8Array) {
 /* @internal */
 export async function initializeWasmFetch(uri: string) {
        const stream = fetch(uri);
-       imports.env["js_invoke_function"] = js_invoke;
+       imports.env["js_invoke_function_u"] = js_invoke;
+       imports.env["js_invoke_function_b"] = js_invoke;
        const { instance: wasmInstance } = await WebAssembly.instantiateStreaming(stream, imports);
        await finishInitializeWasm(wasmInstance);
 }"""
@@ -149,6 +154,17 @@ export function uint5ArrToBytes(inputArray: Array<UInt5>): Uint8Array {
        return arr;
 }
 
+/* @internal */
+export function WitnessVersionArrToBytes(inputArray: Array<WitnessVersion>): Uint8Array {
+       const arr = new Uint8Array(inputArray.length);
+       for (var i = 0; i < inputArray.length; i++) {
+               arr[i] = inputArray[i].getVal();
+       }
+       return arr;
+}
+
+
+
 /* @internal */
 export function encodeUint8Array (inputArray: Uint8Array): number {
        const cArrayPointer = wasm.TS_malloc(inputArray.length + 4);
@@ -178,7 +194,7 @@ export function encodeUint64Array (inputArray: BigUint64Array|Array<bigint>): nu
 
 /* @internal */
 export function check_arr_len(arr: Uint8Array, len: number): Uint8Array {
-       if (arr.length != len) { throw new Error("Expected array of length " + len + "got " + arr.length); }
+       if (arr.length != len) { throw new Error("Expected array of length " + len + " got " + arr.length); }
        return arr;
 }
 
@@ -215,7 +231,23 @@ const decodeUint32Array = (arrayPointer: number, free = true) => {
        }
        return actualArray;
 }
-
+/* @internal */
+export function decodeUint64Array (arrayPointer: number, free = true): bigint[] {
+       const arraySize = getArrayLength(arrayPointer);
+       const actualArrayViewer = new BigUint64Array(
+               wasm.memory.buffer, // value
+               arrayPointer + 4, // offset (ignoring length bytes)
+               arraySize // uint32 count
+       );
+       // Clone the contents, TODO: In the future we should wrap the Viewer in a class that
+       // will free the underlying memory when it becomes unreachable instead of copying here.
+       const actualArray = new Array(arraySize);
+       for (var i = 0; i < arraySize; i++) actualArray[i] = actualArrayViewer[i];
+       if (free) {
+               wasm.TS_free(arrayPointer);
+       }
+       return actualArray;
+}
 
 export function freeWasmMemory(pointer: number) { wasm.TS_free(pointer); }
 
@@ -300,7 +332,7 @@ export class CommonBase {
        protected constructor(ptr: number, free_fn: (ptr: number) => void) {
                this.ptr = ptr;
                if (Number.isFinite(ptr) && ptr != 0){
-                       finalizer.register(this, get_freeer(ptr, free_fn));
+                       finalizer.register(this, get_freeer(ptr, free_fn), this);
                }
        }
        // In Java, protected means "any subclass can access fields on any other subclass'"
@@ -314,7 +346,10 @@ export class CommonBase {
        }
        protected static set_null_skip_free(o: CommonBase) {
                o.ptr = 0;
-               finalizer.unregister(o);
+               // @ts-ignore TypeScript is wrong about the returnvalue of unregister here!
+               const did_unregister: boolean = finalizer.unregister(o);
+               if (!did_unregister)
+                       throw new Error("FinalizationRegistry unregister should always unregister unless you double-free'd");
        }
 }
 
@@ -326,6 +361,19 @@ export class UInt5 {
                return this.val;
        }
 }
+
+export class WitnessVersion {
+       public constructor(private val: number) {
+               if (val > 16 || val < 0) throw new Error("WitnessVersion value is out of range");
+       }
+       public getVal(): number {
+               return this.val;
+       }
+}
+
+export class UnqualifiedError {
+       public constructor(val: number) {}
+}
 """
 
         self.txout_defn = """export class TxOut extends CommonBase {
@@ -592,7 +640,7 @@ jstring __attribute__((export_name("TS_get_ldk_version"))) get_ldk_version() {
 }"""
 
         self.hu_struct_file_prefix = """
-import { CommonBase, UInt5 } from './CommonBase.mjs';
+import { CommonBase, UInt5, WitnessVersion, UnqualifiedError } from './CommonBase.mjs';
 import * as bindings from '../bindings.mjs'
 
 """
@@ -622,11 +670,11 @@ import * as bindings from '../bindings.mjs'
         if ty_info.c_ty == "int8_tArray":
             if copy:
                 return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + arr_len + "); FREE(" + arr_name + ")"
+        assert not copy
         if ty_info.c_ty == "ptrArray":
-            return "(void*) " + arr_name + "->elems /* XXX " + arr_name + " leaks */"
+            return "(void*) " + arr_name + "->elems"
         else:
-            assert not copy
-            return arr_name + "->elems /* XXX " + arr_name + " leaks */"
+            return arr_name + "->elems"
     def get_native_arr_elem(self, arr_name, idxc, ty_info):
         assert False # Only called if above is None
     def get_native_arr_ptr_call(self, ty_info):
@@ -637,9 +685,9 @@ import * as bindings from '../bindings.mjs'
         return None
     def cleanup_native_arr_ref_contents(self, arr_name, dest_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
-            return None
+            return "FREE(" + arr_name + ");"
         else:
-            return None
+            return "FREE(" + arr_name + ")"
 
     def map_hu_array_elems(self, arr_name, conv_name, arr_ty, elem_ty):
         if elem_ty.rust_obj == "LDKu5":
@@ -694,8 +742,12 @@ import * as bindings from '../bindings.mjs'
             assert False
 
     def primitive_arr_to_hu(self, mapped_ty, fixed_len, arr_name, conv_name):
-        assert mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t"
-        return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
+        if mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
+            return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
+        elif mapped_ty.c_ty == "uint64_t" or mapped_ty.c_ty == "int64_t":
+            return "const " + conv_name + ": bigint[] = bindings.decodeUint64Array(" + arr_name + ");"
+        else:
+            assert False
 
     def var_decl_statement(self, ty_string, var_name, statement):
         return "const " + var_name + ": " + ty_string + " = " + statement
@@ -753,8 +805,12 @@ import * as bindings from '../bindings.mjs'
         out_c = out_c + "\t}\n"
         out_c = out_c + "}\n"
 
+        # Note that this is *not* marked /* @internal */ as we re-expose it directly in enums/
+        enum_comment_formatted = enum_doc_comment.replace("\n", "\n * ")
         out_typescript = f"""
-/* @internal */
+/**
+ * {enum_comment_formatted}
+ */
 export enum {struct_name} {{
        {out_typescript_enum_fields}
 }}
@@ -977,17 +1033,27 @@ export class {struct_name.replace("LDK","")} extends CommonBase {{
                         out_c = out_c + arg_info.arg_name
                         out_c = out_c + arg_info.ret_conv[1].replace('\n', '\n\t') + "\n"
 
+                fn_suffix = ""
+                if fn_line.ret_ty_info.c_ty == "uint64_t" or fn_line.ret_ty_info.c_ty == "int64_t":
+                    fn_suffix += "b_"
+                else:
+                    fn_suffix += "u_"
+                for arg in fn_line.args_ty:
+                    if arg_info.c_ty == "uint64_t" or arg_info.c_ty == "int64_t":
+                        fn_suffix += "b"
+                    else:
+                        fn_suffix += "u"
                 if fn_line.ret_ty_info.c_ty.endswith("Array"):
                     out_c += "\t" + fn_line.ret_ty_info.c_ty + " ret = (" + fn_line.ret_ty_info.c_ty + ")"
-                    out_c += "js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
+                    out_c += "js_invoke_function_" + fn_suffix + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
                 elif fn_line.ret_ty_info.java_ty == "void":
-                    out_c = out_c + "\tjs_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
+                    out_c = out_c + "\tjs_invoke_function_" + fn_suffix + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
                 elif fn_line.ret_ty_info.java_hu_ty == "string":
-                    out_c = out_c + "\tjstring ret = (jstring)js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
+                    out_c += "\tjstring ret = (jstring)js_invoke_function_" + fn_suffix + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
                 elif not fn_line.ret_ty_info.passed_as_ptr:
-                    out_c = out_c + "\treturn js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
+                    out_c += "\treturn js_invoke_function_" + fn_suffix + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
                 else:
-                    out_c = out_c + "\tuint32_t ret = js_invoke_function_" + str(len(fn_line.args_ty)) + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
+                    out_c += "\tuint32_t ret = js_invoke_function_" + fn_suffix + "(j_calls->instance_ptr, " + str(self.function_ptr_counter)
 
                 self.function_ptrs[self.function_ptr_counter] = (struct_name, fn_line.fn_name)
                 self.function_ptr_counter += 1
@@ -1085,7 +1151,10 @@ export class {struct_name.replace("LDK","")} extends CommonBase {{
         return (out_typescript_bindings, out_typescript_human, out_c)
 
     def trait_struct_inc_refcnt(self, ty_info):
-        return ""
+        base_conv = "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n"
+        base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n"
+        base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_cloned(&" + ty_info.var_name + "_conv);\n}"
+        return base_conv
 
     def map_complex_enum(self, struct_name, variant_list, camel_to_snake, enum_doc_comment):
         bindings_type = struct_name.replace("LDK", "")
@@ -1172,18 +1241,28 @@ export class {struct_name.replace("LDK","")} extends CommonBase {{
 
         hu_name = struct_name.replace("LDKC2Tuple", "TwoTuple").replace("LDKC3Tuple", "ThreeTuple").replace("LDK", "")
         out_opaque_struct_human = f"{self.hu_struct_file_prefix}"
-        if struct_name.startswith("LDKLocked"):
-            out_opaque_struct_human += "/** XXX: DO NOT USE THIS - it remains locked until the GC runs (if that ever happens */"
+        constructor_body = "super(ptr, bindings." + struct_name.replace("LDK","") + "_free);"
+        extra_docs = ""
+        extra_body = ""
+        if struct_name.startswith("LDKLocked") or struct_name.startswith("LDKReadOnly"):
+            extra_docs = "\n * This type represents a lock and MUST BE MANUALLY FREE'd!"
+            constructor_body = 'super(ptr, () => { throw new Error("Locks must be manually freed with free()"); });'
+            extra_body = f"""
+       /** Releases this lock */
+       public free() {{
+               bindings.{struct_name.replace("LDK","")}_free(this.ptr);
+               CommonBase.set_null_skip_free(this);
+       }}"""
         formatted_doc_comment = struct_doc_comment.replace("\n", "\n * ")
         out_opaque_struct_human += f"""
-/**
+/**{extra_docs}
  * {formatted_doc_comment}
  */
 export class {hu_name} extends CommonBase {implementations}{{
        /* @internal */
        public constructor(_dummy: object, ptr: number) {{
-               super(ptr, bindings.{struct_name.replace("LDK","")}_free);
-       }}
+               {constructor_body}
+       }}{extra_body}
 
 """
         self.obj_defined([hu_name], "structs")
@@ -1423,7 +1502,7 @@ export function {method_name}({method_argument_string}): {return_java_ty} {{
         with open(self.outdir + "/bindings.mts", "a") as bindings:
             bindings.write("""
 
-js_invoke = function(obj_ptr: number, fn_id: number, arg1: number, arg2: number, arg3: number, arg4: number, arg5: number, arg6: number, arg7: number, arg8: number, arg9: number, arg10: number) {
+js_invoke = function(obj_ptr: number, fn_id: number, arg1: bigint|number, arg2: bigint|number, arg3: bigint|number, arg4: bigint|number, arg5: bigint|number, arg6: bigint|number, arg7: bigint|number, arg8: bigint|number, arg9: bigint|number, arg10: bigint|number) {
        const weak: WeakRef<object> = js_objs[obj_ptr];
        if (weak == null || weak == undefined) {
                console.error("Got function call on unknown/free'd JS object!");
@@ -1448,5 +1527,7 @@ js_invoke = function(obj_ptr: number, fn_id: number, arg1: number, arg2: number,
                console.error("Got function call on incorrect JS object!");
                throw new Error("Got function call on incorrect JS object!");
        }
-       return fn.value.bind(obj)(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10);
+       const ret = fn.value.bind(obj)(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10);
+       if (ret === undefined || ret === null) return BigInt(0);
+       return BigInt(ret);
 }""")