Update to support None
[ldk-java] / genbindings.py
index 5ba69e6d09082e262f02ad5889b4fc96e0f234c9..d2cf50db00938836c97e5964ada92e317446e484 100755 (executable)
@@ -19,7 +19,7 @@ class TypeInfo:
         self.arr_access = arr_access
 
 class ConvInfo:
-    def __init__(self, ty_info, arg_name, arg_conv, arg_conv_name, ret_conv, ret_conv_name):
+    def __init__(self, ty_info, arg_name, arg_conv, arg_conv_name, arg_conv_cleanup, ret_conv, ret_conv_name):
         assert(ty_info.c_ty is not None)
         assert(ty_info.java_ty is not None)
         assert(arg_name is not None)
@@ -31,6 +31,7 @@ class ConvInfo:
         self.arg_name = arg_name
         self.arg_conv = arg_conv
         self.arg_conv_name = arg_conv_name
+        self.arg_conv_cleanup = arg_conv_cleanup
         self.ret_conv = ret_conv
         self.ret_conv_name = ret_conv_name
 
@@ -93,7 +94,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 lastund = True
         return (ret + lastchar.lower()).strip("_")
 
-    var_is_arr_regex = re.compile("\(\*([A-za-z0-9_]*)\)\[([0-9]*)\]")
+    var_is_arr_regex = re.compile("\(\*([A-za-z0-9_]*)\)\[([a-z0-9]*)\]")
     var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)")
     def java_c_types(fn_arg, ret_arr_len):
         fn_arg = fn_arg.strip()
@@ -118,10 +119,26 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             assert var_is_arr_regex.match(fn_arg[8:])
             rust_obj = "LDKPublicKey"
             arr_access = "compressed_form"
-        #if fn_arg.startswith("LDKSignature"):
-        #    fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
-        #    assert var_is_arr_regex.match(fn_arg[8:])
-        #    rust_obj = "LDKSignature"
+        if fn_arg.startswith("LDKSecretKey"):
+            fn_arg = "uint8_t (*" + fn_arg[13:] + ")[32]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKSecretKey"
+            arr_access = "bytes"
+        if fn_arg.startswith("LDKSignature"):
+            fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKSignature"
+            arr_access = "compact_form"
+        if fn_arg.startswith("LDKThreeBytes"):
+            fn_arg = "uint8_t (*" + fn_arg[14:] + ")[3]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKThreeBytes"
+            arr_access = "data"
+        if fn_arg.startswith("LDKu8slice"):
+            fn_arg = "uint8_t (*" + fn_arg[11:] + ")[datalen]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKu8slice"
+            arr_access = "data"
 
         if fn_arg.startswith("void"):
             java_ty = "void"
@@ -206,7 +223,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         if ty_info.c_ty == "void":
             if not print_void:
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                    arg_conv = None, arg_conv_name = None, ret_conv = None, ret_conv_name = None)
+                    arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None,
+                    ret_conv = None, ret_conv_name = None)
 
         if ty_info.c_ty.endswith("Array"):
             arr_len = ty_info.arr_len
@@ -216,19 +234,30 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 arr_name = "ret"
                 arr_len = ret_arr_len
             assert(ty_info.c_ty == "jbyteArray")
-            if ty_info.rust_obj is not None:
-                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
-                arr_access = ("", "." + ty_info.arr_access)
+            ret_conv = ("jbyteArray " + arr_name + "_arr = (*_env)->NewByteArray(_env, " + arr_len + ");\n" + "(*_env)->SetByteArrayRegion(_env, " + arr_name + "_arr, 0, " + arr_len + ", ", "")
+            arg_conv_cleanup = None
+            if not arr_len.isdigit():
+                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
+                arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = (*_env)->GetByteArrayElements (_env, " + arr_name + ", NULL);\n"
+                arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = (*_env)->GetArrayLength (_env, " + arr_name + ");"
+                arg_conv_cleanup = "(*_env)->ReleaseByteArrayElements(_env, " + arr_name + ", (int8_t*)" + arr_name + "_ref." + ty_info.arr_access + ", 0);"
+                arr_access = "." + ty_info.arr_access
+                ret_conv = (ty_info.rust_obj + " " + arr_name + "_var = ", "")
+                ret_conv = (ret_conv[0], ";\njbyteArray " + arr_name + "_arr = (*_env)->NewByteArray(_env, " + arr_name + "_var." + arr_len + ");\n")
+                ret_conv = (ret_conv[0], ret_conv[1] + "(*_env)->SetByteArrayRegion(_env, " + arr_name + "_arr, 0, " + arr_name + "_var." + arr_len + ", " + arr_name + "_var." + ty_info.arr_access + ");")
+            elif ty_info.rust_obj is not None:
+                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
+                arg_conv = arg_conv + "CHECK((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
+                arg_conv = arg_conv + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
+                ret_conv = (ret_conv[0], "." + ty_info.arr_access + ");")
             else:
-                arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
-                arr_access = ("*", "")
+                arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n"
+                arg_conv = arg_conv + "CHECK((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
+                arg_conv = arg_conv + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
+                ret_conv = (ret_conv[0] + "*", ");")
             return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                arg_conv = arg_conv,
-                arg_conv_name = arr_name + "_ref",
-                ret_conv = ("jbyteArray " + arr_name + "_arr = (*_env)->NewByteArray(_env, " + arr_len + ");\n" +
-                    "(*_env)->SetByteArrayRegion(_env, " + arr_name + "_arr, 0, " + arr_len + ", " + arr_access[0],
-                    arr_access[1] + ");"),
-                ret_conv_name = arr_name + "_arr")
+                arg_conv = arg_conv, arg_conv_name = arr_name + "_ref", arg_conv_cleanup = arg_conv_cleanup,
+                ret_conv = ret_conv, ret_conv_name = arr_name + "_arr")
         elif ty_info.var_name != "":
             # If we have a parameter name, print it (noting that it may indicate its a pointer)
             if ty_info.rust_obj is not None:
@@ -236,20 +265,24 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n"
                 opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = (void*)(" + ty_info.var_name + " & (~1));\n"
                 opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = (" + ty_info.var_name + " & 1) || (" + ty_info.var_name + " == 0);"
-                if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns and not ty_info.is_ptr and not is_free:
-                    # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
-                    opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
-                    opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
+                if not ty_info.is_ptr and not is_free:
+                    if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns:
+                        # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
+                        opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
+                        opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
+                    elif ty_info.passed_as_ptr:
+                        opaque_arg_conv = opaque_arg_conv + "\n// Warning: we may need a move here but can't clone!"
                 if not ty_info.is_ptr:
                     if ty_info.rust_obj in unitary_enums:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                             arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = " + ty_info.rust_obj + "_from_java(_env, " + ty_info.var_name + ");",
                             arg_conv_name = ty_info.var_name + "_conv",
+                            arg_conv_cleanup = None,
                             ret_conv = ("jclass " + ty_info.var_name + "_conv = " + ty_info.rust_obj + "_to_java(_env, ", ");"),
                             ret_conv_name = ty_info.var_name + "_conv")
                     if ty_info.rust_obj in opaque_structs:
-                        ret_conv_suf = ";\nDO_ASSERT((((long)" + ty_info.var_name + "_var.inner) & 1) == 0); // We rely on a free low bit, malloc guarantees this.\n"
-                        ret_conv_suf = ret_conv_suf + "DO_ASSERT((((long)&" + ty_info.var_name + "_var) & 1) == 0); // We rely on a free low bit, pointer alignment guarantees this.\n"
+                        ret_conv_suf = ";\nCHECK((((long)" + ty_info.var_name + "_var.inner) & 1) == 0); // We rely on a free low bit, malloc guarantees this.\n"
+                        ret_conv_suf = ret_conv_suf + "CHECK((((long)&" + ty_info.var_name + "_var) & 1) == 0); // We rely on a free low bit, pointer alignment guarantees this.\n"
                         ret_conv_suf = ret_conv_suf + "long " + ty_info.var_name + "_ref;\n"
                         ret_conv_suf = ret_conv_suf + "if (" + ty_info.var_name + "_var.is_owned) {\n"
                         ret_conv_suf = ret_conv_suf + "\t" + ty_info.var_name + "_ref = (long)" + ty_info.var_name + "_var.inner | 1;\n"
@@ -257,7 +290,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         ret_conv_suf = ret_conv_suf + "\t" + ty_info.var_name + "_ref = (long)&" + ty_info.var_name + "_var;\n"
                         ret_conv_suf = ret_conv_suf + "}"
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                            arg_conv = opaque_arg_conv, arg_conv_name = ty_info.var_name + "_conv",
+                            arg_conv = opaque_arg_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                             ret_conv = (ty_info.rust_obj + " " + ty_info.var_name + "_var = ", ret_conv_suf),
                             ret_conv_name = ty_info.var_name + "_ref")
                     base_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv = *(" + ty_info.rust_obj + "*)" + ty_info.var_name + ";";
@@ -269,48 +302,47 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         else:
                             base_conv = base_conv + "\n" + "FREE((void*)" + ty_info.var_name + ");"
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                            arg_conv = base_conv,
-                            arg_conv_name = ty_info.var_name + "_conv",
+                            arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                             ret_conv = ("CANT PASS TRAIT TO Java?", ""), ret_conv_name = "NO CONV POSSIBLE")
                     if ty_info.rust_obj != "LDKu8slice":
                         # Don't bother free'ing slices passed in - Rust doesn't auto-free the
                         # underlying unlike Vecs, and it gives Java more freedom.
                         base_conv = base_conv + "\nFREE((void*)" + ty_info.var_name + ");";
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                        arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv",
+                        arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                         ret_conv = ("long " + ty_info.var_name + "_ref = (long)&", ";"), ret_conv_name = ty_info.var_name + "_ref")
                 else:
                     assert(not is_free)
                     if ty_info.rust_obj in opaque_structs:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                            arg_conv = opaque_arg_conv, arg_conv_name = "&" + ty_info.var_name + "_conv",
+                            arg_conv = opaque_arg_conv, arg_conv_name = "&" + ty_info.var_name + "_conv", arg_conv_cleanup = None,
                             ret_conv = None, ret_conv_name = None) # its a pointer, no conv needed
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                         arg_conv = ty_info.rust_obj + "* " + ty_info.var_name + "_conv = (" + ty_info.rust_obj + "*)" + ty_info.var_name + ";",
-                        arg_conv_name = ty_info.var_name + "_conv",
+                        arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                         ret_conv = None, ret_conv_name = None) # its a pointer, no conv needed
             elif ty_info.is_ptr:
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                    arg_conv = None, arg_conv_name = ty_info.var_name, ret_conv = None, ret_conv_name = None)
+                    arg_conv = None, arg_conv_name = ty_info.var_name, arg_conv_cleanup = None, ret_conv = None, ret_conv_name = None)
             elif ty_info.java_ty == "String":
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                    arg_conv = None, arg_conv_name = None,
+                    arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None,
                     ret_conv = ("jstring " + ty_info.var_name + "_conv = (*_env)->NewStringUTF(_env, ", ");"), ret_conv_name = ty_info.var_name + "_conv")
             else:
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                    arg_conv = None, arg_conv_name = ty_info.var_name, ret_conv = None, ret_conv_name = None)
+                    arg_conv = None, arg_conv_name = ty_info.var_name, arg_conv_cleanup = None, ret_conv = None, ret_conv_name = None)
         elif not print_void:
             # We don't have a parameter name, and want one, just call it arg
             if ty_info.rust_obj is not None:
                 assert(not is_free or ty_info.rust_obj not in opaque_structs);
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = ty_info.rust_obj + " arg_conv = *(" + ty_info.rust_obj + "*)arg;\nFREE((void*)arg);",
-                    arg_conv_name = "arg_conv",
+                    arg_conv_name = "arg_conv", arg_conv_cleanup = None,
                     ret_conv = None, ret_conv_name = None)
             else:
                 assert(not is_free)
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                    arg_conv = None, arg_conv_name = "arg", ret_conv = None, ret_conv_name = None)
+                    arg_conv = None, arg_conv_name = "arg", arg_conv_cleanup = None, ret_conv = None, ret_conv_name = None)
         else:
             # We don't have a parameter name, and don't want one (cause we're returning)
             if ty_info.rust_obj is not None:
@@ -318,7 +350,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                     if ty_info.rust_obj in unitary_enums:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                             arg_conv = ty_info.rust_obj + " ret = " + ty_info.rust_obj + "_from_java(_env, " + ty_info.var_name + ");",
-                            arg_conv_name = "ret",
+                            arg_conv_name = "ret", arg_conv_cleanup = None,
                             ret_conv = ("jclass ret = " + ty_info.rust_obj + "_to_java(_env, ", ");"), ret_conv_name = "ret")
                     if ty_info.rust_obj in opaque_structs:
                         # If we're returning a newly-allocated struct, we don't want Rust to ever
@@ -328,19 +360,19 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                             ret_conv = (ty_info.rust_obj + " ret = ", ";"),
                             ret_conv_name = "((long)ret.inner) | (ret.is_owned ? 1 : 0)",
-                            arg_conv = None, arg_conv_name = None)
+                            arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None)
                     else:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                             ret_conv = (ty_info.rust_obj + "* ret = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*ret = ", ";"),
                             ret_conv_name = "(long)ret",
-                            arg_conv = None, arg_conv_name = None)
+                            arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None)
                 else:
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                         ret_conv = ("long ret = (long)", ";"), ret_conv_name = "ret",
-                        arg_conv = None, arg_conv_name = None)
+                        arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None)
             else:
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                    arg_conv = None, arg_conv_name = None, ret_conv = None, ret_conv_name = None)
+                    arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None, ret_conv = None, ret_conv_name = None)
 
     def map_fn(line, re_match, ret_arr_len, c_call_string):
         out_java.write("\t// " + line)
@@ -380,7 +412,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             arg_names.append(arg_conv_info)
 
         out_java_struct = None
-        if "LDK" + struct_meth in opaque_structs and not is_free:
+        if ("LDK" + struct_meth in opaque_structs or "LDK" + struct_meth in trait_structs) and not is_free:
             out_java_struct = open(sys.argv[3] + "/structs/" + struct_meth + ".java", "a")
             if not args_known:
                 out_java_struct.write("\t// Skipped " + re_match.group(2) + "\n")
@@ -416,12 +448,14 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
 
         for info in arg_names:
             if info.arg_conv is not None:
-                out_c.write("\t" + info.arg_conv.replace('\n', "\n\t") + "\n");
+                out_c.write("\t" + info.arg_conv.replace('\n', "\n\t") + "\n")
 
         if ret_info.ret_conv is not None:
-            out_c.write("\t" + ret_conv_pfx.replace('\n', '\n\t'));
+            out_c.write("\t" + ret_conv_pfx.replace('\n', '\n\t'))
+        elif ret_info.c_ty != "void":
+            out_c.write("\t" + ret_info.c_ty + " ret_val = ")
         else:
-            out_c.write("\treturn ");
+            out_c.write("\t")
 
         if c_call_string is None:
             out_c.write(re_match.group(2) + "(")
@@ -437,9 +471,15 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         out_c.write(")")
         if ret_info.ret_conv is not None:
             out_c.write(ret_conv_sfx.replace('\n', '\n\t'))
-            out_c.write("\n\treturn " + ret_info.ret_conv_name + ";")
         else:
             out_c.write(";")
+        for info in arg_names:
+            if info.arg_conv_cleanup is not None:
+                out_c.write("\n\t" + info.arg_conv_cleanup.replace("\n", "\n\t"))
+        if ret_info.ret_conv is not None:
+            out_c.write("\n\treturn " + ret_info.ret_conv_name + ";")
+        elif ret_info.c_ty != "void":
+            out_c.write("\n\treturn ret_val;")
         out_c.write("\n}\n\n")
         if out_java_struct is not None:
             out_java_struct.write("\t\t")
@@ -458,9 +498,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 if info.arg_name == "this_arg":
                     out_java_struct.write("this.ptr")
                 elif info.passed_as_ptr and info.rust_obj in opaque_structs:
-                    out_java_struct.write(info.arg_name + ".ptr & ~1")
+                    out_java_struct.write(info.arg_name + " == null ? 0 : " + info.arg_name + ".ptr & ~1")
                 elif info.passed_as_ptr and info.rust_obj in trait_structs:
-                    out_java_struct.write(info.arg_name + ".ptr")
+                    out_java_struct.write(info.arg_name + " == null ? 0 : " + info.arg_name + ".ptr")
                 else:
                     out_java_struct.write(info.arg_name)
             out_java_struct.write(")")
@@ -519,12 +559,12 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                     out_c.write("static jfieldID " + struct_name + "_" + variant + " = NULL;\n")
             out_c.write("JNIEXPORT void JNICALL Java_org_ldk_enums_" + struct_name.replace("_", "_1") + "_init (JNIEnv * env, jclass clz) {\n")
             out_c.write("\t" + struct_name + "_class = (*env)->NewGlobalRef(env, clz);\n")
-            out_c.write("\tDO_ASSERT(" + struct_name + "_class != NULL);\n")
+            out_c.write("\tCHECK(" + struct_name + "_class != NULL);\n")
             for idx, struct_line in enumerate(field_lines):
                 if idx > 0 and idx < len(field_lines) - 3:
                     variant = struct_line.strip().strip(",")
                     out_c.write("\t" + struct_name + "_" + variant + " = (*env)->GetStaticFieldID(env, " + struct_name + "_class, \"" + variant + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n")
-                    out_c.write("\tDO_ASSERT(" + struct_name + "_" + variant + " != NULL);\n")
+                    out_c.write("\tCHECK(" + struct_name + "_" + variant + " != NULL);\n")
             out_c.write("}\n")
             out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
             out_c.write("\tswitch (val) {\n")
@@ -585,9 +625,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 var_name = struct_line.strip(' ,')[len(struct_name) + 1:]
                 out_c.write("\t" + struct_name + "_" + var_name + "_class =\n")
                 out_c.write("\t\t(*env)->NewGlobalRef(env, (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + "$" + var_name + ";\"));\n")
-                out_c.write("\tDO_ASSERT(" + struct_name + "_" + var_name + "_class != NULL);\n")
+                out_c.write("\tCHECK(" + struct_name + "_" + var_name + "_class != NULL);\n")
                 out_c.write("\t" + struct_name + "_" + var_name + "_meth = (*env)->GetMethodID(env, " + struct_name + "_" + var_name + "_class, \"<init>\", \"(" + init_meth_jty_strs[var_name] + ")V\");\n")
-                out_c.write("\tDO_ASSERT(" + struct_name + "_" + var_name + "_meth != NULL);\n")
+                out_c.write("\tCHECK(" + struct_name + "_" + var_name + "_meth != NULL);\n")
         out_c.write("}\n")
         out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (JNIEnv * env, jclass _c, jlong ptr) {\n")
         out_c.write("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
@@ -630,6 +670,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
 
             out_java_trait.write("package org.ldk.structs;\n\n")
             out_java_trait.write("import org.ldk.impl.bindings;\n\n")
+            out_java_trait.write("import org.ldk.enums.*;\n\n")
             out_java_trait.write("public class " + struct_name.replace("LDK","") + " extends CommonBase {\n")
             out_java_trait.write("\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); }\n")
             out_java_trait.write("\tpublic " + struct_name.replace("LDK", "") + "(bindings." + struct_name + " arg")
@@ -648,7 +689,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_java_trait.write("\tprotected void finalize() throws Throwable {\n")
             out_java_trait.write("\t\tbindings." + struct_name.replace("LDK","") + "_free(ptr); super.finalize();\n")
             out_java_trait.write("\t}\n\n")
-            out_java_trait.write("}\n")
 
             out_java.write("\tpublic interface " + struct_name + " {\n")
             java_meths = []
@@ -692,7 +732,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                             out_c.write(arg_info.arg_name)
                             out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "env") + "\n")
 
-                    out_c.write("\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tDO_ASSERT(obj != NULL);\n")
+                    out_c.write("\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tCHECK(obj != NULL);\n")
                     if ret_ty_info.c_ty.endswith("Array"):
                         assert(ret_ty_info.c_ty == "jbyteArray")
                         out_c.write("\tjbyteArray jret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.group(2) + "_meth")
@@ -709,6 +749,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                     out_c.write(");\n");
                     if ret_ty_info.c_ty.endswith("Array"):
                         out_c.write("\t" + ret_ty_info.rust_obj + " ret;\n")
+                        out_c.write("\tCHECK((*env)->GetArrayLength(env, jret) == " + ret_ty_info.arr_len + ");\n")
                         out_c.write("\t(*env)->GetByteArrayRegion(env, jret, 0, " + ret_ty_info.arr_len + ", ret." + ret_ty_info.arr_access + ");\n")
                         out_c.write("\treturn ret;\n")
 
@@ -749,7 +790,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_c.write(") {\n")
 
             out_c.write("\tjclass c = (*env)->GetObjectClass(env, o);\n")
-            out_c.write("\tDO_ASSERT(c != NULL);\n")
+            out_c.write("\tCHECK(c != NULL);\n")
             out_c.write("\t" + struct_name + "_JCalls *calls = MALLOC(sizeof(" + struct_name + "_JCalls), \"" + struct_name + "_JCalls\");\n")
             out_c.write("\tatomic_init(&calls->refcnt, 1);\n")
             out_c.write("\tDO_ASSERT((*env)->GetJavaVM(env, &calls->vm) == 0);\n")
@@ -757,7 +798,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             for (fn_line, java_meth_descr) in zip(trait_fn_lines, java_meths):
                 if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
                     out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + java_meth_descr + "\");\n")
-                    out_c.write("\tDO_ASSERT(calls->" + fn_line.group(2) + "_meth != NULL);\n")
+                    out_c.write("\tCHECK(calls->" + fn_line.group(2) + "_meth != NULL);\n")
             out_c.write("\n\t" + struct_name + " ret = {\n")
             out_c.write("\t\t.this_arg = (void*) calls,\n")
             for fn_line in trait_fn_lines:
@@ -794,16 +835,16 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_java.write("\tpublic static native " + struct_name + " " + struct_name + "_get_obj_from_jcalls(long val);\n")
             out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1get_1obj_1from_1jcalls (JNIEnv * env, jclass _a, jlong val) {\n")
             out_c.write("\tjobject ret = (*env)->NewLocalRef(env, ((" + struct_name + "_JCalls*)val)->o);\n")
-            out_c.write("\tDO_ASSERT(ret != NULL);\n")
+            out_c.write("\tCHECK(ret != NULL);\n")
             out_c.write("\treturn ret;\n")
             out_c.write("}\n")
 
-            for fn_line in trait_fn_lines:
-                # For now, just disable enabling the _call_log - we don't know how to inverse-map String
-                is_log = fn_line.group(2) == "log" and struct_name == "LDKLogger"
-                if fn_line.group(2) != "free" and fn_line.group(2) != "clone" and fn_line.group(2) != "eq" and not is_log:
-                    dummy_line = fn_line.group(1) + struct_name + "_call_" + fn_line.group(2) + " " + struct_name + "* arg" + fn_line.group(4) + "\n"
-                    map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(arg_conv->" + fn_line.group(2) + ")(arg_conv->this_arg")
+        for fn_line in trait_fn_lines:
+            # For now, just disable enabling the _call_log - we don't know how to inverse-map String
+            is_log = fn_line.group(2) == "log" and struct_name == "LDKLogger"
+            if fn_line.group(2) != "free" and fn_line.group(2) != "clone" and fn_line.group(2) != "eq" and not is_log:
+                dummy_line = fn_line.group(1) + struct_name.replace("LDK", "") + "_call_" + fn_line.group(2) + " " + struct_name + "* this_arg" + fn_line.group(4) + "\n"
+                map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(this_arg_conv->" + fn_line.group(2) + ")(this_arg_conv->this_arg")
 
     out_c.write("""#include \"org_ldk_impl_bindings.h\"
 #include <rust_types.h>
@@ -816,9 +857,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         out_c.write("#define MALLOC(a, _) malloc(a)\n")
         out_c.write("#define FREE free\n")
         out_c.write("#define DO_ASSERT(a) (void)(a)\n")
+        out_c.write("#define CHECK(a)\n")
     else:
         out_c.write("""#include <assert.h>
+// Always run a, then assert it is true:
 #define DO_ASSERT(a) do { bool _assert_val = (a); assert(_assert_val); } while(0)
+// Assert a is true or do nothing
+#define CHECK(a) DO_ASSERT(a)
 
 // Running a leak check across all the allocations and frees of the JDK is a mess,
 // so instead we implement our own naive leak checker here, relying on the -wrap
@@ -947,11 +992,11 @@ static jmethodID slicedef_meth = NULL;
 static jclass slicedef_cls = NULL;
 JNIEXPORT void Java_org_ldk_impl_bindings_init(JNIEnv * env, jclass _b, jclass enum_class, jclass slicedef_class) {
        ordinal_meth = (*env)->GetMethodID(env, enum_class, "ordinal", "()I");
-       DO_ASSERT(ordinal_meth != NULL);
+       CHECK(ordinal_meth != NULL);
        slicedef_meth = (*env)->GetMethodID(env, slicedef_class, "<init>", "(JJJ)V");
-       DO_ASSERT(slicedef_meth != NULL);
+       CHECK(slicedef_meth != NULL);
        slicedef_cls = (*env)->NewGlobalRef(env, slicedef_class);
-       DO_ASSERT(slicedef_cls != NULL);
+       CHECK(slicedef_cls != NULL);
 }
 
 JNIEXPORT jboolean JNICALL Java_org_ldk_impl_bindings_deref_1bool (JNIEnv * env, jclass _a, jlong ptr) {
@@ -1028,13 +1073,6 @@ class CommonBase {
 }
 """)
 
-    # XXX: Temporarily write out a manual SecretKey_new() for testing, we should auto-gen this kind of thing
-    out_java.write("\tpublic static native long LDKSecretKey_new();\n\n") # TODO: rm me
-    out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_LDKSecretKey_1new(JNIEnv * _env, jclass _b) {\n") # TODO: rm me
-    out_c.write("\tLDKSecretKey* key = (LDKSecretKey*)MALLOC(sizeof(LDKSecretKey), \"LDKSecretKey\");\n") # TODO: rm me
-    out_c.write("\treturn (long)key;\n") # TODO: rm me
-    out_c.write("}\n") # TODO: rm me
-
     in_block_comment = False
     cur_block_obj = None
 
@@ -1167,6 +1205,7 @@ class CommonBase {
                                 out_c.write("\n\tret->" + e + " = " + ty_info.arg_conv_name + ";\n")
                             else:
                                 out_c.write("\tret->" + e + " = " + e + ";\n")
+                            assert ty_info.arg_conv_cleanup is None
                     out_c.write("\treturn (long)ret;\n")
                     out_c.write("}\n")
                 elif vec_ty is not None:
@@ -1181,7 +1220,7 @@ class CommonBase {
                         out_c.write("\tjlongArray ret = (*env)->NewLongArray(env, vec->datalen);\n")
                         out_c.write("\tjlong *ret_elems = (*env)->GetPrimitiveArrayCritical(env, ret, NULL);\n")
                         out_c.write("\tfor (size_t i = 0; i < vec->datalen; i++) {\n")
-                        out_c.write("\t\tDO_ASSERT((((long)vec->data[i].inner) & 1) == 0);\n")
+                        out_c.write("\t\tCHECK((((long)vec->data[i].inner) & 1) == 0);\n")
                         out_c.write("\t\tret_elems[i] = (long)vec->data[i].inner | (vec->data[i].is_owned ? 1 : 0);\n")
                         out_c.write("\t}\n")
                         out_c.write("\t(*env)->ReleasePrimitiveArrayCritical(env, ret, ret_elems, 0);\n")
@@ -1206,6 +1245,7 @@ class CommonBase {
                             out_c.write("\t\t\t" + ty_info.c_ty + " arr_elem = java_elems[i];\n")
                             out_c.write("\t\t\t" + ty_info.arg_conv.replace("\n", "\n\t\t\t") + "\n")
                             out_c.write("\t\t\tret->data[i] = " + ty_info.arg_conv_name + ";\n")
+                            assert ty_info.arg_conv_cleanup is None
                         else:
                             out_c.write("\t\t\tret->data[i] = java_elems[i];\n")
                         out_c.write("\t\t}\n")
@@ -1286,3 +1326,6 @@ class CommonBase {
     for struct_name in opaque_structs:
         with open(sys.argv[3] + "/structs/" + struct_name.replace("LDK","") + ".java", "a") as out_java_struct:
             out_java_struct.write("}\n")
+    for struct_name in trait_structs:
+        with open(sys.argv[3] + "/structs/" + struct_name.replace("LDK","") + ".java", "a") as out_java_struct:
+            out_java_struct.write("}\n")