Skip fewer fn's, support trait-contained objects
[ldk-java] / genbindings.py
index b09d930b4e03992063e579c32f094dc7936c7b3e..79e41fd9008118fd267796f18d8c36ce9e526e5f 100755 (executable)
@@ -767,9 +767,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 arg_conv_info.print_name()
             if arg_conv_info.arg_name == "this_ptr" or arg_conv_info.arg_name == "this_arg":
                 takes_self = True
-            if arg_conv_info.passed_as_ptr and not arg_conv_info.rust_obj in opaque_structs:
-                if not arg_conv_info.rust_obj in trait_structs and not arg_conv_info.rust_obj in unitary_enums:
-                    args_known = False
             if arg_conv_info.arg_conv is not None and "Warning" in arg_conv_info.arg_conv:
                 if arg_conv_info.rust_obj in constructor_fns:
                     assert not is_free
@@ -1073,9 +1070,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_c.write("\tatomic_size_t refcnt;\n")
             out_c.write("\tJavaVM *vm;\n")
             out_c.write("\tjweak o;\n")
+            field_var_convs = []
             for var_line in field_var_lines:
                 if var_line.group(1) in trait_structs:
                     out_c.write("\t" + var_line.group(1) + "_JCalls* " + var_line.group(2) + ";\n")
+                    field_var_convs.append(None)
+                else:
+                    field_var_convs.append(map_type(var_line.group(1) + " " + var_line.group(2), False, None, False, False))
             for fn_line in trait_fn_lines:
                 if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
                     out_c.write("\tjmethodID " + fn_line.group(2) + "_meth;\n")
@@ -1086,16 +1087,27 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_java_trait.write("\tfinal bindings." + struct_name + " bindings_instance;\n")
             out_java_trait.write("\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); bindings_instance = null; }\n")
             out_java_trait.write("\tprivate " + struct_name.replace("LDK", "") + "(bindings." + struct_name + " arg")
-            for var_line in field_var_lines:
+            for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     out_java_trait.write(", bindings." + var_line.group(1) + " " + var_line.group(2))
+                else:
+                    out_java_trait.write(", " + field_var_convs[idx].java_hu_ty + " " + var_line.group(2))
             out_java_trait.write(") {\n")
             out_java_trait.write("\t\tsuper(bindings." + struct_name + "_new(arg")
-            for var_line in field_var_lines:
+            for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     out_java_trait.write(", " + var_line.group(2))
+                elif field_var_convs[idx].from_hu_conv is not None:
+                    out_java_trait.write(", " + field_var_convs[idx].from_hu_conv[0])
+                else:
+                    out_java_trait.write(", " + var_line.group(2))
             out_java_trait.write("));\n")
             out_java_trait.write("\t\tthis.ptrs_to.add(arg);\n")
+            for idx, var_line in enumerate(field_var_lines):
+                if var_line.group(1) in trait_structs:
+                    out_java_trait.write("\t\tthis.ptrs_to.add(" + var_line.group(2) + ");\n")
+                elif field_var_convs[idx].from_hu_conv is not None and field_var_convs[idx].from_hu_conv[1] != "":
+                    out_java_trait.write("\t\t" + field_var_convs[idx].from_hu_conv[1] + ";\n")
             out_java_trait.write("\t\tthis.bindings_instance = arg;\n")
             out_java_trait.write("\t}\n")
             out_java_trait.write("\t@Override @SuppressWarnings(\"deprecation\")\n")
@@ -1104,11 +1116,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_java_trait.write("\t}\n\n")
 
             java_trait_constr = "\tpublic " + struct_name.replace("LDK", "") + "(" + struct_name.replace("LDK", "") + "Interface arg"
-            for var_line in field_var_lines:
+            for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     # Ideally we'd be able to take any instance of the interface, but our C code can only represent
                     # Java-implemented version, so we require users pass a Java implementation here :/
                     java_trait_constr = java_trait_constr + ", " + var_line.group(1).replace("LDK", "") + "." + var_line.group(1).replace("LDK", "") + "Interface " + var_line.group(2)
+                else:
+                    java_trait_constr = java_trait_constr + ", " + field_var_convs[idx].java_hu_ty + " " + var_line.group(2)
             java_trait_constr = java_trait_constr + ") {\n\t\tthis(new bindings." + struct_name + "() {\n"
             out_java_trait.write("\tpublic static interface " + struct_name.replace("LDK", "") + "Interface {\n")
             out_java.write("\tpublic interface " + struct_name + " {\n")
@@ -1220,6 +1234,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             for var_line in field_var_lines:
                 if var_line.group(1) in trait_structs:
                     java_trait_constr = java_trait_constr + ", new " + var_line.group(2) + "(" + var_line.group(2) + ").bindings_instance"
+                else:
+                    java_trait_constr = java_trait_constr + ", " + var_line.group(2)
             out_java_trait.write("\t}\n")
             out_java_trait.write(java_trait_constr + ");\n\t}\n")
 
@@ -1237,10 +1253,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
 
             out_java.write("\tpublic static native long " + struct_name + "_new(" + struct_name + " impl")
             out_c.write("static inline " + struct_name + " " + struct_name + "_init (JNIEnv * env, jclass _a, jobject o")
-            for var_line in field_var_lines:
+            for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     out_java.write(", " + var_line.group(1) + " " + var_line.group(2))
                     out_c.write(", jobject " + var_line.group(2))
+                else:
+                    out_java.write(", " + field_var_convs[idx].java_ty + " " + var_line.group(2))
+                    out_c.write(", " + field_var_convs[idx].c_ty + " " + var_line.group(2))
             out_java.write(");\n")
             out_c.write(") {\n")
 
@@ -1254,6 +1273,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
                     out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + java_meth_descr + "\");\n")
                     out_c.write("\tCHECK(calls->" + fn_line.group(2) + "_meth != NULL);\n")
+            for idx, var_line in enumerate(field_var_lines):
+                if field_var_convs[idx] is not None and field_var_convs[idx].arg_conv is not None:
+                    out_c.write("\n\t" + field_var_convs[idx].arg_conv.replace("\n", "\n\t") +"\n")
             out_c.write("\n\t" + struct_name + " ret = {\n")
             out_c.write("\t\t.this_arg = (void*) calls,\n")
             for fn_line in trait_fn_lines:
@@ -1264,9 +1286,15 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 else:
                     clone_fns.add(struct_name + "_clone")
                     out_c.write("\t\t.clone = " + struct_name + "_JCalls_clone,\n")
-            for var_line in field_var_lines:
+            for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     out_c.write("\t\t." + var_line.group(2) + " = " + var_line.group(1) + "_init(env, _a, " + var_line.group(2) + "),\n")
+                elif field_var_convs[idx].arg_conv_name is not None:
+                    out_c.write("\t\t." + var_line.group(2) + " = " + field_var_convs[idx].arg_conv_name + ",\n")
+                    out_c.write("\t\t.set_" + var_line.group(2) + " = NULL,\n")
+                else:
+                    out_c.write("\t\t." + var_line.group(2) + " = " + var_line.group(2) + ",\n")
+                    out_c.write("\t\t.set_" + var_line.group(2) + " = NULL,\n")
             out_c.write("\t};\n")
             for var_line in field_var_lines:
                 if var_line.group(1) in trait_structs:
@@ -1275,15 +1303,16 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_c.write("}\n")
 
             out_c.write("JNIEXPORT long JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1new (JNIEnv * env, jclass _a, jobject o")
-            for var_line in field_var_lines:
+            for idx, var_line in enumerate(field_var_lines):
                 if var_line.group(1) in trait_structs:
                     out_c.write(", jobject " + var_line.group(2))
+                else:
+                    out_c.write(", " + field_var_convs[idx].c_ty + " " + var_line.group(2))
             out_c.write(") {\n")
             out_c.write("\t" + struct_name + " *res_ptr = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
             out_c.write("\t*res_ptr = " + struct_name + "_init(env, _a, o")
             for var_line in field_var_lines:
-                if var_line.group(1) in trait_structs:
-                    out_c.write(", " + var_line.group(2))
+                out_c.write(", " + var_line.group(2))
             out_c.write(");\n")
             out_c.write("\treturn (long)res_ptr;\n")
             out_c.write("}\n")
@@ -1301,6 +1330,15 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             if fn_line.group(2) != "free" and fn_line.group(2) != "clone" and fn_line.group(2) != "eq" and not is_log:
                 dummy_line = fn_line.group(1) + struct_name.replace("LDK", "") + "_" + fn_line.group(2) + " " + struct_name + "* this_arg" + fn_line.group(4) + "\n"
                 map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(this_arg_conv->" + fn_line.group(2) + ")(this_arg_conv->this_arg")
+        for idx, var_line in enumerate(field_var_lines):
+            if var_line.group(1) not in trait_structs:
+                out_c.write(var_line.group(1) + " " + struct_name + "_set_get_" + var_line.group(2) + "(" + struct_name + "* this_arg) {\n")
+                out_c.write("\tif (this_arg->set_" + var_line.group(2) + " != NULL)\n")
+                out_c.write("\t\tthis_arg->set_" + var_line.group(2) + "(this_arg);\n")
+                out_c.write("\treturn this_arg->" + var_line.group(2) + ";\n")
+                out_c.write("}\n")
+                dummy_line = var_line.group(1) + " " + struct_name.replace("LDK", "") + "_get_" + var_line.group(2) + " " + struct_name + "* this_arg" + fn_line.group(4) + "\n"
+                map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, struct_name + "_set_get_" + var_line.group(2) + "(this_arg_conv")
 
     out_c.write("""#include \"org_ldk_impl_bindings.h\"
 #include <rust_types.h>
@@ -1369,7 +1407,12 @@ static void alloc_freed(void* ptr) {
        while (it->ptr != ptr) {
                p = it; it = it->next;
                if (it == NULL) {
-                       fprintf(stderr, "Tried to free unknown pointer %p!\\n", ptr);
+                       fprintf(stderr, "Tried to free unknown pointer %p at:\\n", ptr);
+                       void* bt[BT_MAX];
+                       int bt_len = backtrace(bt, BT_MAX);
+                       backtrace_symbols_fd(bt, bt_len, STDERR_FILENO);
+                       fprintf(stderr, "\\n\\n");
+                       DO_ASSERT(mtx_unlock(&allocation_mtx) == thrd_success);
                        return; // addrsan should catch malloc-unknown and print more info than we have
                }
        }
@@ -1557,6 +1600,7 @@ class CommonBase {
     assert(line_indicates_trait_regex.match("   void *(*clone)(const void *this_arg);"))
     line_field_var_regex = re.compile("^   ([A-Za-z_0-9]*) ([A-Za-z_0-9]*);$")
     assert(line_field_var_regex.match("   LDKMessageSendEventsProvider MessageSendEventsProvider;"))
+    assert(line_field_var_regex.match("   LDKChannelPublicKeys pubkeys;"))
     struct_name_regex = re.compile("^typedef (struct|enum|union) (MUST_USE_STRUCT )?(LDK[A-Za-z_0-9]*) {$")
     assert(struct_name_regex.match("typedef struct LDKCVecTempl_u8 {"))
     assert(struct_name_regex.match("typedef enum LDKNetwork {"))