Rewrite clone handling to manually-define them for types, though almost unused
authorMatt Corallo <git@bluematt.me>
Fri, 8 Jan 2021 05:22:12 +0000 (00:22 -0500)
committerMatt Corallo <git@bluematt.me>
Fri, 8 Jan 2021 05:25:17 +0000 (00:25 -0500)
genbindings.py

index 87ada732e6ddc0fd66dcc69ca4a5f525f1620f6e..fef5ef93c2e4b0b6ec8f205513e6e2ca2c7cab06 100755 (executable)
@@ -175,6 +175,7 @@ class ConvInfo:
         self.java_ty = ty_info.java_ty
         self.java_hu_ty = ty_info.java_hu_ty
         self.java_fn_ty_arg = ty_info.java_fn_ty_arg
+        self.is_native_primitive = ty_info.is_native_primitive
         self.arg_name = arg_name
         self.arg_conv = arg_conv
         self.arg_conv_name = arg_conv_name
@@ -500,6 +501,11 @@ with open(sys.argv[1]) as in_h:
                 clone_fns.add(arr_fn.group(2))
             # No object constructors return arrays, as then they wouldn't be an object constructor
             continue
+
+# Define some manual clones...
+clone_fns.add("ThirtyTwoBytes_clone")
+write_c("static inline struct LDKThirtyTwoBytes ThirtyTwoBytes_clone(const struct LDKThirtyTwoBytes *orig) { struct LDKThirtyTwoBytes ret; memcpy(ret.data, orig->data, 32); return ret; }\n")
+
 java_c_types_none_allowed = False # C structs created by cbindgen are declared in dependency order
 
 with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
@@ -687,17 +693,17 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 else:
                     opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = (" + ty_info.var_name + " & 1) || (" + ty_info.var_name + " == 0);"
                 if not is_free and (not ty_info.is_ptr or not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False:
-                    if (ty_info.java_hu_ty + "_clone") in clone_fns:
+                    if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns:
                         # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
                         opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
-                        opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.java_hu_ty + "_clone(&" + ty_info.var_name + "_conv);"
+                        opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
                     elif ty_info.passed_as_ptr:
                         opaque_arg_conv = opaque_arg_conv + "\n// Warning: we may need a move here but can't clone!"
 
                 opaque_ret_conv_suf = ";\n"
-                if not holds_ref and ty_info.is_ptr and (ty_info.java_hu_ty + "_clone") in clone_fns: # is_ptr, not holds_ref implies passing a pointed-to value to java, which needs copied
+                if not holds_ref and ty_info.is_ptr and (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns: # is_ptr, not holds_ref implies passing a pointed-to value to java, which needs copied
                     opaque_ret_conv_suf = opaque_ret_conv_suf + "if (" + ty_info.var_name + "->inner != NULL)\n"
-                    opaque_ret_conv_suf = opaque_ret_conv_suf + "\t" + ty_info.var_name + "_var = " + ty_info.java_hu_ty + "_clone(" + ty_info.var_name + ");\n"
+                    opaque_ret_conv_suf = opaque_ret_conv_suf + "\t" + ty_info.var_name + "_var = " + ty_info.rust_obj.replace("LDK", "") + "_clone(" + ty_info.var_name + ");\n"
                 elif not holds_ref and ty_info.is_ptr:
                     opaque_ret_conv_suf = opaque_ret_conv_suf + "// Warning: we may need a move here but can't clone!\n"
 
@@ -740,8 +746,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 if ty_info.rust_obj in trait_structs:
                     if not is_free:
                         needs_full_clone = not is_free and (not ty_info.is_ptr and not holds_ref or ty_info.requires_clone == True) and ty_info.requires_clone != False
-                        if needs_full_clone and (ty_info.java_hu_ty + "_clone") in clone_fns:
-                            base_conv = base_conv + "\n" + ty_info.var_name + "_conv = " + ty_info.java_hu_ty + "_clone(" + ty_info.var_name + ");"
+                        if needs_full_clone and (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns:
+                            base_conv = base_conv + "\n" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(" + ty_info.var_name + ");"
                         else:
                             base_conv = base_conv + "\nif (" + ty_info.var_name + "_conv.free == " + ty_info.rust_obj + "_JCalls_free) {\n"
                             base_conv = base_conv + "\t// If this_arg is a JCalls struct, then we need to increment the refcnt in it.\n"
@@ -767,8 +773,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     if not holds_ref:
                         ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_copy = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n", "")
                         if ty_info.requires_clone == True: # Set in object array mapping
-                            if (ty_info.java_hu_ty + "_clone") in clone_fns:
-                                ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = " + ty_info.java_hu_ty + "_clone(&", ");\n")
+                            if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns:
+                                ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&", ");\n")
                             else:
                                 ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = ", "; // XXX: We likely need to clone here, but no _clone fn is available!\n")
                         else:
@@ -780,9 +786,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         to_hu_conv = ty_info.java_hu_ty + " " + ty_info.var_name + "_hu_conv = " + ty_info.java_hu_ty + ".constr_from_ptr(" + ty_info.var_name + ");\n" + ty_info.var_name + "_hu_conv.ptrs_to.add(this);",
                         to_hu_conv_name = ty_info.var_name + "_hu_conv", from_hu_conv = (ty_info.var_name + ".ptr", ""))
                 if ty_info.rust_obj in result_types:
-                    #assert not ty_info.is_ptr and not holds_ref # Otherwise we shouldn't be MALLOC'ing
-                    # XXX ^
-                    ret_conv = (ty_info.rust_obj + "* " + ty_info.var_name + "_conv = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*" + ty_info.var_name + "_conv = ", ";")
+                    if holds_ref:
+                        # If we're trying to return a ref, we have to clone.
+                        # We just blindly assume its implemented and let the compiler fail if its not.
+                        ret_conv = (ty_info.rust_obj + "* " + ty_info.var_name + "_conv = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*" + ty_info.var_name + "_conv = ", ";")
+                        ret_conv = (ret_conv[0], ret_conv[1] + "\n*" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(" + ty_info.var_name + "_conv);")
+                    else:
+                        ret_conv = (ty_info.rust_obj + "* " + ty_info.var_name + "_conv = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*" + ty_info.var_name + "_conv = ", ";")
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                         arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                         ret_conv = ret_conv, ret_conv_name = "(long)" + ty_info.var_name + "_conv",
@@ -845,12 +855,12 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                         to_hu_conv_name = ty_info.var_name + "_hu_conv",
                         from_hu_conv = (ty_info.var_name + " == null ? 0 : " + ty_info.var_name + ".ptr & ~1", "this.ptrs_to.add(" + ty_info.var_name + ")"))
                 elif ty_info.rust_obj in trait_structs:
-                    if ty_info.java_hu_ty + "_clone" in clone_fns:
+                    if ty_info.rust_obj.replace("LDK", "") + "_clone" in clone_fns:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                             arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
                             arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                             ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_clone = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n" +
-                                "*" + ty_info.var_name + "_clone = " + ty_info.java_hu_ty + "_clone(", ");"),
+                                "*" + ty_info.var_name + "_clone = " + ty_info.rust_obj.replace("LDK", "") + "_clone(", ");"),
                             ret_conv_name = "(long)" + ty_info.var_name + "_clone",
                             to_hu_conv = ty_info.java_hu_ty + " ret_hu_conv = new " + ty_info.java_hu_ty + "(null, " + ty_info.var_name + ");\nret_hu_conv.ptrs_to.add(this);",
                             to_hu_conv_name = "ret_hu_conv",
@@ -1496,6 +1506,11 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
 
             res_map = map_type(res_ty + " res", True, None, False, True)
             err_map = map_type(err_ty + " err", True, None, False, True)
+            can_clone = True
+            if not res_map.is_native_primitive and (res_map.rust_obj.replace("LDK", "") + "_clone" not in clone_fns):
+                can_clone = False
+            if not err_map.is_native_primitive and (err_map.rust_obj.replace("LDK", "") + "_clone" not in clone_fns):
+                can_clone = False
 
             out_java.write("\tpublic static native boolean " + struct_name + "_result_ok(long arg);\n")
             write_c("JNIEXPORT jboolean JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1result_1ok (JNIEnv * env, jclass _a, jlong arg) {\n")
@@ -1577,6 +1592,34 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     out_java_struct.write("\t\t\tthis(null, bindings.C" + human_ty + "_err(err));\n")
             out_java_struct.write("\t\t}\n\t}\n}\n")
 
+            if can_clone:
+                clone_fns.add(struct_name.replace("LDK", "") + "_clone")
+                write_c("static inline " + struct_name + " " + struct_name.replace("LDK", "") + "_clone(const " + struct_name + " *orig) {\n")
+                write_c("\t" + struct_name + " res = { .result_ok = orig->result_ok };\n")
+                write_c("\tif (orig->result_ok) {\n")
+                if res_map.c_ty == "void":
+                    write_c("\t\tres.contents.result = NULL;\n")
+                else:
+                    if res_map.is_native_primitive:
+                        write_c("\t\t" + res_map.c_ty + "* contents = MALLOC(sizeof(" + res_map.c_ty + "), \"" + res_map.c_ty + " result OK clone\");\n")
+                        write_c("\t\t*contents = *orig->contents.result;\n")
+                    else:
+                        write_c("\t\t" + res_map.rust_obj + "* contents = MALLOC(sizeof(" + res_map.rust_obj + "), \"" + res_map.rust_obj + " result OK clone\");\n")
+                        write_c("\t\t*contents = " + res_map.rust_obj.replace("LDK", "") + "_clone(orig->contents.result);\n")
+                    write_c("\t\tres.contents.result = contents;\n")
+                write_c("\t} else {\n")
+                if err_map.c_ty == "void":
+                    write_c("\t\tres.contents.err = NULL;\n")
+                else:
+                    if err_map.is_native_primitive:
+                        write_c("\t\t" + err_map.c_ty + "* contents = MALLOC(sizeof(" + err_map.c_ty + "), \"" + err_map.c_ty + " result Err clone\");\n")
+                        write_c("\t\t*contents = *orig->contents.err;\n")
+                    else:
+                        write_c("\t\t" + err_map.rust_obj + "* contents = MALLOC(sizeof(" + err_map.rust_obj + "), \"" + err_map.rust_obj + " result Err clone\");\n")
+                        write_c("\t\t*contents = " + err_map.rust_obj.replace("LDK", "") + "_clone(orig->contents.err);\n")
+                    write_c("\t\tres.contents.err = contents;\n")
+                write_c("\t}\n\treturn res;\n}\n")
+
     def map_tuple(struct_name, field_lines):
         out_java.write("\tpublic static native long " + struct_name + "_new(")
         write_c("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1new(JNIEnv *_env, jclass _b")
@@ -1594,6 +1637,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
         out_java.write(");\n")
         write_c(") {\n")
         write_c("\t" + struct_name + "* ret = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
+        can_clone = True
+        clone_str = "static inline " + struct_name + " " + struct_name.replace("LDK", "") + "_clone(const " + struct_name + " *orig) {\n"
+        clone_str = clone_str + "\t" + struct_name + " ret = {\n"
         for idx, line in enumerate(field_lines):
             if idx != 0 and idx < len(field_lines) - 2:
                 ty_info = map_type(line.strip(';'), False, None, False, False)
@@ -1605,9 +1651,20 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                     write_c("\tret->" + e + " = " + e + ";\n")
                 if ty_info.arg_conv_cleanup is not None:
                     write_c("\t//TODO: Really need to call " + ty_info.arg_conv_cleanup + " here\n")
+                if not ty_info.is_native_primitive and (ty_info.rust_obj.replace("LDK", "") + "_clone") not in clone_fns:
+                    can_clone = False
+                elif can_clone and ty_info.is_native_primitive:
+                    clone_str = clone_str + "\t\t." + chr(ord('a') + idx - 1) + " = orig->" + chr(ord('a') + idx - 1) + ",\n"
+                elif can_clone:
+                    clone_str = clone_str + "\t\t." + chr(ord('a') + idx - 1) + " = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&orig->" + chr(ord('a') + idx - 1) + "),\n"
         write_c("\treturn (long)ret;\n")
         write_c("}\n")
 
+        if can_clone:
+            clone_fns.add(struct_name.replace("LDK", "") + "_clone")
+            write_c(clone_str)
+            write_c("\t};\n\treturn ret;\n}\n")
+
         for idx, ty_info in enumerate(ty_list):
             e = chr(ord('a') + idx)
             out_java.write("\tpublic static native " + ty_info.java_ty + " " + struct_name + "_get_" + e + "(long ptr);\n")
@@ -1916,6 +1973,21 @@ class CommonBase {
                         write_c("\t}\n")
                         write_c("\treturn (long)ret;\n")
                         write_c("}\n")
+
+                    if ty_info.is_native_primitive:
+                        clone_fns.add(struct_name.replace("LDK", "") + "_clone")
+                        write_c("static inline " + struct_name + " " + struct_name.replace("LDK", "") + "_clone(const " + struct_name + " *orig) {\n")
+                        write_c("\t" + struct_name + " ret = { .data = MALLOC(sizeof(" + ty_info.c_ty + ") * orig->datalen, \"" + struct_name + " clone bytes\"), .datalen = orig->datalen };\n")
+                        write_c("\tmemcpy(ret.data, orig->data, sizeof(" + ty_info.c_ty + ") * ret.datalen);\n")
+                        write_c("\treturn ret;\n}\n")
+                    elif (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns:
+                        ty_name = "CVec_" + ty_info.rust_obj.replace("LDK", "") + "Z";
+                        clone_fns.add(ty_name + "_clone")
+                        write_c("static inline " + struct_name + " " + ty_name + "_clone(const " + struct_name + " *orig) {\n")
+                        write_c("\t" + struct_name + " ret = { .data = MALLOC(sizeof(" + ty_info.rust_obj + ") * orig->datalen, \"" + struct_name + " clone bytes\"), .datalen = orig->datalen };\n")
+                        write_c("\tfor (size_t i = 0; i < ret.datalen; i++) {\n")
+                        write_c("\t\tret.data[i] = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&orig->data[i]);\n")
+                        write_c("\t}\n\treturn ret;\n}\n")
                 elif is_union_enum:
                     assert(struct_name.endswith("_Tag"))
                     struct_name = struct_name[:-4]