Rework holds_ref and clone logic somewhat
[ldk-java] / genbindings.py
index 3c9235118f1e3fe4217563c70aabad99f5709aed..0356da087de2853810189b8ea2a20aa8d56dbb3d 100755 (executable)
@@ -115,13 +115,14 @@ void* __wrap_calloc(size_t nmemb, size_t len) {
        return res;
 }
 void __wrap_free(void* ptr) {
+       if (ptr == NULL) return;
        alloc_freed(ptr);
        __real_free(ptr);
 }
 
 void* __real_realloc(void* ptr, size_t newlen);
 void* __wrap_realloc(void* ptr, size_t len) {
-       alloc_freed(ptr);
+       if (ptr != NULL) alloc_freed(ptr);
        void* res = __real_realloc(ptr, len);
        new_allocation(res, "realloc call");
        return res;
@@ -161,6 +162,7 @@ class TypeInfo:
         self.arr_access = arr_access
         self.subty = subty
         self.pass_by_ref = is_ptr
+        self.requires_clone = None
 
 class ConvInfo:
     def __init__(self, ty_info, arg_name, arg_conv, arg_conv_name, arg_conv_cleanup, ret_conv, ret_conv_name, to_hu_conv, to_hu_conv_name, from_hu_conv):
@@ -545,7 +547,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 conv_name = "arr_conv_" + str(len(ty_info.java_hu_ty))
                 idxc = chr(ord('a') + (len(ty_info.java_hu_ty) % 26))
                 ty_info.subty.var_name = conv_name
-                ty_info.subty.passed_as_ptr = False
+                ty_info.subty.requires_clone = not ty_info.is_ptr or not holds_ref
                 subty = map_type_with_info(ty_info.subty, False, None, is_free, holds_ref)
                 if arr_name == "":
                     arr_name = "arg"
@@ -670,11 +672,11 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             if ty_info.rust_obj in opaque_structs:
                 opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n"
                 opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = (void*)(" + ty_info.var_name + " & (~1));\n"
-                if holds_ref:
+                if ty_info.is_ptr and holds_ref:
                     opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = false;"
                 else:
                     opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = (" + ty_info.var_name + " & 1) || (" + ty_info.var_name + " == 0);"
-                if not ty_info.is_ptr and not is_free and not ty_info.pass_by_ref and not holds_ref:
+                if not is_free and (not ty_info.is_ptr or not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False:
                     if (ty_info.java_hu_ty + "_clone") in clone_fns:
                         # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
                         opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
@@ -747,8 +749,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";")
                     if not holds_ref:
                         ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_copy = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n", "")
-                        if not ty_info.passed_as_ptr:
-                            # We use passed_as_ptr as a flag to detect if we're copying a Vec.
+                        if ty_info.requires_clone == True: # Set in object array mapping
                             if (ty_info.java_hu_ty + "_clone") in clone_fns:
                                 ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = " + ty_info.java_hu_ty + "_clone(&", ");\n")
                             else:
@@ -870,7 +871,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 out_java.write(", ")
             if arg != "void":
                 write_c(", ")
-            arg_conv_info = map_type(arg, False, None, is_free, False)
+            arg_conv_info = map_type(arg, False, None, is_free, True)
             if arg_conv_info.c_ty != "void":
                 arg_conv_info.print_ty()
                 arg_conv_info.print_name()
@@ -880,7 +881,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 if arg_conv_info.rust_obj in constructor_fns:
                     assert not is_free
                     for explode_arg in constructor_fns[arg_conv_info.rust_obj].split(','):
-                        explode_arg_conv = map_type(explode_arg, False, None, False, False)
+                        explode_arg_conv = map_type(explode_arg, False, None, False, True)
                         if explode_arg_conv.c_ty == "void":
                             # We actually want to handle this case, but for now its only used in NetGraphMsgHandler::new()
                             # which ends up resulting in a redundant constructor - both without arguments for the NetworkGraph.
@@ -1224,15 +1225,17 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             out_java_trait.write("\t\tif (ptr != 0) { bindings." + struct_name.replace("LDK","") + "_free(ptr); } super.finalize();\n")
             out_java_trait.write("\t}\n\n")
 
-            java_trait_constr = "\tpublic " + struct_name.replace("LDK", "") + "(" + struct_name.replace("LDK", "") + "Interface arg"
+            java_trait_constr = "\tprivate static class " + struct_name + "Holder { " + struct_name.replace("LDK", "") + " held; }\n"
+            java_trait_constr = java_trait_constr + "\tpublic static " + struct_name.replace("LDK", "") + " new_impl(" + struct_name.replace("LDK", "") + "Interface arg"
             for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     # Ideally we'd be able to take any instance of the interface, but our C code can only represent
                     # Java-implemented version, so we require users pass a Java implementation here :/
-                    java_trait_constr = java_trait_constr + ", " + var_line.group(1).replace("LDK", "") + "." + var_line.group(1).replace("LDK", "") + "Interface " + var_line.group(2)
+                    java_trait_constr = java_trait_constr + ", " + var_line.group(1).replace("LDK", "") + "." + var_line.group(1).replace("LDK", "") + "Interface " + var_line.group(2) + "_impl"
                 else:
                     java_trait_constr = java_trait_constr + ", " + field_var_convs[idx].java_hu_ty + " " + var_line.group(2)
-            java_trait_constr = java_trait_constr + ") {\n\t\tthis(new bindings." + struct_name + "() {\n"
+            java_trait_constr = java_trait_constr + ") {\n\t\tfinal " + struct_name + "Holder impl_holder = new " + struct_name + "Holder();\n"
+            java_trait_constr = java_trait_constr + "\t\timpl_holder.held = new " + struct_name.replace("LDK", "") + "(new bindings." + struct_name + "() {\n"
             out_java_trait.write("\tpublic static interface " + struct_name.replace("LDK", "") + "Interface {\n")
             out_java.write("\tpublic interface " + struct_name + " {\n")
             java_meths = []
@@ -1319,7 +1322,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         if ret_ty_info.from_hu_conv is not None:
                             java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.java_ty + " result = " + ret_ty_info.from_hu_conv[0] + ";\n"
                             if ret_ty_info.from_hu_conv[1] != "":
-                                java_trait_constr = java_trait_constr + "\t\t\t\t//TODO: May need to call: " + ret_ty_info.from_hu_conv[1] + ";\n"
+                                java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
                             if is_common_base_ext(ret_ty_info.rust_obj):
                                 java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
                             java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
@@ -1338,11 +1341,11 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             java_trait_constr = java_trait_constr + "\t\t}"
             for var_line in field_var_lines:
                 if var_line.group(1) in trait_structs:
-                    java_trait_constr = java_trait_constr + ", new " + var_line.group(2) + "(" + var_line.group(2) + ").bindings_instance"
+                    java_trait_constr = java_trait_constr + ", " + var_line.group(2) + ".new_impl(" + var_line.group(2) + "_impl).bindings_instance"
                 else:
                     java_trait_constr = java_trait_constr + ", " + var_line.group(2)
             out_java_trait.write("\t}\n")
-            out_java_trait.write(java_trait_constr + ");\n\t}\n")
+            out_java_trait.write(java_trait_constr + ");\n\t\treturn impl_holder.held;\n\t}\n")
 
             # Write out a clone function whether we need one or not, as we use them in moving to rust
             write_c("static void* " + struct_name + "_JCalls_clone(const void* this_arg) {\n")