Fix write method generation w/ upstream changes
[ldk-java] / genbindings.py
index 56d36da4b653c2222155052bd072810a0cd417e6..1ed228a56eb6b8f2e4192ff35d44e65e8240321b 100755 (executable)
@@ -230,9 +230,6 @@ trait_structs = set()
 result_types = set()
 tuple_types = {}
 
-def is_common_base_ext(struct_name):
-    return struct_name in complex_enums or struct_name in opaque_structs or struct_name in trait_structs or struct_name in result_types
-
 var_is_arr_regex = re.compile("\(\*([A-za-z0-9_]*)\)\[([a-z0-9]*)\]")
 var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)")
 java_c_types_none_allowed = True # Unset when we do the real pass that populates the above sets
@@ -737,19 +734,26 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         arg_conv_cleanup = None,
                         ret_conv = ("jclass " + ty_info.var_name + "_conv = " + ty_info.rust_obj + "_to_java(_env, ", ");"),
                         ret_conv_name = ty_info.var_name + "_conv", to_hu_conv = None, to_hu_conv_name = None, from_hu_conv = None)
-                base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";";
+                base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";"
                 if ty_info.rust_obj in trait_structs:
                     if not is_free:
-                        base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n"
-                        base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n"
-                        base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_clone(" + ty_info.var_name + "_conv.this_arg);\n}"
+                        needs_full_clone = not is_free and (not ty_info.is_ptr and not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False
+                        if needs_full_clone and (ty_info.java_hu_ty + "_clone") in clone_fns:
+                            base_conv = base_conv + "\n" + ty_info.var_name + "_conv = " + ty_info.java_hu_ty + "_clone(" + ty_info.var_name + ");"
+                        else:
+                            base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n"
+                            base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n"
+                            base_conv = base_conv + "\t" + ty_info.rust_obj + "_JCalls_clone(" + ty_info.var_name + "_conv.this_arg);\n}"
+                            if needs_full_clone:
+                                base_conv = base_conv + "// Warning: we may need a move here but can't do a full clone!\n"
+
                     else:
                         base_conv = base_conv + "\n" + "FREE((void*)" + ty_info.var_name + ");"
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                         arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                         ret_conv = (ty_info.rust_obj + "* ret = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*ret = ", ";"),
                         ret_conv_name = "(long)ret",
-                        to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, ret);\nret_hu_conv.ptrs_to.add(this);",
+                        to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");\nret_hu_conv.ptrs_to.add(this);",
                         to_hu_conv_name = "ret_hu_conv",
                         from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr", "this.ptrs_to.add(" + ty_info.var_name + ")"))
                 if ty_info.rust_obj != "LDKu8slice":
@@ -834,13 +838,24 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         to_hu_conv_name = ty_info.var_name + "_hu_conv",
                         from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr & ~1", "this.ptrs_to.add(" + ty_info.var_name + ")"))
                 elif ty_info.rust_obj in trait_structs:
-                    return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                        arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
-                        arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
-                        ret_conv = ("long ret_" + ty_info.var_name + " = (long)", ";"), ret_conv_name = "ret_" + ty_info.var_name,
-                        to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, ret);\nret_hu_conv.ptrs_to.add(this);",
-                        to_hu_conv_name = "ret_hu_conv",
-                        from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr", "this.ptrs_to.add(" + ty_info.var_name + ")"))
+                    if ty_info.java_hu_ty + "_clone" in clone_fns:
+                        return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
+                            arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
+                            arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
+                            ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_clone = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n" +
+                                "*" + ty_info.var_name + "_clone = " + ty_info.java_hu_ty + "_clone(", ");"),
+                            ret_conv_name = "(long)" + ty_info.var_name + "_clone",
+                            to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");\nret_hu_conv.ptrs_to.add(this);",
+                            to_hu_conv_name = "ret_hu_conv",
+                            from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr", "this.ptrs_to.add(" + ty_info.var_name + ")"))
+                    else:
+                        return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
+                            arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
+                            arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
+                            ret_conv = ("long ret_" + ty_info.var_name + " = (long)", ";"), ret_conv_name = "ret_" + ty_info.var_name,
+                            to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");\nret_hu_conv.ptrs_to.add(this);",
+                            to_hu_conv_name = "ret_hu_conv",
+                            from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr", "this.ptrs_to.add(" + ty_info.var_name + ")"))
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
                     arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
@@ -992,7 +1007,10 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     out_java_struct.write(info.arg_name)
             out_java_struct.write(");\n")
             if ret_info.to_hu_conv is not None:
-                out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t") + "\n")
+                if ret_info.rust_obj == "LDK" + struct_meth:
+                    out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t").replace("this", ret_info.to_hu_conv_name) + "\n")
+                else:
+                    out_java_struct.write("\t\t" + ret_info.to_hu_conv.replace("\n", "\n\t\t") + "\n")
 
             for info in arg_names:
                 if info.arg_name == "this_ptr" or info.arg_name == "this_arg":
@@ -1326,7 +1344,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                             java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.java_ty + " result = " + ret_ty_info.from_hu_conv[0] + ";\n"
                             if ret_ty_info.from_hu_conv[1] != "":
                                 java_trait_constr = java_trait_constr + "\t\t\t\t" + ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
-                            if is_common_base_ext(ret_ty_info.rust_obj):
+                            if ret_ty_info.rust_obj in result_types:
+                                # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead
                                 java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
                             java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
                         else:
@@ -1395,7 +1414,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 elif fn_line.group(2) == "free":
                     write_c("\t\t.free = " + struct_name + "_JCalls_free,\n")
                 else:
-                    clone_fns.add(struct_name + "_clone")
                     write_c("\t\t.clone = " + struct_name + "_JCalls_clone,\n")
             for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs: