Handle LDKStr in other structs and fix string conversion overread
[ldk-java] / java_strings.py
index 14c6a9f61fb8c6387f1f688a9ea79faf75b12b72..43265f449c2359ebc14d7bcbf803a0f43a0e2d0a 100644 (file)
@@ -60,6 +60,7 @@ public class bindings {
         self.util_fn_pfx = """package org.ldk.structs;
 import org.ldk.impl.bindings;
 import java.util.Arrays;
+import org.ldk.enums.*;
 
 public class UtilMethods {
 """
@@ -308,12 +309,26 @@ typedef jbyteArray int8_tArray;
 
 static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
        // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
-       char* err_buf = MALLOC(len + 1, "str conv buf");
-       memcpy(err_buf, chars, len);
-       err_buf[len] = 0;
-       jstring err_conv = (*env)->NewStringUTF(env, chars);
-       FREE(err_buf);
-       return err_conv;
+       char* conv_buf = MALLOC(len + 1, "str conv buf");
+       memcpy(conv_buf, chars, len);
+       conv_buf[len] = 0;
+       jstring ret = (*env)->NewStringUTF(env, conv_buf);
+       FREE(conv_buf);
+       return ret;
+}
+static inline LDKStr java_to_owned_str(JNIEnv *env, jstring str) {
+       uint64_t str_len = (*env)->GetStringUTFLength(env, str);
+       char* newchars = MALLOC(str_len + 1, "String chars");
+       const char* jchars = (*env)->GetStringUTFChars(env, str, NULL);
+       memcpy(newchars, jchars, str_len);
+       newchars[str_len] = 0;
+       (*env)->ReleaseStringUTFChars(env, str, jchars);
+       LDKStr res = {
+               .chars = newchars,
+               .len = str_len,
+               .chars_is_owned = true
+       };
+       return res;
 }
 """
 
@@ -381,8 +396,10 @@ import java.util.Arrays;
         else:
             return "(*env)->Release" + ty_info.java_ty.strip("[]").title() + "ArrayElements(env, " + arr_name + ", " + dest_name + ", 0)"
 
-    def str_ref_to_c_call(self, var_name, str_len):
+    def str_ref_to_native_call(self, var_name, str_len):
         return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
+    def str_ref_to_c_call(self, var_name):
+        return "java_to_owned_str(env, " + var_name + ")"
 
     def c_fn_name_define_pfx(self, fn_name, has_args):
         if has_args:
@@ -479,7 +496,7 @@ import java.util.Arrays;
             ret = ret + ", " + param
         return ret + ")"
 
-    def native_c_map_trait(self, struct_name, field_vars, field_fns, trait_doc_comment):
+    def native_c_map_trait(self, struct_name, field_vars, flattened_field_vars, field_fns, trait_doc_comment):
         out_java_trait = ""
         out_java = ""
 
@@ -492,14 +509,14 @@ import java.util.Arrays;
         out_java_trait = out_java_trait + "\tfinal bindings." + struct_name + " bindings_instance;\n"
         out_java_trait = out_java_trait + "\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); bindings_instance = null; }\n"
         out_java_trait = out_java_trait + "\tprivate " + struct_name.replace("LDK", "") + "(bindings." + struct_name + " arg"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 out_java_trait = out_java_trait + ", " + var.java_hu_ty + " " + var.arg_name
             else:
                 out_java_trait = out_java_trait + ", bindings." + var[0] + " " + var[1]
         out_java_trait = out_java_trait + ") {\n"
         out_java_trait = out_java_trait + "\t\tsuper(bindings." + struct_name + "_new(arg"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 if var.from_hu_conv is not None:
                     out_java_trait = out_java_trait + ", " + var.from_hu_conv[0]
@@ -509,7 +526,7 @@ import java.util.Arrays;
                 out_java_trait = out_java_trait + ", " + var[1]
         out_java_trait = out_java_trait + "));\n"
         out_java_trait = out_java_trait + "\t\tthis.ptrs_to.add(arg);\n"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 if var.from_hu_conv is not None and var.from_hu_conv[1] != "":
                     out_java_trait = out_java_trait + "\t\t" + var.from_hu_conv[1].replace("\n", "\n\t\t") + ";\n"
@@ -524,7 +541,7 @@ import java.util.Arrays;
 
         java_trait_constr = "\tprivate static class " + struct_name + "Holder { " + struct_name.replace("LDK", "") + " held; }\n"
         java_trait_constr = java_trait_constr + "\tpublic static " + struct_name.replace("LDK", "") + " new_impl(" + struct_name.replace("LDK", "") + "Interface arg"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 java_trait_constr = java_trait_constr + ", " + var.java_hu_ty + " " + var.arg_name
             else:
@@ -593,14 +610,25 @@ import java.util.Arrays;
             if isinstance(var, ConvInfo):
                 java_trait_constr = java_trait_constr + ", " + var.arg_name
             else:
-                java_trait_constr = java_trait_constr + ", " + var[1] + ".new_impl(" + var[1] + "_impl).bindings_instance"
+                java_trait_constr += ", " + var[1] + ".new_impl(" + var[1] + "_impl"
+                for suparg in var[2]:
+                    if isinstance(suparg, ConvInfo):
+                        java_trait_constr += ", " + suparg.arg_name
+                    else:
+                        java_trait_constr += ", " + suparg[1]
+                java_trait_constr += ").bindings_instance"
+                for suparg in var[2]:
+                    if isinstance(suparg, ConvInfo):
+                        java_trait_constr += ", " + suparg.arg_name
+                    else:
+                        java_trait_constr += ", " + suparg[1]
         out_java_trait = out_java_trait + "\t}\n"
         out_java_trait = out_java_trait + java_trait_constr + ");\n\t\treturn impl_holder.held;\n\t}\n"
 
         out_java = out_java + "\t}\n"
 
         out_java = out_java + "\tpublic static native long " + struct_name + "_new(" + struct_name + " impl"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 out_java = out_java + ", " + var.java_ty + " " + var.arg_name
             else:
@@ -612,7 +640,7 @@ import java.util.Arrays;
         out_c = out_c + "\tatomic_size_t refcnt;\n"
         out_c = out_c + "\tJavaVM *vm;\n"
         out_c = out_c + "\tjweak o;\n"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 # We're a regular ol' field
                 pass
@@ -688,7 +716,7 @@ import java.util.Arrays;
         out_c = out_c + "}\n"
 
         out_c = out_c + "static inline " + struct_name + " " + struct_name + "_init (" + self.c_fn_args_pfx + ", jobject o"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 out_c = out_c + ", " + var.c_ty + " " + var.arg_name
             else:
@@ -707,7 +735,7 @@ import java.util.Arrays;
                 out_c = out_c + "\tcalls->" + fn_name + "_meth = (*env)->GetMethodID(env, c, \"" + fn_name + "\", \"" + java_meth_descr + "\");\n"
                 out_c = out_c + "\tCHECK(calls->" + fn_name + "_meth != NULL);\n"
 
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo) and var.arg_conv is not None:
                 out_c = out_c + "\n\t" + var.arg_conv.replace("\n", "\n\t") +"\n"
         out_c = out_c + "\n\t" + struct_name + " ret = {\n"
@@ -728,16 +756,22 @@ import java.util.Arrays;
                     out_c = out_c + "\t\t." + var.var_name + " = " + var.var_name + ",\n"
                     out_c = out_c + "\t\t.set_" + var.var_name + " = NULL,\n"
             else:
-                out_c = out_c + "\t\t." + var[1] + " = " + var[0] + "_init(env, clz, " + var[1] + "),\n"
+                out_c += "\t\t." + var[1] + " = " + var[0] + "_init(env, clz, " + var[1]
+                for suparg in var[2]:
+                    if isinstance(suparg, ConvInfo):
+                        out_c = out_c + ", " + suparg.arg_name
+                    else:
+                        out_c = out_c + ", " + suparg[1]
+                out_c += "),\n"
         out_c = out_c + "\t};\n"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if not isinstance(var, ConvInfo):
                 out_c = out_c + "\tcalls->" + var[1] + " = ret." + var[1] + ".this_arg;\n"
         out_c = out_c + "\treturn ret;\n"
         out_c = out_c + "}\n"
 
         out_c = out_c + self.c_fn_ty_pfx + "int64_t " + self.c_fn_name_define_pfx(struct_name + "_new", True) + "jobject o"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 out_c = out_c + ", " + var.c_ty + " " + var.arg_name
             else:
@@ -745,7 +779,7 @@ import java.util.Arrays;
         out_c = out_c + ") {\n"
         out_c = out_c + "\t" + struct_name + " *res_ptr = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n"
         out_c = out_c + "\t*res_ptr = " + struct_name + "_init(env, clz, o"
-        for var in field_vars:
+        for var in flattened_field_vars:
             if isinstance(var, ConvInfo):
                 out_c = out_c + ", " + var.arg_name
             else:
@@ -828,7 +862,7 @@ import java.util.Arrays;
         out_c += (self.c_complex_enum_pfx(struct_name, [x.var_name for x in variant_list], init_meth_jty_strs))
 
         out_c += (self.c_fn_ty_pfx + self.c_complex_enum_pass_ty(struct_name) + " " + self.c_fn_name_define_pfx(struct_name + "_ref_from_ptr", True) + self.ptr_c_ty + " ptr) {\n")
-        out_c += ("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
+        out_c += ("\t" + struct_name + " *obj = (" + struct_name + "*)(ptr & ~1);\n")
         out_c += ("\tswitch(obj->tag) {\n")
         for var in variant_list:
             out_c += ("\t\tcase " + struct_name + "_" + var.var_name + ": {\n")
@@ -836,11 +870,17 @@ import java.util.Arrays;
             for idx, field_map in enumerate(var.fields):
                 if field_map.ret_conv is not None:
                     out_c += ("\t\t\t" + field_map.ret_conv[0].replace("\n", "\n\t\t\t"))
-                    out_c += ("obj->" + camel_to_snake(var.var_name) + "." + field_map.arg_name)
+                    if var.tuple_variant:
+                        out_c += "obj->" + camel_to_snake(var.var_name)
+                    else:
+                        out_c += "obj->" + camel_to_snake(var.var_name) + "." + field_map.arg_name
                     out_c += (field_map.ret_conv[1].replace("\n", "\n\t\t\t") + "\n")
                     c_params.append(field_map.ret_conv_name)
                 else:
-                    c_params.append("obj->" + camel_to_snake(var.var_name) + "." + field_map.arg_name)
+                    if var.tuple_variant:
+                        c_params.append("obj->" + camel_to_snake(var.var_name))
+                    else:
+                        c_params.append("obj->" + camel_to_snake(var.var_name) + "." + field_map.arg_name)
             out_c += ("\t\t\treturn " + self.c_constr_native_complex_enum(struct_name, var.var_name, c_params) + ";\n")
             out_c += ("\t\t}\n")
         out_c += ("\t\tdefault: abort();\n")