Properly set CVec_u8Z to a byte[] which adds a ton more fn's
[ldk-java] / genbindings.py
index b1965938a6805e14d40950f50d2d4a5e48b73858..cc03661a62f4df6f045d8f17fd9ac941a31349ac 100755 (executable)
@@ -127,10 +127,15 @@ def java_c_types(fn_arg, ret_arr_len):
         assert var_is_arr_regex.match(fn_arg[8:])
         rust_obj = "LDKu8slice"
         arr_access = "data"
-    if fn_arg.startswith("LDKCVecTempl_u8"):
-        fn_arg = "uint8_t (*" + fn_arg[11:] + ")[datalen]"
-        assert var_is_arr_regex.match(fn_arg[8:])
-        rust_obj = "LDKCVecTempl_u8"
+    if fn_arg.startswith("LDKCVecTempl_u8") or fn_arg.startswith("LDKCVec_u8Z"):
+        if fn_arg.startswith("LDKCVecTempl_u8"):
+            fn_arg = "uint8_t (*" + fn_arg[16:] + ")[datalen]"
+            rust_obj = "LDKCVecTempl_u8"
+            assert var_is_arr_regex.match(fn_arg[8:])
+        else:
+            fn_arg = "uint8_t (*" + fn_arg[12:] + ")[datalen]"
+            rust_obj = "LDKCVec_u8Z"
+            assert var_is_arr_regex.match(fn_arg[8:])
         arr_access = "data"
 
     if fn_arg.startswith("void"):
@@ -423,7 +428,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                     arg_conv = None, arg_conv_name = None, arg_conv_cleanup = None,
                     ret_conv = None, ret_conv_name = None,
-                    to_hu_conv = ("TODO d", ""), from_hu_conv = None)
+                    to_hu_conv = None, from_hu_conv = None)
 
     def map_fn(line, re_match, ret_arr_len, c_call_string):
         out_java.write("\t// " + line)
@@ -702,7 +707,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 out_c.write("\t" + struct_name + "_" + var_name + "_meth = (*env)->GetMethodID(env, " + struct_name + "_" + var_name + "_class, \"<init>\", \"(" + init_meth_jty_strs[var_name] + ")V\");\n")
                 out_c.write("\tCHECK(" + struct_name + "_" + var_name + "_meth != NULL);\n")
         out_c.write("}\n")
-        out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (JNIEnv * env, jclass _c, jlong ptr) {\n")
+        out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1ref_1from_1ptr (JNIEnv * _env, jclass _c, jlong ptr) {\n")
         out_c.write("\t" + struct_name + " *obj = (" + struct_name + "*)ptr;\n")
         out_c.write("\tswitch(obj->tag) {\n")
         for idx, struct_line in enumerate(tag_field_lines):
@@ -716,13 +721,13 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         if idx != 0 and idx < len(enum_var_lines) - 2:
                             field_map = map_type(field.strip(' ;'), False, None, False)
                             if field_map.ret_conv is not None:
-                                out_c.write("\t\t\t" + field_map.ret_conv[0].replace("\n", "\n\t\t\t").replace("_env", "env"))
+                                out_c.write("\t\t\t" + field_map.ret_conv[0].replace("\n", "\n\t\t\t"))
                                 out_c.write("obj->" + camel_to_snake(var_name) + "." + field_map.arg_name)
                                 out_c.write(field_map.ret_conv[1].replace("\n", "\n\t\t\t") + "\n")
                                 c_params_text = c_params_text + ", " + field_map.ret_conv_name
                             else:
                                 c_params_text = c_params_text + ", obj->" + camel_to_snake(var_name) + "." + field_map.arg_name
-                out_c.write("\t\t\treturn (*env)->NewObject(env, " + struct_name + "_" + var_name + "_class, " + struct_name + "_" + var_name + "_meth" + c_params_text + ");\n")
+                out_c.write("\t\t\treturn (*_env)->NewObject(_env, " + struct_name + "_" + var_name + "_class, " + struct_name + "_" + var_name + "_meth" + c_params_text + ");\n")
                 out_c.write("\t\t}\n")
         out_c.write("\t\tdefault: abort();\n")
         out_c.write("\t}\n}\n")
@@ -746,7 +751,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             out_java_trait.write("import org.ldk.enums.*;\n\n")
             out_java_trait.write("public class " + struct_name.replace("LDK","") + " extends CommonBase {\n")
             out_java_trait.write("\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); }\n")
-            out_java_trait.write("\tpublic " + struct_name.replace("LDK", "") + "(bindings." + struct_name + " arg")
+            out_java_trait.write("\tpublic " + struct_name.replace("LDK", "") + "(bindings." + struct_name + " arg") # XXX: Should be priv
             for var_line in field_var_lines:
                 if var_line.group(1) in trait_structs:
                     out_java_trait.write(", bindings." + var_line.group(1) + " " + var_line.group(2))
@@ -768,7 +773,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             for fn_line in trait_fn_lines:
                 java_meth_descr = "("
                 if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
-                    ret_ty_info = java_c_types(fn_line.group(1), None)
+                    ret_ty_info = map_type(fn_line.group(1), True, None, False)
 
                     out_java.write("\t\t " + ret_ty_info.java_ty + " " + fn_line.group(2) + "(")
                     is_const = fn_line.group(3) is not None
@@ -796,23 +801,23 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                     out_java.write(");\n")
                     out_c.write(") {\n")
                     out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
-                    out_c.write("\tJNIEnv *env;\n")
-                    out_c.write("\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n")
+                    out_c.write("\tJNIEnv *_env;\n")
+                    out_c.write("\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&_env, JNI_VERSION_1_8) == JNI_OK);\n")
 
                     for arg_info in arg_names:
                         if arg_info.ret_conv is not None:
-                            out_c.write("\t" + arg_info.ret_conv[0].replace('\n', '\n\t').replace("_env", "env"));
+                            out_c.write("\t" + arg_info.ret_conv[0].replace('\n', '\n\t'));
                             out_c.write(arg_info.arg_name)
-                            out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "env") + "\n")
+                            out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t') + "\n")
 
-                    out_c.write("\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tCHECK(obj != NULL);\n")
+                    out_c.write("\tjobject obj = (*_env)->NewLocalRef(_env, j_calls->o);\n\tCHECK(obj != NULL);\n")
                     if ret_ty_info.c_ty.endswith("Array"):
                         assert(ret_ty_info.c_ty == "jbyteArray")
-                        out_c.write("\tjbyteArray jret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.group(2) + "_meth")
+                        out_c.write("\tjbyteArray ret = (*_env)->CallObjectMethod(_env, obj, j_calls->" + fn_line.group(2) + "_meth")
                     elif not ret_ty_info.passed_as_ptr:
-                        out_c.write("\treturn (*env)->Call" + ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.group(2) + "_meth")
+                        out_c.write("\treturn (*_env)->Call" + ret_ty_info.java_ty.title() + "Method(_env, obj, j_calls->" + fn_line.group(2) + "_meth")
                     else:
-                        out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*env)->CallLongMethod(env, obj, j_calls->" + fn_line.group(2) + "_meth");
+                        out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*_env)->CallLongMethod(_env, obj, j_calls->" + fn_line.group(2) + "_meth");
 
                     for arg_info in arg_names:
                         if arg_info.ret_conv is not None:
@@ -820,11 +825,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         else:
                             out_c.write(", " + arg_info.arg_name)
                     out_c.write(");\n");
-                    if ret_ty_info.c_ty.endswith("Array"):
-                        out_c.write("\t" + ret_ty_info.rust_obj + " ret;\n")
-                        out_c.write("\tCHECK((*env)->GetArrayLength(env, jret) == " + ret_ty_info.arr_len + ");\n")
-                        out_c.write("\t(*env)->GetByteArrayRegion(env, jret, 0, " + ret_ty_info.arr_len + ", ret." + ret_ty_info.arr_access + ");\n")
-                        out_c.write("\treturn ret;\n")
+                    if ret_ty_info.arg_conv is not None:
+                        out_c.write("\t" + ret_ty_info.arg_conv.replace("\n", "\n\t").replace("arg", "ret") + "\n\treturn " + ret_ty_info.arg_conv_name.replace("arg", "ret") + ";\n")
 
                     if ret_ty_info.passed_as_ptr:
                         out_c.write("\t" + fn_line.group(1).strip() + " res = *ret;\n")
@@ -1286,7 +1288,8 @@ class CommonBase {
                                 out_c.write("\n\tret->" + e + " = " + ty_info.arg_conv_name + ";\n")
                             else:
                                 out_c.write("\tret->" + e + " = " + e + ";\n")
-                            assert ty_info.arg_conv_cleanup is None
+                            if ty_info.arg_conv_cleanup is not None:
+                                out_c.write("\t//TODO: Really need to call " + ty_info.arg_conv_cleanup + " here\n")
                     out_c.write("\treturn (long)ret;\n")
                     out_c.write("}\n")
                 elif vec_ty is not None: