Add a reference low bit for non-opaque types to tag dont-free refs
[ldk-java] / java_strings.py
index 086518ff6b2a05ed59a9baefbbe936d452a5ae1f..e1ea1e868e0b60655949ddb4b073a524aa1deed4 100644 (file)
@@ -54,7 +54,7 @@ public class bindings {
         self.common_base = """package org.ldk.structs;
 import java.util.LinkedList;
 class CommonBase {
-       long ptr;
+       final long ptr;
        LinkedList<Object> ptrs_to = new LinkedList();
        protected CommonBase(long ptr) { this.ptr = ptr; }
        public long _test_only_get_ptr() { return this.ptr; }
@@ -180,11 +180,14 @@ void __wrap_reallocarray(void* ptr, size_t new_sz) {
 }
 
 void __attribute__((destructor)) check_leaks() {
+       size_t alloc_count = 0;
        for (allocation* a = allocation_ll; a != NULL; a = a->next) {
                fprintf(stderr, "%s %p remains:\\n", a->struct_name, a->ptr);
                backtrace_symbols_fd(a->bt, a->bt_len, STDERR_FILENO);
                fprintf(stderr, "\\n\\n");
+               alloc_count++;
        }
+       fprintf(stderr, "%lu allocations remained.\\n", alloc_count);
        DO_ASSERT(allocation_ll == NULL);
 }
 """
@@ -303,7 +306,6 @@ import java.util.Arrays;
 @SuppressWarnings("unchecked") // We correctly assign various generic arrays
 """
         self.c_fn_ty_pfx = "JNIEXPORT "
-        self.c_fn_name_pfx = "JNICALL Java_org_ldk_impl_bindings_"
         self.c_fn_args_pfx = "JNIEnv *env, jclass clz"
         self.file_ext = ".java"
         self.ptr_c_ty = "int64_t"
@@ -362,6 +364,11 @@ import java.util.Arrays;
     def str_ref_to_c_call(self, var_name, str_len):
         return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
 
+    def c_fn_name_define_pfx(self, fn_name, has_args):
+        if has_args:
+            return "JNICALL Java_org_ldk_impl_bindings_" + fn_name.replace("_", "_1") + "(JNIEnv *env, jclass clz, "
+        return "JNICALL Java_org_ldk_impl_bindings_" + fn_name.replace("_", "_1") + "(JNIEnv *env, jclass clz"
+
     def init_str(self):
         res = ""
         for ty in self.c_array_class_caches:
@@ -548,9 +555,6 @@ import java.util.Arrays;
                         if fn_line.ret_ty_info.from_hu_conv[1] != "":
                             java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
                         #if fn_line.ret_ty_info.rust_obj in result_types:
-                        # XXX: We need to handle this in conversion logic so that its cross-language!
-                            # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead
-                        #    java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
                         java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
                     else:
                         java_trait_constr = java_trait_constr + "\t\t\t\treturn ret;\n"
@@ -703,7 +707,7 @@ import java.util.Arrays;
         out_c = out_c + "\treturn ret;\n"
         out_c = out_c + "}\n"
 
-        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1new (" + self.c_fn_args_pfx + ", jobject o"
+        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_define_pfx(struct_name + "_new", True) + "jobject o"
         for var in field_vars:
             if isinstance(var, ConvInfo):
                 out_c = out_c + ", " + var.c_ty + " " + var.arg_name
@@ -792,7 +796,7 @@ import java.util.Arrays;
 
         out_c += (self.c_complex_enum_pfx(struct_name, [x.var_name for x in variant_list], init_meth_jty_strs))
 
-        out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (" + self.c_fn_args_pfx + ", " + self.ptr_c_ty + " ptr) {\n")
+        out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_define_pfx(struct_name + "_ref_from_ptr", True) + self.ptr_c_ty + " ptr) {\n")
         out_c += ("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
         out_c += ("\tswitch(obj->tag) {\n")
         for var in variant_list:
@@ -832,7 +836,7 @@ import java.util.Arrays;
         return out_opaque_struct_human
 
 
-    def map_function(self, argument_types, c_call_string, is_free, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, args_known, has_out_java_struct: bool, type_mapping_generator):
+    def map_function(self, argument_types, c_call_string, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, args_known, type_mapping_generator):
         out_java = ""
         out_c = ""
         out_java_struct = None
@@ -844,44 +848,42 @@ import java.util.Arrays;
         if return_type_info.ret_conv is not None:
             ret_conv_pfx, ret_conv_sfx = return_type_info.ret_conv
         out_java += (" " + method_name + "(")
-        out_c += (" " + self.c_fn_name_pfx + method_name.replace('_', '_1') + "(" + self.c_fn_args_pfx)
+        have_args = len(argument_types) > 1 or (len(argument_types) > 0 and argument_types[0].c_ty != "void")
+        out_c += (" " + self.c_fn_name_define_pfx(method_name, have_args))
 
         for idx, arg_conv_info in enumerate(argument_types):
             if idx != 0:
                 out_java += (", ")
-            if arg_conv_info.c_ty != "void":
                 out_c += (", ")
             if arg_conv_info.c_ty != "void":
                 out_c += (arg_conv_info.c_ty + " " + arg_conv_info.arg_name)
                 out_java += (arg_conv_info.java_ty + " " + arg_conv_info.arg_name)
 
-        if has_out_java_struct:
-            out_java_struct = ""
-            if not args_known:
-                out_java_struct += ("\t// Skipped " + method_name + "\n")
-                has_out_java_struct = False
+        out_java_struct = ""
+        if not args_known:
+            out_java_struct += ("\t// Skipped " + method_name + "\n")
+        else:
+            meth_n = method_name[len(struct_meth) + 1:]
+            if not takes_self:
+                out_java_struct += (
+                    "\tpublic static " + return_type_info.java_hu_ty + " constructor_" + meth_n + "(")
             else:
-                meth_n = method_name[len(struct_meth) + 1:]
-                if not takes_self:
-                    out_java_struct += (
-                        "\tpublic static " + return_type_info.java_hu_ty + " constructor_" + meth_n + "(")
-                else:
-                    out_java_struct += ("\tpublic " + return_type_info.java_hu_ty + " " + meth_n + "(")
-                for idx, arg in enumerate(argument_types):
-                    if idx != 0:
-                        if not takes_self or idx > 1:
-                            out_java_struct += (", ")
-                    elif takes_self:
-                        continue
-                    if arg.java_ty != "void":
-                        if arg.arg_name in default_constructor_args:
-                            for explode_idx, explode_arg in enumerate(default_constructor_args[arg.arg_name]):
-                                if explode_idx != 0:
-                                    out_java_struct += (", ")
-                                out_java_struct += (
-                                    explode_arg.java_hu_ty + " " + arg.arg_name + "_" + explode_arg.arg_name)
-                        else:
-                            out_java_struct += (arg.java_hu_ty + " " + arg.arg_name)
+                out_java_struct += ("\tpublic " + return_type_info.java_hu_ty + " " + meth_n + "(")
+            for idx, arg in enumerate(argument_types):
+                if idx != 0:
+                    if not takes_self or idx > 1:
+                        out_java_struct += (", ")
+                elif takes_self:
+                    continue
+                if arg.java_ty != "void":
+                    if arg.arg_name in default_constructor_args:
+                        for explode_idx, explode_arg in enumerate(default_constructor_args[arg.arg_name]):
+                            if explode_idx != 0:
+                                out_java_struct += (", ")
+                            out_java_struct += (
+                                explode_arg.java_hu_ty + " " + arg.arg_name + "_" + explode_arg.arg_name)
+                    else:
+                        out_java_struct += (arg.java_hu_ty + " " + arg.arg_name)
         out_java += (");\n")
         out_c += (") {\n")
         if out_java_struct is not None:
@@ -920,7 +922,7 @@ import java.util.Arrays;
             out_c += ("\n\treturn ret_val;")
         out_c += ("\n}\n\n")
 
-        if has_out_java_struct:
+        if args_known:
             out_java_struct += ("\t\t")
             if return_type_info.java_ty != "void":
                 out_java_struct += (return_type_info.java_ty + " ret = ")