Stop masking the owned bit on a freshly-cloned object
[ldk-java] / genbindings.py
index d96447bd37eb36a5a47692c26ced8894389f2166..56d36da4b653c2222155052bd072810a0cd417e6 100755 (executable)
@@ -115,13 +115,14 @@ void* __wrap_calloc(size_t nmemb, size_t len) {
        return res;
 }
 void __wrap_free(void* ptr) {
+       if (ptr == NULL) return;
        alloc_freed(ptr);
        __real_free(ptr);
 }
 
 void* __real_realloc(void* ptr, size_t newlen);
 void* __wrap_realloc(void* ptr, size_t len) {
-       alloc_freed(ptr);
+       if (ptr != NULL) alloc_freed(ptr);
        void* res = __real_realloc(ptr, len);
        new_allocation(res, "realloc call");
        return res;
@@ -161,6 +162,7 @@ class TypeInfo:
         self.arr_access = arr_access
         self.subty = subty
         self.pass_by_ref = is_ptr
+        self.requires_clone = None
 
 class ConvInfo:
     def __init__(self, ty_info, arg_name, arg_conv, arg_conv_name, arg_conv_cleanup, ret_conv, ret_conv_name, to_hu_conv, to_hu_conv_name, from_hu_conv):
@@ -303,6 +305,11 @@ def java_c_types(fn_arg, ret_arr_len):
             rust_obj = "LDKCVec_u8Z"
             assert var_is_arr_regex.match(fn_arg[8:])
         arr_access = "data"
+    elif fn_arg.startswith("LDKTransaction"):
+        fn_arg = "uint8_t (*" + fn_arg[15:] + ")[datalen]"
+        rust_obj = "LDKTransaction"
+        assert var_is_arr_regex.match(fn_arg[8:])
+        arr_access = "data"
     elif fn_arg.startswith("LDKCVecTempl_") or fn_arg.startswith("LDKCVec_"):
         is_ptr = False
         if "*" in fn_arg:
@@ -518,14 +525,20 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 arg_conv_cleanup = None
                 if not arr_len.isdigit():
                     arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
-                    arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = (*_env)->GetByteArrayElements (_env, " + arr_name + ", NULL);\n"
-                    arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = (*_env)->GetArrayLength (_env, " + arr_name + ");"
-                    arg_conv_cleanup = "(*_env)->ReleaseByteArrayElements(_env, " + arr_name + ", (int8_t*)" + arr_name + "_ref." + ty_info.arr_access + ", 0);"
+                    arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = (*_env)->GetArrayLength (_env, " + arr_name + ");\n"
+                    if (not ty_info.is_ptr or not holds_ref) and ty_info.rust_obj != "LDKu8slice":
+                        arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = MALLOC(" + arr_name + "_ref." + arr_len + ", \"" + ty_info.rust_obj + " Bytes\");\n"
+                        arg_conv = arg_conv + "(*_env)->GetByteArrayRegion(_env, " + arr_name + ", 0, " + arr_name + "_ref." + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
+                    else:
+                        arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = (*_env)->GetByteArrayElements (_env, " + arr_name + ", NULL);"
+                        arg_conv_cleanup = "(*_env)->ReleaseByteArrayElements(_env, " + arr_name + ", (int8_t*)" + arr_name + "_ref." + ty_info.arr_access + ", 0);"
+                    if ty_info.rust_obj == "LDKTransaction":
+                        arg_conv = arg_conv + "\n" + arr_name + "_ref.data_is_owned = " + str(holds_ref).lower() + ";"
                     ret_conv = (ty_info.rust_obj + " " + arr_name + "_var = ", "")
                     ret_conv = (ret_conv[0], ";\njbyteArray " + arr_name + "_arr = (*_env)->NewByteArray(_env, " + arr_name + "_var." + arr_len + ");\n")
                     ret_conv = (ret_conv[0], ret_conv[1] + "(*_env)->SetByteArrayRegion(_env, " + arr_name + "_arr, 0, " + arr_name + "_var." + arr_len + ", " + arr_name + "_var." + ty_info.arr_access + ");")
-                    if not holds_ref and ty_info.rust_obj == "LDKCVec_u8Z":
-                        ret_conv = (ret_conv[0], ret_conv[1] + "\nCVec_u8Z_free(" + arr_name + "_var);")
+                    if not holds_ref and ty_info.rust_obj != "LDKu8slice":
+                        ret_conv = (ret_conv[0], ret_conv[1] + "\n" + ty_info.rust_obj.replace("LDK", "") + "_free(" + arr_name + "_var);")
                 elif ty_info.rust_obj is not None:
                     arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
                     arg_conv = arg_conv + "CHECK((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
@@ -545,7 +558,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 conv_name = "arr_conv_" + str(len(ty_info.java_hu_ty))
                 idxc = chr(ord('a') + (len(ty_info.java_hu_ty) % 26))
                 ty_info.subty.var_name = conv_name
-                ty_info.subty.passed_as_ptr = False
+                ty_info.subty.requires_clone = not ty_info.is_ptr or not holds_ref
                 subty = map_type_with_info(ty_info.subty, False, None, is_free, holds_ref)
                 if arr_name == "":
                     arr_name = "arg"
@@ -670,11 +683,11 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             if ty_info.rust_obj in opaque_structs:
                 opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n"
                 opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = (void*)(" + ty_info.var_name + " & (~1));\n"
-                if holds_ref:
+                if ty_info.is_ptr and holds_ref:
                     opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = false;"
                 else:
                     opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = (" + ty_info.var_name + " & 1) || (" + ty_info.var_name + " == 0);"
-                if not ty_info.is_ptr and not is_free and not ty_info.pass_by_ref and not holds_ref:
+                if not is_free and (not ty_info.is_ptr or not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False:
                     if (ty_info.java_hu_ty + "_clone") in clone_fns:
                         # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
                         opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
@@ -691,7 +704,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
 
                 opaque_ret_conv_suf = opaque_ret_conv_suf + "CHECK((((long)" + ty_info.var_name + "_var.inner) & 1) == 0); // We rely on a free low bit, malloc guarantees this.\n"
                 opaque_ret_conv_suf = opaque_ret_conv_suf + "CHECK((((long)&" + ty_info.var_name + "_var) & 1) == 0); // We rely on a free low bit, pointer alignment guarantees this.\n"
-                if holds_ref or ty_info.is_ptr:
+                if holds_ref:
                     opaque_ret_conv_suf = opaque_ret_conv_suf + "long " + ty_info.var_name + "_ref = (long)" + ty_info.var_name + "_var.inner & ~1;"
                 else:
                     opaque_ret_conv_suf = opaque_ret_conv_suf + "long " + ty_info.var_name + "_ref = (long)" + ty_info.var_name + "_var.inner;\n"
@@ -739,7 +752,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, ret);\nret_hu_conv.ptrs_to.add(this);",
                         to_hu_conv_name = "ret_hu_conv",
                         from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr", "this.ptrs_to.add(" + ty_info.var_name + ")"))
-                if ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKTransaction":
+                if ty_info.rust_obj != "LDKu8slice":
                     # Don't bother free'ing slices passed in - Rust doesn't auto-free the
                     # underlying unlike Vecs, and it gives Java more freedom.
                     base_conv = base_conv + "\nFREE((void*)" + ty_info.var_name + ");";
@@ -747,8 +760,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";")
                     if not holds_ref:
                         ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_copy = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n", "")
-                        if not ty_info.passed_as_ptr:
-                            # We use passed_as_ptr as a flag to detect if we're copying a Vec.
+                        if ty_info.requires_clone == True: # Set in object array mapping
                             if (ty_info.java_hu_ty + "_clone") in clone_fns:
                                 ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = " + ty_info.java_hu_ty + "_clone(&", ");\n")
                             else:
@@ -805,20 +817,12 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         to_hu_conv = to_hu_conv_pfx + to_hu_conv_sfx + ");", to_hu_conv_name = ty_info.var_name + "_conv", from_hu_conv = (from_hu_conv + ")", ""))
 
                 # The manually-defined types - TxOut and Transaction
-                assert ty_info.rust_obj == "LDKTransaction" or ty_info.rust_obj == "LDKTxOut"
-                if ty_info.rust_obj == "LDKTransaction":
-                    return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                        arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
-                        ret_conv = ("LDKTransaction *" + ty_info.var_name + "_copy = MALLOC(sizeof(LDKTransaction), \"LDKTransaction\");\n*" + ty_info.var_name + "_copy = ", ";\nlong " + ty_info.var_name + "_ref = (long)" + ty_info.var_name + "_copy;"),
-                        ret_conv_name = ty_info.var_name + "_ref",
-                        to_hu_conv = ty_info.java_hu_ty + " " + ty_info.var_name + "_conv = new " +ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");",
-                        to_hu_conv_name = ty_info.var_name + "_conv", from_hu_conv = (ty_info.var_name + ".ptr", ""))
-                elif ty_info.rust_obj == "LDKTxOut":
-                    return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                        arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
-                        ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";"), ret_conv_name = ty_info.var_name + "_ref",
-                        to_hu_conv = ty_info.java_hu_ty + " " + ty_info.var_name + "_conv = new " +ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");",
-                        to_hu_conv_name = ty_info.var_name + "_conv", from_hu_conv = (ty_info.var_name + ".ptr", ""))
+                assert ty_info.rust_obj == "LDKTxOut"
+                return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
+                    arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
+                    ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";"), ret_conv_name = ty_info.var_name + "_ref",
+                    to_hu_conv = ty_info.java_hu_ty + " " + ty_info.var_name + "_conv = new " +ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");",
+                    to_hu_conv_name = ty_info.var_name + "_conv", from_hu_conv = (ty_info.var_name + ".ptr", ""))
             elif ty_info.is_ptr:
                 assert(not is_free)
                 if ty_info.rust_obj in complex_enums:
@@ -870,7 +874,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 out_java.write(", ")
             if arg != "void":
                 write_c(", ")
-            arg_conv_info = map_type(arg, False, None, is_free, False)
+            arg_conv_info = map_type(arg, False, None, is_free, True)
             if arg_conv_info.c_ty != "void":
                 arg_conv_info.print_ty()
                 arg_conv_info.print_name()
@@ -880,7 +884,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 if arg_conv_info.rust_obj in constructor_fns:
                     assert not is_free
                     for explode_arg in constructor_fns[arg_conv_info.rust_obj].split(','):
-                        explode_arg_conv = map_type(explode_arg, False, None, False, False)
+                        explode_arg_conv = map_type(explode_arg, False, None, False, True)
                         if explode_arg_conv.c_ty == "void":
                             # We actually want to handle this case, but for now its only used in NetGraphMsgHandler::new()
                             # which ends up resulting in a redundant constructor - both without arguments for the NetworkGraph.
@@ -1792,16 +1796,6 @@ class CommonBase {
                         out_java_struct.write("\tlong to_c_ptr() { return 0; }\n")
                         # TODO: TxOut body
                         out_java_struct.write("}")
-                elif struct_name == "LDKTransaction":
-                    with open(sys.argv[3] + "/structs/Transaction.java", "w") as out_java_struct:
-                        out_java_struct.write(hu_struct_file_prefix)
-                        out_java_struct.write("public class Transaction extends CommonBase{\n")
-                        out_java_struct.write("\tTransaction(java.lang.Object _dummy, long ptr) { super(ptr); }\n")
-                        out_java_struct.write("\tpublic Transaction(byte[] data) { super(bindings.new_txpointer_copy_data(data)); }\n")
-                        out_java_struct.write("\t@Override public void finalize() throws Throwable { super.finalize(); bindings.txpointer_free(ptr); }\n")
-                        out_java_struct.write("\tpublic byte[] get_contents() { return bindings.txpointer_get_buffer(ptr); }\n")
-                        # TODO: Transaction body
-                        out_java_struct.write("}")
                 else:
                     pass # Everything remaining is a byte[] or some form
                 cur_block_obj = None