Redo tuple mapping to be explicit and not generic
[ldk-java] / genbindings.py
index 3733d540bc06f95e65e2221a9b458e83be29a862..1aa2c8ee3729b50f06c154cf4430c7e23779cc77 100755 (executable)
@@ -100,7 +100,7 @@ trait_structs = {}
 result_types = set()
 tuple_types = {}
 
-var_is_arr_regex = re.compile("\(\*([A-za-z0-9_]*)\)\[([a-z0-9]*)\]")
+var_is_arr_regex = re.compile("\(\* ?([A-za-z0-9_]*)\)\[([a-z0-9]*)\]")
 var_ty_regex = re.compile("([A-za-z_0-9]*)(.*)")
 java_c_types_none_allowed = True # Unset when we do the real pass that populates the above sets
 def java_c_types(fn_arg, ret_arr_len):
@@ -304,54 +304,15 @@ def java_c_types(fn_arg, ret_arr_len):
             fn_ty_arg = "Lorg/ldk/enums/" + java_ty + ";"
             fn_arg = ma.group(2).strip()
             rust_obj = ma.group(1).strip()
-        elif ma.group(1).strip().startswith("LDKC2Tuple"):
-            c_ty = consts.ptr_c_ty
-            java_ty = consts.ptr_native_ty
-            java_hu_ty = "TwoTuple<"
-            if not ma.group(1).strip() in tuple_types:
-                assert java_c_types_none_allowed
-                return None
-            for idx, ty_info in enumerate(tuple_types[ma.group(1).strip()][0]):
-                if idx != 0:
-                    java_hu_ty = java_hu_ty + ", "
-                if ty_info.is_native_primitive:
-                    if ty_info.java_hu_ty == "int":
-                        java_hu_ty = java_hu_ty + "Integer" # Java concrete integer type is Integer, not Int
-                    else:
-                        java_hu_ty = java_hu_ty + ty_info.java_hu_ty.title() # If we're a primitive, capitalize the first letter
-                else:
-                    java_hu_ty = java_hu_ty + ty_info.java_hu_ty
-            java_hu_ty = java_hu_ty + ">"
-            fn_ty_arg = "J"
-            fn_arg = ma.group(2).strip()
-            rust_obj = ma.group(1).strip()
-            take_by_ptr = True
-        elif ma.group(1).strip().startswith("LDKC3Tuple"):
-            c_ty = consts.ptr_c_ty
-            java_ty = consts.ptr_native_ty
-            java_hu_ty = "ThreeTuple<"
-            if not ma.group(1).strip() in tuple_types:
-                assert java_c_types_none_allowed
-                return None
-            for idx, ty_info in enumerate(tuple_types[ma.group(1).strip()][0]):
-                if idx != 0:
-                    java_hu_ty = java_hu_ty + ", "
-                if ty_info.is_native_primitive:
-                    if ty_info.java_hu_ty == "int":
-                        java_hu_ty = java_hu_ty + "Integer" # Java concrete integer type is Integer, not Int
-                    else:
-                        java_hu_ty = java_hu_ty + ty_info.java_hu_ty.title() # If we're a primitive, capitalize the first letter
-                else:
-                    java_hu_ty = java_hu_ty + ty_info.java_hu_ty
-            java_hu_ty = java_hu_ty + ">"
-            fn_ty_arg = "J"
-            fn_arg = ma.group(2).strip()
-            rust_obj = ma.group(1).strip()
-            take_by_ptr = True
         else:
             c_ty = consts.ptr_c_ty
             java_ty = consts.ptr_native_ty
-            java_hu_ty = ma.group(1).strip().replace("LDKCOption", "Option").replace("LDKCResult", "Result").replace("LDK", "")
+            java_hu_ty = ma.group(1).strip()
+            java_hu_ty = java_hu_ty.replace("LDKCOption", "Option")
+            java_hu_ty = java_hu_ty.replace("LDKCResult", "Result")
+            java_hu_ty = java_hu_ty.replace("LDKC2Tuple", "TwoTuple")
+            java_hu_ty = java_hu_ty.replace("LDKC3Tuple", "ThreeTuple")
+            java_hu_ty = java_hu_ty.replace("LDK", "")
             fn_ty_arg = "J"
             fn_arg = ma.group(2).strip()
             rust_obj = ma.group(1).strip()
@@ -427,19 +388,29 @@ with open(f"{sys.argv[3]}/structs/UtilMethods{consts.file_ext}", "a") as util:
 with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}", "w") as out_java:
     # Map a top-level function
     def map_fn(line, re_match, ret_arr_len, c_call_string, doc_comment):
+        map_fn_with_ref_option(line, re_match, ret_arr_len, c_call_string, doc_comment, False)
+    def map_fn_with_ref_option(line, re_match, ret_arr_len, c_call_string, doc_comment, force_holds_ref):
         method_return_type = re_match.group(1)
         method_name = re_match.group(2)
-        orig_method_name = str(method_name)
         method_comma_separated_arguments = re_match.group(3)
         method_arguments = method_comma_separated_arguments.split(',')
 
         is_free = method_name.endswith("_free")
         if method_name.startswith("COption") or method_name.startswith("CResult"):
             struct_meth = method_name.rsplit("Z", 1)[0][1:] + "Z"
+            expected_struct = "LDKC" + struct_meth
+            struct_meth_name = method_name[len(struct_meth) + 1:].strip("_")
+        elif method_name.startswith("C2Tuple"):
+            tuple_name = method_name.rsplit("Z", 1)[0][2:] + "Z"
+            struct_meth = "Two" + tuple_name
+            expected_struct = "LDKC2" + tuple_name
+            struct_meth_name = method_name[len(tuple_name) + 2:].strip("_")
         else:
             struct_meth = method_name.split("_")[0]
+            expected_struct = "LDK" + struct_meth
+            struct_meth_name = method_name[len(struct_meth) + 1 if len(struct_meth) != 0 else 0:].strip("_")
 
-        return_type_info = type_mapping_generator.map_type(method_return_type.strip() + " ret", True, ret_arr_len, False, False)
+        return_type_info = type_mapping_generator.map_type(method_return_type.strip() + " ret", True, ret_arr_len, False, force_holds_ref)
 
         (params_nullable, ret_nullable) = doc_to_params_ret_nullable(doc_comment)
         if ret_nullable:
@@ -491,12 +462,13 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
             argument_types.append(argument_conversion_info)
         if not takes_self and return_type_info.java_hu_ty != struct_meth:
             if not return_type_info.java_hu_ty.startswith("Result_" + struct_meth):
-                method_name = orig_method_name
+                struct_meth_name = method_name
                 struct_meth = ""
+                expected_struct = ""
 
         out_java.write("\t// " + line)
         (out_java_delta, out_c_delta, out_java_struct_delta) = \
-            consts.map_function(argument_types, c_call_string, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_ptr, args_known, type_mapping_generator, doc_comment)
+            consts.map_function(argument_types, c_call_string, method_name, struct_meth_name, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_ptr, args_known, type_mapping_generator, doc_comment)
         out_java.write(out_java_delta)
 
         if is_free:
@@ -522,11 +494,9 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
             write_c(out_c_delta)
 
         out_java_struct = None
-        expected_struct = "LDK" + struct_meth
-        expected_cstruct = "LDKC" + struct_meth
         if (expected_struct in opaque_structs or expected_struct in trait_structs
-                or expected_struct in complex_enums or expected_cstruct in complex_enums
-                or expected_cstruct in result_types) and not is_free:
+                or expected_struct in complex_enums or expected_struct in complex_enums
+                or expected_struct in result_types or expected_struct in tuple_types) and not is_free:
             out_java_struct = open(f"{sys.argv[3]}/structs/{struct_meth}{consts.file_ext}", "a")
             out_java_struct.write(out_java_struct_delta)
         elif (not is_free and not method_name.endswith("_clone") and
@@ -756,47 +726,64 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
     def map_tuple(struct_name, field_lines):
         out_java.write("\tpublic static native long " + struct_name + "_new(")
         write_c(consts.c_fn_ty_pfx + consts.ptr_c_ty + " " + consts.c_fn_name_define_pfx(struct_name + "_new", len(field_lines) > 3))
-        ty_list = []
-        for idx, (line, _) in enumerate(field_lines):
-            if idx != 0 and idx < len(field_lines) - 2:
-                ty_info = java_c_types(line.strip(';'), None)
-                if idx != 1:
-                    out_java.write(", ")
-                    write_c(", ")
-                e = chr(ord('a') + idx - 1)
-                out_java.write(ty_info.java_ty + " " + e)
-                write_c(ty_info.c_ty + " " + e)
-                ty_list.append(ty_info)
-        tuple_types[struct_name] = (ty_list, struct_name)
-        out_java.write(");\n")
-        write_c(") {\n")
-        write_c("\t" + struct_name + "* ret = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
+        human_ty = struct_name.replace("LDKC2Tuple", "TwoTuple").replace("LDKC3Tuple", "ThreeTuple")
+        with open(f"{sys.argv[3]}/structs/{human_ty}{consts.file_ext}", "w") as out_java_struct:
+            out_java_struct.write(consts.map_tuple(struct_name))
+            ty_list = []
+            for idx, (line, _) in enumerate(field_lines):
+                if idx != 0 and idx < len(field_lines) - 2:
+                    ty_info = java_c_types(line.strip(';'), None)
+                    if idx != 1:
+                        out_java.write(", ")
+                        write_c(", ")
+                    e = chr(ord('a') + idx - 1)
+                    out_java.write(ty_info.java_ty + " " + e)
+                    write_c(ty_info.c_ty + " " + e)
+                    ty_list.append(ty_info)
+            tuple_types[struct_name] = (ty_list, struct_name)
+            out_java.write(");\n")
+            write_c(") {\n")
+            write_c("\t" + struct_name + "* ret = MALLOC(sizeof(" + struct_name + "), \"" + struct_name + "\");\n")
+            for idx, (line, _) in enumerate(field_lines):
+                if idx != 0 and idx < len(field_lines) - 2:
+                    ty_info = type_mapping_generator.map_type(line.strip(';'), False, None, False, False)
+                    e = chr(ord('a') + idx - 1)
+                    if ty_info.arg_conv is not None:
+                        write_c("\t" + ty_info.arg_conv.replace("\n", "\n\t"))
+                        write_c("\n\tret->" + e + " = " + ty_info.arg_conv_name + ";\n")
+                    else:
+                        write_c("\tret->" + e + " = " + e + ";\n")
+                    if ty_info.arg_conv_cleanup is not None:
+                        write_c("\t//TODO: Really need to call " + ty_info.arg_conv_cleanup + " here\n")
+            write_c("\treturn (uint64_t)ret;\n")
+            write_c("}\n")
+
+        # Map virtual getter functions
         for idx, (line, _) in enumerate(field_lines):
             if idx != 0 and idx < len(field_lines) - 2:
-                ty_info = type_mapping_generator.map_type(line.strip(';'), False, None, False, False)
-                e = chr(ord('a') + idx - 1)
-                if ty_info.arg_conv is not None:
-                    write_c("\t" + ty_info.arg_conv.replace("\n", "\n\t"))
-                    write_c("\n\tret->" + e + " = " + ty_info.arg_conv_name + ";\n")
+                field_name = chr(ord('a') + idx - 1)
+                assert line.endswith(" " + field_name + ";")
+                field_ty = java_c_types(line[:-1], None)
+                ptr_fn_defn = line[:-3].strip() + " *" + struct_name.replace("LDK", "") + "_get_" + field_name + "(" + struct_name + " *NONNULL_PTR tuple)"
+                owned_fn_defn = line[:-3].strip() + " " + struct_name.replace("LDK", "") + "_get_" + field_name + "(" + struct_name + " *NONNULL_PTR tuple)"
+
+                holds_ref = False
+                if field_ty.rust_obj is not None and field_ty.rust_obj.replace("LDK", "") + "_clone" in clone_fns:
+                    fn_defn = owned_fn_defn
+                    write_c("static inline " + fn_defn + "{\n")
+                    write_c("\treturn " + field_ty.rust_obj.replace("LDK", "") + "_clone(&tuple->" + field_name + ");\n")
+                elif field_ty.arr_len is not None or field_ty.is_native_primitive:
+                    fn_defn = owned_fn_defn
+                    write_c("static inline " + fn_defn + "{\n")
+                    write_c("\treturn tuple->" + field_name + ";\n")
                 else:
-                    write_c("\tret->" + e + " = " + e + ";\n")
-                if ty_info.arg_conv_cleanup is not None:
-                    write_c("\t//TODO: Really need to call " + ty_info.arg_conv_cleanup + " here\n")
-        write_c("\treturn (uint64_t)ret;\n")
-        write_c("}\n")
-
-        for idx, ty_info in enumerate(ty_list):
-            e = chr(ord('a') + idx)
-            out_java.write("\tpublic static native " + ty_info.java_ty + " " + struct_name + "_get_" + e + "(long ptr);\n")
-            write_c(consts.c_fn_ty_pfx + ty_info.c_ty + " " + consts.c_fn_name_define_pfx(struct_name + "_get_" + e, True) + consts.ptr_c_ty + " ptr) {\n")
-            write_c("\t" + struct_name + " *tuple = (" + struct_name + "*)(ptr & ~1);\n")
-            conv_info = type_mapping_generator.map_type_with_info(ty_info, False, None, False, True)
-            if conv_info.ret_conv is not None:
-                write_c("\t" + conv_info.ret_conv[0].replace("\n", "\n\t") + "tuple->" + e + conv_info.ret_conv[1].replace("\n", "\n\t") + "\n")
-                write_c("\treturn " + conv_info.ret_conv_name + ";\n")
-            else:
-                write_c("\treturn tuple->" + e + ";\n")
-            write_c("}\n")
+                    fn_defn = ptr_fn_defn
+                    write_c("static inline " + fn_defn + "{\n")
+                    write_c("\treturn &tuple->" + field_name + ";\n")
+                    holds_ref = True
+                write_c("}\n")
+                dummy_line = fn_defn + ";\n"
+                map_fn_with_ref_option(dummy_line, reg_fn_regex.match(dummy_line), None, None, "", holds_ref)
 
     out_java.write(consts.bindings_header)
     with open(f"{sys.argv[2]}/version{consts.file_ext}", "w") as out_java_version:
@@ -1087,6 +1074,10 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
     for struct_name in result_types:
         with open(f"{sys.argv[3]}/structs/{struct_name.replace('LDKCResult', 'Result')}{consts.file_ext}", "a") as out_java_struct:
             out_java_struct.write("}\n")
+    for struct_name in tuple_types:
+        struct_hu_name = struct_name.replace("LDKC2Tuple", "TwoTuple").replace("LDKC3Tuple", "ThreeTuple")
+        with open(f"{sys.argv[3]}/structs/{struct_hu_name}{consts.file_ext}", "a") as out_java_struct:
+            out_java_struct.write("}\n")
 
 with open(f"{sys.argv[4]}/bindings.c.body", "w") as out_c:
     out_c.write(consts.c_file_pfx)