Update HumanObjectPeerTest for upstream changes, to reload nodes
[ldk-java] / java_strings.py
index d9b7b4b78a3f0a1a0c81d590caef9885fafe643d..bc3c73cf05f2ea841bd68123264e8287c0d967fb 100644 (file)
@@ -51,6 +51,13 @@ public class bindings {
 
         self.bindings_footer = "}\n"
 
+        self.util_fn_pfx = """package org.ldk.structs;
+import org.ldk.impl.bindings;
+import java.util.Arrays;
+
+public class UtilMethods {
+"""
+        self.util_fn_sfx = "}"
         self.common_base = """package org.ldk.structs;
 import java.util.LinkedList;
 class CommonBase {
@@ -67,6 +74,7 @@ class CommonBase {
 #include <string.h>
 #include <stdatomic.h>
 #include <stdlib.h>
+
 """
 
         if not DEBUG:
@@ -102,16 +110,18 @@ typedef struct allocation {
        const char* struct_name;
        void* bt[BT_MAX];
        int bt_len;
+       size_t alloc_len;
 } allocation;
 static allocation* allocation_ll = NULL;
 
 void* __real_malloc(size_t len);
 void* __real_calloc(size_t nmemb, size_t len);
-static void new_allocation(void* res, const char* struct_name) {
+static void new_allocation(void* res, const char* struct_name, size_t len) {
        allocation* new_alloc = __real_malloc(sizeof(allocation));
        new_alloc->ptr = res;
        new_alloc->struct_name = struct_name;
        new_alloc->bt_len = backtrace(new_alloc->bt, BT_MAX);
+       new_alloc->alloc_len = len;
        DO_ASSERT(mtx_lock(&allocation_mtx) == thrd_success);
        new_alloc->next = allocation_ll;
        allocation_ll = new_alloc;
@@ -119,7 +129,7 @@ static void new_allocation(void* res, const char* struct_name) {
 }
 static void* MALLOC(size_t len, const char* struct_name) {
        void* res = __real_malloc(len);
-       new_allocation(res, struct_name);
+       new_allocation(res, struct_name, len);
        return res;
 }
 void __real_free(void* ptr);
@@ -152,12 +162,12 @@ static void FREE(void* ptr) {
 
 void* __wrap_malloc(size_t len) {
        void* res = __real_malloc(len);
-       new_allocation(res, "malloc call");
+       new_allocation(res, "malloc call", len);
        return res;
 }
 void* __wrap_calloc(size_t nmemb, size_t len) {
        void* res = __real_calloc(nmemb, len);
-       new_allocation(res, "calloc call");
+       new_allocation(res, "calloc call", len);
        return res;
 }
 void __wrap_free(void* ptr) {
@@ -170,7 +180,7 @@ void* __real_realloc(void* ptr, size_t newlen);
 void* __wrap_realloc(void* ptr, size_t len) {
        if (ptr != NULL) alloc_freed(ptr);
        void* res = __real_realloc(ptr, len);
-       new_allocation(res, "realloc call");
+       new_allocation(res, "realloc call", len);
        return res;
 }
 void __wrap_reallocarray(void* ptr, size_t new_sz) {
@@ -179,11 +189,16 @@ void __wrap_reallocarray(void* ptr, size_t new_sz) {
 }
 
 void __attribute__((destructor)) check_leaks() {
+       size_t alloc_count = 0;
+       size_t alloc_size = 0;
        for (allocation* a = allocation_ll; a != NULL; a = a->next) {
                fprintf(stderr, "%s %p remains:\\n", a->struct_name, a->ptr);
                backtrace_symbols_fd(a->bt, a->bt_len, STDERR_FILENO);
                fprintf(stderr, "\\n\\n");
+               alloc_count++;
+               alloc_size += a->alloc_len;
        }
+       fprintf(stderr, "%lu allocations remained for %lu bytes.\\n", alloc_count, alloc_size);
        DO_ASSERT(allocation_ll == NULL);
 }
 """
@@ -281,6 +296,15 @@ _Static_assert(sizeof(void*) <= 8, "Pointers must fit into 64 bits");
 typedef jlongArray int64_tArray;
 typedef jbyteArray int8_tArray;
 
+static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
+       // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
+       char* err_buf = MALLOC(len + 1, "str conv buf");
+       memcpy(err_buf, chars, len);
+       err_buf[len] = 0;
+       jstring err_conv = (*env)->NewStringUTF(env, chars);
+       FREE(err_buf);
+       return err_conv;
+}
 """
 
         self.hu_struct_file_prefix = """package org.ldk.structs;
@@ -293,13 +317,11 @@ import java.util.Arrays;
 @SuppressWarnings("unchecked") // We correctly assign various generic arrays
 """
         self.c_fn_ty_pfx = "JNIEXPORT "
-        self.c_fn_name_pfx = "JNICALL Java_org_ldk_impl_bindings_"
         self.c_fn_args_pfx = "JNIEnv *env, jclass clz"
         self.file_ext = ".java"
         self.ptr_c_ty = "int64_t"
         self.ptr_native_ty = "long"
         self.result_c_ty = "jclass"
-        self.owned_str_to_c_call = ("(*env)->NewStringUTF(env, ", ")")
         self.ptr_arr = "jobjectArray"
         self.get_native_arr_len_call = ("(*env)->GetArrayLength(env, ", ")")
 
@@ -350,6 +372,14 @@ import java.util.Arrays;
         else:
             return "(*env)->Release" + ty_info.java_ty.strip("[]").title() + "ArrayElements(env, " + arr_name + ", " + dest_name + ", 0)"
 
+    def str_ref_to_c_call(self, var_name, str_len):
+        return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
+
+    def c_fn_name_define_pfx(self, fn_name, has_args):
+        if has_args:
+            return "JNICALL Java_org_ldk_impl_bindings_" + fn_name.replace("_", "_1") + "(JNIEnv *env, jclass clz, "
+        return "JNICALL Java_org_ldk_impl_bindings_" + fn_name.replace("_", "_1") + "(JNIEnv *env, jclass clz"
+
     def init_str(self):
         res = ""
         for ty in self.c_array_class_caches:
@@ -465,7 +495,7 @@ import java.util.Arrays;
         for var in field_vars:
             if isinstance(var, ConvInfo):
                 if var.from_hu_conv is not None and var.from_hu_conv[1] != "":
-                    out_java_trait = out_java_trait + "\t\t" + var.from_hu_conv[1] + ";\n"
+                    out_java_trait = out_java_trait + "\t\t" + var.from_hu_conv[1].replace("\n", "\n\t\t") + ";\n"
             else:
                 out_java_trait = out_java_trait + "\t\tthis.ptrs_to.add(" + var[1] + ");\n"
         out_java_trait = out_java_trait + "\t\tthis.bindings_instance = arg;\n"
@@ -536,9 +566,6 @@ import java.util.Arrays;
                         if fn_line.ret_ty_info.from_hu_conv[1] != "":
                             java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
                         #if fn_line.ret_ty_info.rust_obj in result_types:
-                        # XXX: We need to handle this in conversion logic so that its cross-language!
-                            # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead
-                        #    java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
                         java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
                     else:
                         java_trait_constr = java_trait_constr + "\t\t\t\treturn ret;\n"
@@ -691,7 +718,7 @@ import java.util.Arrays;
         out_c = out_c + "\treturn ret;\n"
         out_c = out_c + "}\n"
 
-        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1new (" + self.c_fn_args_pfx + ", jobject o"
+        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_define_pfx(struct_name + "_new", True) + "jobject o"
         for var in field_vars:
             if isinstance(var, ConvInfo):
                 out_c = out_c + ", " + var.c_ty + " " + var.arg_name
@@ -780,7 +807,7 @@ import java.util.Arrays;
 
         out_c += (self.c_complex_enum_pfx(struct_name, [x.var_name for x in variant_list], init_meth_jty_strs))
 
-        out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (" + self.c_fn_args_pfx + ", " + self.ptr_c_ty + " ptr) {\n")
+        out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_define_pfx(struct_name + "_ref_from_ptr", True) + self.ptr_c_ty + " ptr) {\n")
         out_c += ("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
         out_c += ("\tswitch(obj->tag) {\n")
         for var in variant_list:
@@ -820,7 +847,7 @@ import java.util.Arrays;
         return out_opaque_struct_human
 
 
-    def map_function(self, argument_types, c_call_string, is_free, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, args_known, has_out_java_struct: bool, type_mapping_generator):
+    def map_function(self, argument_types, c_call_string, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, args_known, type_mapping_generator):
         out_java = ""
         out_c = ""
         out_java_struct = None
@@ -832,44 +859,42 @@ import java.util.Arrays;
         if return_type_info.ret_conv is not None:
             ret_conv_pfx, ret_conv_sfx = return_type_info.ret_conv
         out_java += (" " + method_name + "(")
-        out_c += (" " + self.c_fn_name_pfx + method_name.replace('_', '_1') + "(" + self.c_fn_args_pfx)
+        have_args = len(argument_types) > 1 or (len(argument_types) > 0 and argument_types[0].c_ty != "void")
+        out_c += (" " + self.c_fn_name_define_pfx(method_name, have_args))
 
         for idx, arg_conv_info in enumerate(argument_types):
             if idx != 0:
                 out_java += (", ")
-            if arg_conv_info.c_ty != "void":
                 out_c += (", ")
             if arg_conv_info.c_ty != "void":
                 out_c += (arg_conv_info.c_ty + " " + arg_conv_info.arg_name)
                 out_java += (arg_conv_info.java_ty + " " + arg_conv_info.arg_name)
 
-        if has_out_java_struct:
-            out_java_struct = ""
-            if not args_known:
-                out_java_struct += ("\t// Skipped " + method_name + "\n")
-                has_out_java_struct = False
+        out_java_struct = ""
+        if not args_known:
+            out_java_struct += ("\t// Skipped " + method_name + "\n")
+        else:
+            meth_n = method_name[len(struct_meth) + 1:]
+            if not takes_self:
+                out_java_struct += (
+                    "\tpublic static " + return_type_info.java_hu_ty + " constructor_" + meth_n + "(")
             else:
-                meth_n = method_name[len(struct_meth) + 1:]
-                if not takes_self:
-                    out_java_struct += (
-                        "\tpublic static " + return_type_info.java_hu_ty + " constructor_" + meth_n + "(")
-                else:
-                    out_java_struct += ("\tpublic " + return_type_info.java_hu_ty + " " + meth_n + "(")
-                for idx, arg in enumerate(argument_types):
-                    if idx != 0:
-                        if not takes_self or idx > 1:
-                            out_java_struct += (", ")
-                    elif takes_self:
-                        continue
-                    if arg.java_ty != "void":
-                        if arg.arg_name in default_constructor_args:
-                            for explode_idx, explode_arg in enumerate(default_constructor_args[arg.arg_name]):
-                                if explode_idx != 0:
-                                    out_java_struct += (", ")
-                                out_java_struct += (
-                                    explode_arg.java_hu_ty + " " + arg.arg_name + "_" + explode_arg.arg_name)
-                        else:
-                            out_java_struct += (arg.java_hu_ty + " " + arg.arg_name)
+                out_java_struct += ("\tpublic " + return_type_info.java_hu_ty + " " + meth_n + "(")
+            for idx, arg in enumerate(argument_types):
+                if idx != 0:
+                    if not takes_self or idx > 1:
+                        out_java_struct += (", ")
+                elif takes_self:
+                    continue
+                if arg.java_ty != "void":
+                    if arg.arg_name in default_constructor_args:
+                        for explode_idx, explode_arg in enumerate(default_constructor_args[arg.arg_name]):
+                            if explode_idx != 0:
+                                out_java_struct += (", ")
+                            out_java_struct += (
+                                explode_arg.java_hu_ty + " " + arg.arg_name + "_" + explode_arg.arg_name)
+                    else:
+                        out_java_struct += (arg.java_hu_ty + " " + arg.arg_name)
         out_java += (");\n")
         out_c += (") {\n")
         if out_java_struct is not None:
@@ -908,7 +933,7 @@ import java.util.Arrays;
             out_c += ("\n\treturn ret_val;")
         out_c += ("\n}\n\n")
 
-        if has_out_java_struct:
+        if args_known:
             out_java_struct += ("\t\t")
             if return_type_info.java_ty != "void":
                 out_java_struct += (return_type_info.java_ty + " ret = ")
@@ -955,9 +980,9 @@ import java.util.Arrays;
                 elif info.from_hu_conv is not None and info.from_hu_conv[1] != "":
                     if not takes_self and return_type_info.to_hu_conv_name is not None:
                         out_java_struct += (
-                            "\t\t" + info.from_hu_conv[1].replace("this", return_type_info.to_hu_conv_name) + ";\n")
+                            "\t\t" + info.from_hu_conv[1].replace("this", return_type_info.to_hu_conv_name).replace("\n", "\n\t\t") + ";\n")
                     else:
-                        out_java_struct += ("\t\t" + info.from_hu_conv[1] + ";\n")
+                        out_java_struct += ("\t\t" + info.from_hu_conv[1].replace("\n", "\n\t\t") + ";\n")
 
             if return_type_info.to_hu_conv_name is not None:
                 out_java_struct += ("\t\treturn " + return_type_info.to_hu_conv_name + ";\n")