Move enums into their own (non-impl) folder
[ldk-java] / genbindings.py
index a9d91be801b3a1c6bd32481ad12ed05ce22b89e7..d9999d963516d923377cbb821bc887cd1e03c78d 100755 (executable)
@@ -1,8 +1,8 @@
 #!/usr/bin/env python3
 import sys, re
 
-if len(sys.argv) != 5:
-    print("USAGE: /path/to/lightning.h /path/to/bindings/output.java /path/to/bindings/output.c debug")
+if len(sys.argv) != 6:
+    print("USAGE: /path/to/lightning.h /path/to/bindings/output.java /path/to/bindings/enums/ /path/to/bindings/output.c debug")
     print("debug should be true or false and indicates whether to track allocations and ensure we don't leak")
     sys.exit(1)
 
@@ -43,7 +43,7 @@ class ConvInfo:
             out_java.write(" arg")
             out_c.write(" arg")
 
-with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.argv[3], "w") as out_c:
+with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.argv[4], "w") as out_c:
     opaque_structs = set()
     trait_structs = set()
     unitary_enums = set()
@@ -105,7 +105,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             if ma.group(1).strip() in unitary_enums:
                 java_ty = ma.group(1).strip()
                 c_ty = "jclass"
-                fn_ty_arg = "Lorg/ldk/impl/bindings$" + ma.group(1).strip() + ";"
+                fn_ty_arg = "Lorg/ldk/enums/" + ma.group(1).strip() + ";"
                 fn_arg = ma.group(2).strip()
                 rust_obj = ma.group(1).strip()
                 take_by_ptr = True
@@ -528,6 +528,7 @@ void __attribute__((destructor)) check_leaks() {
 }
 """)
     out_java.write("""package org.ldk.impl;
+import org.ldk.enums.*;
 
 public class bindings {
        public static class VecOrSliceDef {
@@ -700,11 +701,13 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                         field_lines.append(struct_line)
 
                 assert(struct_name is not None)
-                assert(len(trait_fn_lines) == 0 or not (is_opaque or is_unitary_enum or is_union_enum or is_union))
-                assert(not is_opaque or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_union))
-                assert(not is_unitary_enum or not (len(trait_fn_lines) != 0 or is_opaque or is_union_enum or is_union))
-                assert(not is_union_enum or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_opaque or is_union))
-                assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque))
+                assert(len(trait_fn_lines) == 0 or not (is_opaque or is_unitary_enum or is_union_enum or is_union or is_result or vec_ty is not None))
+                assert(not is_opaque or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_union or is_result or vec_ty is not None))
+                assert(not is_unitary_enum or not (len(trait_fn_lines) != 0 or is_opaque or is_union_enum or is_union or is_result or vec_ty is not None))
+                assert(not is_union_enum or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_opaque or is_union or is_result or vec_ty is not None))
+                assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_result or vec_ty is not None))
+                assert(not is_result or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or vec_ty is not None))
+                assert(vec_ty is None or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or is_result))
                 if is_opaque:
                     opaque_structs.add(struct_name)
                     out_java.write("\tpublic static native long " + struct_name + "_optional_none();\n")
@@ -722,45 +725,47 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                     out_c.write("\treturn (*env)->NewObject(env, slicedef_cls, slicedef_meth, (long)vec->data, (long)vec->datalen, sizeof(" + vec_ty + "));\n")
                     out_c.write("}\n")
                 elif is_unitary_enum:
-                    unitary_enums.add(struct_name)
-                    out_c.write("static inline " + struct_name + " " + struct_name + "_from_java(JNIEnv *env, jclass val) {\n")
-                    out_c.write("\tswitch ((*env)->CallIntMethod(env, val, ordinal_meth)) {\n")
-                    ord_v = 0
-                    for idx, struct_line in enumerate(field_lines):
-                        if idx == 0:
-                            out_java.write("\tpublic enum " + struct_name + " {\n")
-                        elif idx == len(field_lines) - 3:
-                            assert(struct_line.endswith("_Sentinel,"))
-                        elif idx == len(field_lines) - 2:
-                            out_java.write("\t}\n")
-                        elif idx == len(field_lines) - 1:
-                            assert(struct_line == "")
-                        else:
-                            out_java.write("\t" + struct_line + "\n")
-                            out_c.write("\t\tcase %d: return %s;\n" % (ord_v, struct_line.strip().strip(",")))
-                            ord_v = ord_v + 1
-                    out_c.write("\t}\n")
-                    out_c.write("\tabort();\n")
-                    out_c.write("}\n")
-
-                    ord_v = 0
-                    out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
-                    out_c.write("\t// TODO: This is pretty inefficient, we really need to cache the field IDs and class\n")
-                    out_c.write("\tjclass enum_class = (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
-                    out_c.write("\tDO_ASSERT(enum_class != NULL);\n")
-                    out_c.write("\tswitch (val) {\n")
-                    for idx, struct_line in enumerate(field_lines):
-                        if idx > 0 and idx < len(field_lines) - 3:
-                            variant = struct_line.strip().strip(",")
-                            out_c.write("\t\tcase " + variant + ": {\n")
-                            out_c.write("\t\t\tjfieldID field = (*env)->GetStaticFieldID(env, enum_class, \"" + variant + "\", \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
-                            out_c.write("\t\t\tDO_ASSERT(field != NULL);\n")
-                            out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, enum_class, field);\n")
-                            out_c.write("\t\t}\n")
-                            ord_v = ord_v + 1
-                    out_c.write("\t\tdefault: abort();\n")
-                    out_c.write("\t}\n")
-                    out_c.write("}\n\n")
+                    with open(sys.argv[3] + "/" + struct_name + ".java", "w") as out_java_enum:
+                        out_java_enum.write("package org.ldk.enums;\n\n")
+                        unitary_enums.add(struct_name)
+                        out_c.write("static inline " + struct_name + " " + struct_name + "_from_java(JNIEnv *env, jclass val) {\n")
+                        out_c.write("\tswitch ((*env)->CallIntMethod(env, val, ordinal_meth)) {\n")
+                        ord_v = 0
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx == 0:
+                                out_java_enum.write("public enum " + struct_name + " {\n")
+                            elif idx == len(field_lines) - 3:
+                                assert(struct_line.endswith("_Sentinel,"))
+                            elif idx == len(field_lines) - 2:
+                                out_java_enum.write("}")
+                            elif idx == len(field_lines) - 1:
+                                assert(struct_line == "")
+                            else:
+                                out_java_enum.write(struct_line + "\n")
+                                out_c.write("\t\tcase %d: return %s;\n" % (ord_v, struct_line.strip().strip(",")))
+                                ord_v = ord_v + 1
+                        out_c.write("\t}\n")
+                        out_c.write("\tabort();\n")
+                        out_c.write("}\n")
+
+                        ord_v = 0
+                        out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
+                        out_c.write("\t// TODO: This is pretty inefficient, we really need to cache the field IDs and class\n")
+                        out_c.write("\tjclass enum_class = (*env)->FindClass(env, \"Lorg/ldk/enums/" + struct_name + ";\");\n")
+                        out_c.write("\tDO_ASSERT(enum_class != NULL);\n")
+                        out_c.write("\tswitch (val) {\n")
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx > 0 and idx < len(field_lines) - 3:
+                                variant = struct_line.strip().strip(",")
+                                out_c.write("\t\tcase " + variant + ": {\n")
+                                out_c.write("\t\t\tjfieldID field = (*env)->GetStaticFieldID(env, enum_class, \"" + variant + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n")
+                                out_c.write("\t\t\tDO_ASSERT(field != NULL);\n")
+                                out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, enum_class, field);\n")
+                                out_c.write("\t\t}\n")
+                                ord_v = ord_v + 1
+                        out_c.write("\t\tdefault: abort();\n")
+                        out_c.write("\t}\n")
+                        out_c.write("}\n\n")
                 elif len(trait_fn_lines) > 0:
                     trait_structs.add(struct_name)
                     map_trait(struct_name, field_var_lines, trait_fn_lines)