Clone objects returned from jcalls before returning to rust
[ldk-java] / java_strings.py
index d9b7b4b78a3f0a1a0c81d590caef9885fafe643d..c83ee8269257f6a5c66e9f3bbe15cdb26cc3f98b 100644 (file)
@@ -54,7 +54,7 @@ public class bindings {
         self.common_base = """package org.ldk.structs;
 import java.util.LinkedList;
 class CommonBase {
-       long ptr;
+       final long ptr;
        LinkedList<Object> ptrs_to = new LinkedList();
        protected CommonBase(long ptr) { this.ptr = ptr; }
        public long _test_only_get_ptr() { return this.ptr; }
@@ -67,6 +67,7 @@ class CommonBase {
 #include <string.h>
 #include <stdatomic.h>
 #include <stdlib.h>
+
 """
 
         if not DEBUG:
@@ -281,6 +282,15 @@ _Static_assert(sizeof(void*) <= 8, "Pointers must fit into 64 bits");
 typedef jlongArray int64_tArray;
 typedef jbyteArray int8_tArray;
 
+static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
+       // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
+       char* err_buf = MALLOC(len + 1, "str conv buf");
+       memcpy(err_buf, chars, len);
+       err_buf[len] = 0;
+       jstring err_conv = (*env)->NewStringUTF(env, chars);
+       FREE(err_buf);
+       return err_conv;
+}
 """
 
         self.hu_struct_file_prefix = """package org.ldk.structs;
@@ -293,13 +303,11 @@ import java.util.Arrays;
 @SuppressWarnings("unchecked") // We correctly assign various generic arrays
 """
         self.c_fn_ty_pfx = "JNIEXPORT "
-        self.c_fn_name_pfx = "JNICALL Java_org_ldk_impl_bindings_"
         self.c_fn_args_pfx = "JNIEnv *env, jclass clz"
         self.file_ext = ".java"
         self.ptr_c_ty = "int64_t"
         self.ptr_native_ty = "long"
         self.result_c_ty = "jclass"
-        self.owned_str_to_c_call = ("(*env)->NewStringUTF(env, ", ")")
         self.ptr_arr = "jobjectArray"
         self.get_native_arr_len_call = ("(*env)->GetArrayLength(env, ", ")")
 
@@ -350,6 +358,14 @@ import java.util.Arrays;
         else:
             return "(*env)->Release" + ty_info.java_ty.strip("[]").title() + "ArrayElements(env, " + arr_name + ", " + dest_name + ", 0)"
 
+    def str_ref_to_c_call(self, var_name, str_len):
+        return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
+
+    def c_fn_name_define_pfx(self, fn_name, has_args):
+        if has_args:
+            return "JNICALL Java_org_ldk_impl_bindings_" + fn_name.replace("_", "_1") + "(JNIEnv *env, jclass clz, "
+        return "JNICALL Java_org_ldk_impl_bindings_" + fn_name.replace("_", "_1") + "(JNIEnv *env, jclass clz"
+
     def init_str(self):
         res = ""
         for ty in self.c_array_class_caches:
@@ -536,9 +552,6 @@ import java.util.Arrays;
                         if fn_line.ret_ty_info.from_hu_conv[1] != "":
                             java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.from_hu_conv[1].replace("this", "impl_holder.held") + ";\n"
                         #if fn_line.ret_ty_info.rust_obj in result_types:
-                        # XXX: We need to handle this in conversion logic so that its cross-language!
-                            # Avoid double-free by breaking the result - we should learn to clone these and then we can be safe instead
-                        #    java_trait_constr = java_trait_constr + "\t\t\t\tret.ptr = 0;\n"
                         java_trait_constr = java_trait_constr + "\t\t\t\treturn result;\n"
                     else:
                         java_trait_constr = java_trait_constr + "\t\t\t\treturn ret;\n"
@@ -691,7 +704,7 @@ import java.util.Arrays;
         out_c = out_c + "\treturn ret;\n"
         out_c = out_c + "}\n"
 
-        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1new (" + self.c_fn_args_pfx + ", jobject o"
+        out_c = out_c + self.c_fn_ty_pfx + "long " + self.c_fn_name_define_pfx(struct_name + "_new", True) + "jobject o"
         for var in field_vars:
             if isinstance(var, ConvInfo):
                 out_c = out_c + ", " + var.c_ty + " " + var.arg_name
@@ -780,7 +793,7 @@ import java.util.Arrays;
 
         out_c += (self.c_complex_enum_pfx(struct_name, [x.var_name for x in variant_list], init_meth_jty_strs))
 
-        out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_pfx + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (" + self.c_fn_args_pfx + ", " + self.ptr_c_ty + " ptr) {\n")
+        out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_define_pfx(struct_name + "_ref_from_ptr", True) + self.ptr_c_ty + " ptr) {\n")
         out_c += ("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
         out_c += ("\tswitch(obj->tag) {\n")
         for var in variant_list:
@@ -832,12 +845,12 @@ import java.util.Arrays;
         if return_type_info.ret_conv is not None:
             ret_conv_pfx, ret_conv_sfx = return_type_info.ret_conv
         out_java += (" " + method_name + "(")
-        out_c += (" " + self.c_fn_name_pfx + method_name.replace('_', '_1') + "(" + self.c_fn_args_pfx)
+        have_args = len(argument_types) > 1 or (len(argument_types) > 0 and argument_types[0].c_ty != "void")
+        out_c += (" " + self.c_fn_name_define_pfx(method_name, have_args))
 
         for idx, arg_conv_info in enumerate(argument_types):
             if idx != 0:
                 out_java += (", ")
-            if arg_conv_info.c_ty != "void":
                 out_c += (", ")
             if arg_conv_info.c_ty != "void":
                 out_c += (arg_conv_info.c_ty + " " + arg_conv_info.arg_name)