Convert pubkeys to byte arrays, fix assertions, fix result inner fetch, fix java...
[ldk-java] / genbindings.py
index 5f4edfc71a94489908a46b89e1d40ecfa48d5697..2a5c44e2c6ea7e200d5d42dcba4b8a0218454d99 100755 (executable)
@@ -7,7 +7,7 @@ if len(sys.argv) != 6:
     sys.exit(1)
 
 class TypeInfo:
-    def __init__(self, rust_obj, java_ty, java_fn_ty_arg, c_ty, passed_as_ptr, is_ptr, var_name, arr_len):
+    def __init__(self, rust_obj, java_ty, java_fn_ty_arg, c_ty, passed_as_ptr, is_ptr, var_name, arr_len, arr_access):
         self.rust_obj = rust_obj
         self.java_ty = java_ty
         self.java_fn_ty_arg = java_fn_ty_arg
@@ -16,6 +16,7 @@ class TypeInfo:
         self.is_ptr = is_ptr
         self.var_name = var_name
         self.arr_len = arr_len
+        self.arr_access = arr_access
 
 class ConvInfo:
     def __init__(self, ty_info, arg_name, arg_conv, arg_conv_name, ret_conv, ret_conv_name):
@@ -88,10 +89,21 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         is_ptr = False
         take_by_ptr = False
         rust_obj = None
+        arr_access = None
         if fn_arg.startswith("LDKThirtyTwoBytes"):
             fn_arg = "uint8_t (*" + fn_arg[18:] + ")[32]"
             assert var_is_arr_regex.match(fn_arg[8:])
             rust_obj = "LDKThirtyTwoBytes"
+            arr_access = "data"
+        if fn_arg.startswith("LDKPublicKey"):
+            fn_arg = "uint8_t (*" + fn_arg[13:] + ")[33]"
+            assert var_is_arr_regex.match(fn_arg[8:])
+            rust_obj = "LDKPublicKey"
+            arr_access = "compressed_form"
+        #if fn_arg.startswith("LDKSignature"):
+        #    fn_arg = "uint8_t (*" + fn_arg[13:] + ")[64]"
+        #    assert var_is_arr_regex.match(fn_arg[8:])
+        #    rust_obj = "LDKSignature"
 
         if fn_arg.startswith("void"):
             java_ty = "void"
@@ -162,10 +174,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             java_ty = java_ty + "[]"
             c_ty = c_ty + "Array"
             if var_is_arr is not None:
+                if var_is_arr.group(1) == "":
+                    return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty,
+                        passed_as_ptr=False, is_ptr=False, var_name="arg", arr_len=var_is_arr.group(2), arr_access=arr_access)
                 return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty,
-                    passed_as_ptr=False, is_ptr=False, var_name=var_is_arr.group(1), arr_len=var_is_arr.group(2))
+                    passed_as_ptr=False, is_ptr=False, var_name=var_is_arr.group(1), arr_len=var_is_arr.group(2), arr_access=arr_access)
         return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr,
-            is_ptr=is_ptr, var_name=fn_arg, arr_len=None)
+            is_ptr=is_ptr, var_name=fn_arg, arr_len=None, arr_access=None)
 
     def map_type(fn_arg, print_void, ret_arr_len, is_free):
         ty_info = java_c_types(fn_arg, ret_arr_len)
@@ -184,8 +199,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 arr_len = ret_arr_len
             assert(ty_info.c_ty == "jbyteArray")
             if ty_info.rust_obj is not None:
-                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref.data);"
-                arr_access = ("", ".data")
+                arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_ref." + ty_info.arr_access + ");"
+                arr_access = ("", "." + ty_info.arr_access)
             else:
                 arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n" + "(*_env)->GetByteArrayRegion (_env, " + arr_name + ", 0, " + arr_len + ", " + arr_name + "_arr);\n" + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
                 arr_access = ("*", "")
@@ -289,8 +304,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         # any _free function.
                         # To avoid any issues, we first assert that the incoming object is non-ref.
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                            ret_conv = (ty_info.rust_obj + " ret = ", ";\nDO_ASSERT(ret.is_owned);"),
-                            ret_conv_name = "((long)ret.inner) | 1",
+                            ret_conv = (ty_info.rust_obj + " ret = ", ";"),
+                            ret_conv_name = "((long)ret.inner) | (ret.is_owned ? 1 : 0)",
                             arg_conv = None, arg_conv_name = None)
                     else:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
@@ -562,7 +577,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 out_c.write(");\n");
                 if ret_ty_info.c_ty.endswith("Array"):
                     out_c.write("\t" + ret_ty_info.rust_obj + " ret;\n")
-                    out_c.write("\t(*env)->GetByteArrayRegion(env, jret, 0, " + ret_ty_info.arr_len + ", ret.data);\n")
+                    out_c.write("\t(*env)->GetByteArrayRegion(env, jret, 0, " + ret_ty_info.arr_len + ", ret." + ret_ty_info.arr_access + ");\n")
                     out_c.write("\treturn ret;\n")
 
                 if ret_ty_info.passed_as_ptr:
@@ -883,7 +898,7 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
     reg_fn_regex = re.compile("([A-Za-z_0-9\* ]* \*?)([a-zA-Z_0-9]*)\((.*)\);$")
     const_val_regex = re.compile("^extern const ([A-Za-z_0-9]*) ([A-Za-z_0-9]*);$")
 
-    line_indicates_result_regex = re.compile("^   bool result_ok;$")
+    line_indicates_result_regex = re.compile("^   (LDKCResultPtr_[A-Za-z_0-9]*) contents;$")
     line_indicates_vec_regex = re.compile("^   ([A-Za-z_0-9]*) \*data;$")
     line_indicates_opaque_regex = re.compile("^   bool is_owned;$")
     line_indicates_trait_regex = re.compile("^   ([A-Za-z_0-9]* \*?)\(\*([A-Za-z_0-9]*)\)\((const )?void \*this_arg(.*)\);$")
@@ -900,6 +915,7 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
 
     result_templ_structs = set()
     union_enum_items = {}
+    result_ptr_struct_items = {}
     for line in in_h:
         if in_block_comment:
             if line.endswith("*/\n"):
@@ -912,7 +928,7 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                 vec_ty = None
                 obj_lines = cur_block_obj.split("\n")
                 is_opaque = False
-                is_result = False
+                result_contents = None
                 is_unitary_enum = False
                 is_union_enum = False
                 is_union = False
@@ -939,8 +955,9 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                                 is_union = True
                         if line_indicates_opaque_regex.match(struct_line):
                             is_opaque = True
-                        elif line_indicates_result_regex.match(struct_line):
-                            is_result = True
+                        result_match = line_indicates_result_regex.match(struct_line)
+                        if result_match is not None:
+                            result_contents = result_match.group(1)
                         vec_ty_match = line_indicates_vec_regex.match(struct_line)
                         if vec_ty_match is not None and struct_name.startswith("LDKCVecTempl_"):
                             vec_ty = vec_ty_match.group(1)
@@ -955,18 +972,26 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                         field_lines.append(struct_line)
 
                 assert(struct_name is not None)
-                assert(len(trait_fn_lines) == 0 or not (is_opaque or is_unitary_enum or is_union_enum or is_union or is_result or vec_ty is not None))
-                assert(not is_opaque or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_union or is_result or vec_ty is not None))
-                assert(not is_unitary_enum or not (len(trait_fn_lines) != 0 or is_opaque or is_union_enum or is_union or is_result or vec_ty is not None))
-                assert(not is_union_enum or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_opaque or is_union or is_result or vec_ty is not None))
-                assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_result or vec_ty is not None))
-                assert(not is_result or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or vec_ty is not None))
-                assert(vec_ty is None or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or is_result))
+                assert(len(trait_fn_lines) == 0 or not (is_opaque or is_unitary_enum or is_union_enum or is_union or result_contents is not None or vec_ty is not None))
+                assert(not is_opaque or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_union or result_contents is not None or vec_ty is not None))
+                assert(not is_unitary_enum or not (len(trait_fn_lines) != 0 or is_opaque or is_union_enum or is_union or result_contents is not None or vec_ty is not None))
+                assert(not is_union_enum or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_opaque or is_union or result_contents is not None or vec_ty is not None))
+                assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or result_contents is not None or vec_ty is not None))
+                assert(result_contents is None or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or vec_ty is not None))
+                assert(vec_ty is None or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or result_contents is not None))
 
                 if is_opaque:
                     opaque_structs.add(struct_name)
-                elif is_result:
+                elif result_contents is not None:
                     result_templ_structs.add(struct_name)
+                    assert result_contents in result_ptr_struct_items
+                elif struct_name.startswith("LDKCResultPtr_"):
+                    for line in field_lines:
+                        if line.endswith("*result;"):
+                            res_ty = line[:-8].strip()
+                        elif line.endswith("*err;"):
+                            err_ty = line[:-5].strip()
+                    result_ptr_struct_items[struct_name] = (res_ty, err_ty)
                 elif is_tuple:
                     out_java.write("\tpublic static native long " + struct_name + "_new(")
                     out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1new(JNIEnv *_env, jclass _b")
@@ -1014,28 +1039,28 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                     out_c.write("}\n")
 
                     ty_info = map_type(vec_ty + " arr_elem", False, None, False)
-                    out_java.write("\tpublic static native long " + struct_name + "_new(" + ty_info.java_ty + "[] elems);\n")
-                    out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1new(JNIEnv *env, jclass _b, j" + ty_info.java_ty + "Array elems){\n")
-                    out_c.write("\t" + struct_name + " *ret = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
-                    out_c.write("\tret->datalen = (*env)->GetArrayLength(env, elems);\n")
-                    out_c.write("\tif (ret->datalen == 0) {\n")
-                    out_c.write("\t\tret->data = NULL;\n")
-                    out_c.write("\t} else {\n")
-                    out_c.write("\t\tret->data = MALLOC(sizeof(" + vec_ty + ") * ret->datalen, \"" + struct_name + " Data\");\n")
-                    assert len(ty_info.java_fn_ty_arg) == 1 # ie we're a primitive of some form
-                    out_c.write("\t\t" + ty_info.c_ty + " *java_elems = (*env)->GetPrimitiveArrayCritical(env, elems, NULL);\n")
-                    out_c.write("\t\tfor (size_t i = 0; i < ret->datalen; i++) {\n")
-                    if ty_info.arg_conv is not None:
-                        out_c.write("\t\t\t" + ty_info.c_ty + " arr_elem = java_elems[i];\n")
-                        out_c.write("\t\t\t" + ty_info.arg_conv.replace("\n", "\n\t\t\t") + "\n")
-                        out_c.write("\t\t\tret->data[i] = " + ty_info.arg_conv_name + ";\n")
-                    else:
-                        out_c.write("\t\t\tret->data[i] = java_elems[i];\n")
-                    out_c.write("\t\t}\n")
-                    out_c.write("\t\t(*env)->ReleasePrimitiveArrayCritical(env, elems, java_elems, 0);\n")
-                    out_c.write("\t}\n")
-                    out_c.write("\treturn (long)ret;\n")
-                    out_c.write("}\n")
+                    if len(ty_info.java_fn_ty_arg) == 1: # ie we're a primitive of some form
+                        out_java.write("\tpublic static native long " + struct_name + "_new(" + ty_info.java_ty + "[] elems);\n")
+                        out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1new(JNIEnv *env, jclass _b, j" + ty_info.java_ty + "Array elems){\n")
+                        out_c.write("\t" + struct_name + " *ret = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
+                        out_c.write("\tret->datalen = (*env)->GetArrayLength(env, elems);\n")
+                        out_c.write("\tif (ret->datalen == 0) {\n")
+                        out_c.write("\t\tret->data = NULL;\n")
+                        out_c.write("\t} else {\n")
+                        out_c.write("\t\tret->data = MALLOC(sizeof(" + vec_ty + ") * ret->datalen, \"" + struct_name + " Data\");\n")
+                        out_c.write("\t\t" + ty_info.c_ty + " *java_elems = (*env)->GetPrimitiveArrayCritical(env, elems, NULL);\n")
+                        out_c.write("\t\tfor (size_t i = 0; i < ret->datalen; i++) {\n")
+                        if ty_info.arg_conv is not None:
+                            out_c.write("\t\t\t" + ty_info.c_ty + " arr_elem = java_elems[i];\n")
+                            out_c.write("\t\t\t" + ty_info.arg_conv.replace("\n", "\n\t\t\t") + "\n")
+                            out_c.write("\t\t\tret->data[i] = " + ty_info.arg_conv_name + ";\n")
+                        else:
+                            out_c.write("\t\t\tret->data[i] = java_elems[i];\n")
+                        out_c.write("\t\t}\n")
+                        out_c.write("\t\t(*env)->ReleasePrimitiveArrayCritical(env, elems, java_elems, 0);\n")
+                        out_c.write("\t}\n")
+                        out_c.write("\treturn (long)ret;\n")
+                        out_c.write("}\n")
                 elif is_union_enum:
                     assert(struct_name.endswith("_Tag"))
                     struct_name = struct_name[:-4]
@@ -1078,10 +1103,19 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                     out_c.write("\treturn ((" + alias_match.group(2) + "*)arg)->result_ok;\n")
                     out_c.write("}\n")
                     out_c.write("JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_" + alias_match.group(2).replace("_", "_1") + "_1get_1inner (JNIEnv * env, jclass _a, jlong arg) {\n")
-                    out_c.write("\tif (((" + alias_match.group(2) + "*)arg)->result_ok) {\n")
-                    out_c.write("\t\treturn (long)((" + alias_match.group(2) + "*)arg)->contents.result;\n")
+                    contents_ty = alias_match.group(1).replace("LDKCResultTempl", "LDKCResultPtr")
+                    res_ty, err_ty = result_ptr_struct_items[contents_ty]
+                    out_c.write("\t" + alias_match.group(2) + " *val = (" + alias_match.group(2) + "*)arg;\n")
+                    out_c.write("\tif (val->result_ok) {\n")
+                    if res_ty not in opaque_structs:
+                        out_c.write("\t\treturn (long)val->contents.result;\n")
+                    else:
+                        out_c.write("\t\treturn (long)(val->contents.result->inner) | (val->contents.result->is_owned ? 1 : 0);\n")
                     out_c.write("\t} else {\n")
-                    out_c.write("\t\treturn (long)((" + alias_match.group(2) + "*)arg)->contents.err;\n")
+                    if err_ty not in opaque_structs:
+                        out_c.write("\t\treturn (long)val->contents.err;\n")
+                    else:
+                        out_c.write("\t\treturn (long)(val->contents.err->inner) | (val->contents.err->is_owned ? 1 : 0);\n")
                     out_c.write("\t}\n}\n")
                 pass
             elif fn_ptr is not None: