Map enum functions into human structs and Result types sanely
[ldk-java] / genbindings.py
index 124a76bac1e2688b9fbdc3ea1d29a7db645986f8..2d6771c30d792786873456ff9fd98884d284ec84 100755 (executable)
@@ -88,6 +88,7 @@ def java_c_types(fn_arg, ret_arr_len):
         fn_arg = fn_arg[7:]
     if fn_arg.startswith("enum "):
         fn_arg = fn_arg[5:]
+    nonnull_ptr = "NONNULL_PTR" in fn_arg
     fn_arg = fn_arg.replace("NONNULL_PTR", "")
 
     is_ptr = False
@@ -172,11 +173,13 @@ def java_c_types(fn_arg, ret_arr_len):
             res.pass_by_ref = True
         if res.is_native_primitive or res.passed_as_ptr:
             return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=res.java_ty + "[]", java_hu_ty=res.java_hu_ty + "[]",
-                java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=res.c_ty + "Array", passed_as_ptr=False, is_ptr=is_ptr, is_const=is_const,
+                java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=res.c_ty + "Array", passed_as_ptr=False, is_ptr=is_ptr,
+                nonnull_ptr=nonnull_ptr, is_const=is_const,
                 var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False)
         else:
             return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=res.java_ty + "[]", java_hu_ty=res.java_hu_ty + "[]",
-                java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=consts.ptr_arr, passed_as_ptr=False, is_ptr=is_ptr, is_const=is_const,
+                java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=consts.ptr_arr, passed_as_ptr=False, is_ptr=is_ptr,
+                nonnull_ptr=nonnull_ptr, is_const=is_const,
                 var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False)
 
     is_primitive = False
@@ -296,7 +299,7 @@ def java_c_types(fn_arg, ret_arr_len):
         else:
             c_ty = consts.ptr_c_ty
             java_ty = consts.ptr_native_ty
-            java_hu_ty = ma.group(1).strip().replace("LDKCResult", "Result").replace("LDK", "")
+            java_hu_ty = ma.group(1).strip().replace("LDKCOption", "Option").replace("LDKCResult", "Result").replace("LDK", "")
             fn_ty_arg = "J"
             fn_arg = ma.group(2).strip()
             rust_obj = ma.group(1).strip()
@@ -322,14 +325,16 @@ def java_c_types(fn_arg, ret_arr_len):
         if var_is_arr is not None:
             if var_is_arr.group(1) == "":
                 return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
-                    passed_as_ptr=False, is_ptr=False, var_name="arg", arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False)
+                    passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg",
+                    arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False)
             return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
-                passed_as_ptr=False, is_ptr=False, var_name=var_is_arr.group(1), arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False)
+                passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1),
+                arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False)
 
     if java_hu_ty is None:
         java_hu_ty = java_ty
     return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr,
-        is_const=is_const, is_ptr=is_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive)
+        is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive)
 
 fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$")
 fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")
@@ -376,7 +381,10 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
         method_arguments = method_comma_separated_arguments.split(',')
 
         is_free = method_name.endswith("_free")
-        struct_meth = method_name.split("_")[0]
+        if method_name.startswith("COption"):
+            struct_meth = method_name.rsplit("Z", 1)[0][1:] + "Z"
+        else:
+            struct_meth = method_name.split("_")[0]
 
         return_type_info = type_mapping_generator.map_type(method_return_type, True, ret_arr_len, False, False)
 
@@ -432,7 +440,8 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             write_c(out_c_delta)
 
         out_java_struct = None
-        if ("LDK" + struct_meth in opaque_structs or "LDK" + struct_meth in trait_structs) and not is_free:
+        if ("LDK" + struct_meth in opaque_structs or "LDK" + struct_meth in trait_structs
+                or "LDK" + struct_meth in complex_enums or "LDKC" + struct_meth in complex_enums) and not is_free:
             out_java_struct = open(f"{sys.argv[3]}/structs/{struct_meth}{consts.file_ext}", "a")
         elif method_name.startswith("C2Tuple_") and method_name.endswith("_read"):
             struct_meth = method_name.rsplit("_", 1)[0]
@@ -458,7 +467,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             out_java.write(native_out)
 
     def map_complex_enum(struct_name, union_enum_items, enum_doc_comment):
-        java_hu_type = struct_name.replace("LDK", "")
+        java_hu_type = struct_name.replace("LDK", "").replace("COption", "Option")
         complex_enums.add(struct_name)
 
         enum_variants = []
@@ -527,7 +536,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
             # For now, just disable enabling the _call_log - we don't know how to inverse-map String
             is_log = fn_line.group(3) == "log" and struct_name == "LDKLogger"
             if fn_line.group(3) != "free" and fn_line.group(3) != "clone" and fn_line.group(3) != "eq" and not is_log:
-                dummy_line = fn_line.group(2) + struct_name.replace("LDK", "") + "_" + fn_line.group(3) + " " + struct_name + "* this_arg" + fn_line.group(5) + "\n"
+                dummy_line = fn_line.group(2) + struct_name.replace("LDK", "") + "_" + fn_line.group(3) + " " + struct_name + " *NONNULL_PTR this_arg" + fn_line.group(5) + "\n"
                 map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(this_arg_conv->" + fn_line.group(3) + ")(this_arg_conv->this_arg", fn_docs)
         for idx, var_line in enumerate(field_var_lines):
             if var_line.group(1) not in trait_structs:
@@ -536,7 +545,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
                 write_c("\t\tthis_arg->set_" + var_line.group(2) + "(this_arg);\n")
                 write_c("\treturn this_arg->" + var_line.group(2) + ";\n")
                 write_c("}\n")
-                dummy_line = var_line.group(1) + " " + struct_name.replace("LDK", "") + "_get_" + var_line.group(2) + " " + struct_name + "* this_arg" + fn_line.group(5) + "\n"
+                dummy_line = var_line.group(1) + " " + struct_name.replace("LDK", "") + "_get_" + var_line.group(2) + " " + struct_name + " *NONNULL_PTR this_arg" + fn_line.group(5) + "\n"
                 map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, struct_name + "_set_get_" + var_line.group(2) + "(this_arg_conv", fn_docs)
 
     def map_result(struct_name, res_ty, err_ty):
@@ -922,6 +931,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java:
     for struct_name in trait_structs:
         with open(f"{sys.argv[3]}/structs/{struct_name.replace('LDK', '')}{consts.file_ext}", "a") as out_java_struct:
             out_java_struct.write("}\n")
+    for struct_name in complex_enums:
+        with open(f"{sys.argv[3]}/structs/{struct_name.replace('LDK', '').replace('COption', 'Option')}{consts.file_ext}", "a") as out_java_struct:
+            out_java_struct.write("}\n")
 
 with open(sys.argv[4], "w") as out_c:
     out_c.write(consts.c_file_pfx)