]> git.bitcoin.ninja Git - ldk-java/blobdiff - genbindings.py
Expose signatures as byte[], check array lengths in C.
[ldk-java] / genbindings.py
index 5ba69e6d09082e262f02ad5889b4fc96e0f234c9..6fc8bf202d7d562ed0520f3675e8368dbfacdad1 100755 (executable)
@@ -118,10 +118,16 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             assert var_is_arr_regex.match(fn_arg[8:])
             rust_obj = "LDKPublicKey"
             arr_access = "compressed_form"
-        #if fn_arg.startswith("LDKSignature"):
-        #    fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
-        #    assert var_is_arr_regex.match(fn_arg[8:])
-        #    rust_obj = "LDKSignature"
+        if fn_arg.startswith("LDKSecretKey"):
+            fn_arg = "uint8_t (*" + fn_arg[13:] + ")[32]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKSecretKey"
+            arr_access = "bytes"
+        if fn_arg.startswith("LDKSignature"):
+            fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKSignature"
+            arr_access = "compact_form"
 
         if fn_arg.startswith("void"):
             java_ty = "void"
@@ -217,10 +223,14 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 arr_len = ret_arr_len
             assert(ty_info.c_ty == "jbyteArray")
             if ty_info.rust_obj is not None:
-                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
+                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
+                arg_conv = arg_conv + "DO_ASSERT((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
+                arg_conv = arg_conv + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
                 arr_access = ("", "." + ty_info.arr_access)
             else:
-                arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
+                arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n"
+                arg_conv = arg_conv + "DO_ASSERT((*_env)->GetArrayLength (_env, " + arr_name + ") == " + arr_len + ");\n"
+                arg_conv = arg_conv + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
                 arr_access = ("*", "")
             return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                 arg_conv = arg_conv,
@@ -380,7 +390,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             arg_names.append(arg_conv_info)
 
         out_java_struct = None
-        if "LDK" + struct_meth in opaque_structs and not is_free:
+        if ("LDK" + struct_meth in opaque_structs or "LDK" + struct_meth in trait_structs) and not is_free:
             out_java_struct = open(sys.argv[3] + "/structs/" + struct_meth + ".java", "a")
             if not args_known:
                 out_java_struct.write("\t// Skipped " + re_match.group(2) + "\n")
@@ -630,6 +640,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
 
             out_java_trait.write("package org.ldk.structs;\n\n")
             out_java_trait.write("import org.ldk.impl.bindings;\n\n")
+            out_java_trait.write("import org.ldk.enums.*;\n\n")
             out_java_trait.write("public class " + struct_name.replace("LDK","") + " extends CommonBase {\n")
             out_java_trait.write("\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); }\n")
             out_java_trait.write("\tpublic " + struct_name.replace("LDK", "") + "(bindings." + struct_name + " arg")
@@ -648,7 +659,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_java_trait.write("\tprotected void finalize() throws Throwable {\n")
             out_java_trait.write("\t\tbindings." + struct_name.replace("LDK","") + "_free(ptr); super.finalize();\n")
             out_java_trait.write("\t}\n\n")
-            out_java_trait.write("}\n")
 
             out_java.write("\tpublic interface " + struct_name + " {\n")
             java_meths = []
@@ -709,6 +719,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                     out_c.write(");\n");
                     if ret_ty_info.c_ty.endswith("Array"):
                         out_c.write("\t" + ret_ty_info.rust_obj + " ret;\n")
+                        out_c.write("\tDO_ASSERT((*env)->GetArrayLength(env, jret) == " + ret_ty_info.arr_len + ");\n")
                         out_c.write("\t(*env)->GetByteArrayRegion(env, jret, 0, " + ret_ty_info.arr_len + ", ret." + ret_ty_info.arr_access + ");\n")
                         out_c.write("\treturn ret;\n")
 
@@ -798,12 +809,12 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_c.write("\treturn ret;\n")
             out_c.write("}\n")
 
-            for fn_line in trait_fn_lines:
-                # For now, just disable enabling the _call_log - we don't know how to inverse-map String
-                is_log = fn_line.group(2) == "log" and struct_name == "LDKLogger"
-                if fn_line.group(2) != "free" and fn_line.group(2) != "clone" and fn_line.group(2) != "eq" and not is_log:
-                    dummy_line = fn_line.group(1) + struct_name + "_call_" + fn_line.group(2) + " " + struct_name + "* arg" + fn_line.group(4) + "\n"
-                    map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(arg_conv->" + fn_line.group(2) + ")(arg_conv->this_arg")
+        for fn_line in trait_fn_lines:
+            # For now, just disable enabling the _call_log - we don't know how to inverse-map String
+            is_log = fn_line.group(2) == "log" and struct_name == "LDKLogger"
+            if fn_line.group(2) != "free" and fn_line.group(2) != "clone" and fn_line.group(2) != "eq" and not is_log:
+                dummy_line = fn_line.group(1) + struct_name.replace("LDK", "") + "_call_" + fn_line.group(2) + " " + struct_name + "* this_arg" + fn_line.group(4) + "\n"
+                map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(this_arg_conv->" + fn_line.group(2) + ")(this_arg_conv->this_arg")
 
     out_c.write("""#include \"org_ldk_impl_bindings.h\"
 #include <rust_types.h>
@@ -1028,13 +1039,6 @@ class CommonBase {
 }
 """)
 
-    # XXX: Temporarily write out a manual SecretKey_new() for testing, we should auto-gen this kind of thing
-    out_java.write("\tpublic static native long LDKSecretKey_new();\n\n") # TODO: rm me
-    out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_LDKSecretKey_1new(JNIEnv * _env, jclass _b) {\n") # TODO: rm me
-    out_c.write("\tLDKSecretKey* key = (LDKSecretKey*)MALLOC(sizeof(LDKSecretKey), \"LDKSecretKey\");\n") # TODO: rm me
-    out_c.write("\treturn (long)key;\n") # TODO: rm me
-    out_c.write("}\n") # TODO: rm me
-
     in_block_comment = False
     cur_block_obj = None
 
@@ -1286,3 +1290,6 @@ class CommonBase {
     for struct_name in opaque_structs:
         with open(sys.argv[3] + "/structs/" + struct_name.replace("LDK","") + ".java", "a") as out_java_struct:
             out_java_struct.write("}\n")
+    for struct_name in trait_structs:
+        with open(sys.argv[3] + "/structs/" + struct_name.replace("LDK","") + ".java", "a") as out_java_struct:
+            out_java_struct.write("}\n")