Cache enum fields
[ldk-java] / genbindings.py
index b4f9323db3b6b6eb44d77395aa81f7d97213e41c..ab1df82027b622c26f3f70a76e189559a80b528c 100755 (executable)
@@ -1,8 +1,8 @@
 #!/usr/bin/env python3
 import sys, re
 
-if len(sys.argv) != 5:
-    print("USAGE: /path/to/lightning.h /path/to/bindings/output.java /path/to/bindings/output.c debug")
+if len(sys.argv) != 6:
+    print("USAGE: /path/to/lightning.h /path/to/bindings/output.java /path/to/bindings/enums/ /path/to/bindings/output.c debug")
     print("debug should be true or false and indicates whether to track allocations and ensure we don't leak")
     sys.exit(1)
 
@@ -43,7 +43,7 @@ class ConvInfo:
             out_java.write(" arg")
             out_c.write(" arg")
 
-with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.argv[3], "w") as out_c:
+with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.argv[4], "w") as out_c:
     opaque_structs = set()
     trait_structs = set()
     unitary_enums = set()
@@ -105,7 +105,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
             if ma.group(1).strip() in unitary_enums:
                 java_ty = ma.group(1).strip()
                 c_ty = "jclass"
-                fn_ty_arg = "Lorg/ldk/impl/bindings$" + ma.group(1).strip() + ";"
+                fn_ty_arg = "Lorg/ldk/enums/" + ma.group(1).strip() + ";"
                 fn_arg = ma.group(2).strip()
                 rust_obj = ma.group(1).strip()
                 take_by_ptr = True
@@ -469,14 +469,6 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 dummy_line = fn_line.group(1) + struct_name + "_call_" + fn_line.group(2) + " " + struct_name + "* arg" + fn_line.group(4) + "\n"
                 map_fn(dummy_line, re.compile("([A-Za-z_0-9]*) *([A-Za-z_0-9]*) *(.*)").match(dummy_line), None, "(arg_conv->" + fn_line.group(2) + ")(arg_conv->this_arg")
 
-    out_java.write("""package org.ldk.impl;
-
-public class bindings {
-       static {
-               System.loadLibrary(\"lightningjni\");
-               init(java.lang.Enum.class);
-       }
-""")
     out_c.write("""#include \"org_ldk_impl_bindings.h\"
 #include <rust_types.h>
 #include <lightning.h>
@@ -535,9 +527,23 @@ void __attribute__((destructor)) check_leaks() {
        DO_ASSERT(allocation_ll == NULL);
 }
 """)
+    out_java.write("""package org.ldk.impl;
+import org.ldk.enums.*;
 
-    out_java.write("""
-       static native void init(java.lang.Class c);
+public class bindings {
+       public static class VecOrSliceDef {
+               public long dataptr;
+               public long datalen;
+               public long stride;
+               public VecOrSliceDef(long dataptr, long datalen, long stride) {
+                       this.dataptr = dataptr; this.datalen = datalen; this.stride = stride;
+               }
+       }
+       static {
+               System.loadLibrary(\"lightningjni\");
+               init(java.lang.Enum.class, VecOrSliceDef.class);
+       }
+       static native void init(java.lang.Class c, java.lang.Class slicedef);
 
        public static native boolean deref_bool(long ptr);
        public static native long deref_long(long ptr);
@@ -550,9 +556,15 @@ void __attribute__((destructor)) check_leaks() {
 """)
     out_c.write("""
 jmethodID ordinal_meth = NULL;
-JNIEXPORT void Java_org_ldk_impl_bindings_init(JNIEnv * env, jclass _b, jclass enum_class) {
+jmethodID slicedef_meth = NULL;
+jclass slicedef_cls = NULL;
+JNIEXPORT void Java_org_ldk_impl_bindings_init(JNIEnv * env, jclass _b, jclass enum_class, jclass slicedef_class) {
        ordinal_meth = (*env)->GetMethodID(env, enum_class, "ordinal", "()I");
        DO_ASSERT(ordinal_meth != NULL);
+       slicedef_meth = (*env)->GetMethodID(env, slicedef_class, "<init>", "(JJJ)V");
+       DO_ASSERT(slicedef_meth != NULL);
+       slicedef_cls = (*env)->NewGlobalRef(env, slicedef_class);
+       DO_ASSERT(slicedef_cls != NULL);
 }
 
 JNIEXPORT jboolean JNICALL Java_org_ldk_impl_bindings_deref_1bool (JNIEnv * env, jclass _a, jlong ptr) {
@@ -621,6 +633,7 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
     const_val_regex = re.compile("^extern const ([A-Za-z_0-9]*) ([A-Za-z_0-9]*);$")
 
     line_indicates_result_regex = re.compile("^   bool result_ok;$")
+    line_indicates_vec_regex = re.compile("^   ([A-Za-z_0-9]*) \*data;$")
     line_indicates_opaque_regex = re.compile("^   bool is_owned;$")
     line_indicates_trait_regex = re.compile("^   ([A-Za-z_0-9]* \*?)\(\*([A-Za-z_0-9]*)\)\((const )?void \*this_arg(.*)\);$")
     assert(line_indicates_trait_regex.match("   uintptr_t (*send_data)(void *this_arg, LDKu8slice data, bool resume_read);"))
@@ -645,6 +658,7 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
             if line.startswith("} "):
                 field_lines = []
                 struct_name = None
+                vec_ty = None
                 obj_lines = cur_block_obj.split("\n")
                 is_opaque = False
                 is_result = False
@@ -662,6 +676,7 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                             in_block_comment = False
                     else:
                         struct_name_match = struct_name_regex.match(struct_line)
+                        vec_ty_match = line_indicates_vec_regex.match(struct_line)
                         if struct_name_match is not None:
                             struct_name = struct_name_match.group(3)
                             if struct_name_match.group(1) == "enum":
@@ -675,6 +690,8 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                             is_opaque = True
                         elif line_indicates_result_regex.match(struct_line):
                             is_result = True
+                        elif vec_ty_match is not None and struct_name.startswith("LDKCVecTempl_"):
+                            vec_ty = vec_ty_match.group(1)
                         trait_fn_match = line_indicates_trait_regex.match(struct_line)
                         if trait_fn_match is not None:
                             trait_fn_lines.append(trait_fn_match)
@@ -684,11 +701,13 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                         field_lines.append(struct_line)
 
                 assert(struct_name is not None)
-                assert(len(trait_fn_lines) == 0 or not (is_opaque or is_unitary_enum or is_union_enum or is_union))
-                assert(not is_opaque or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_union))
-                assert(not is_unitary_enum or not (len(trait_fn_lines) != 0 or is_opaque or is_union_enum or is_union))
-                assert(not is_union_enum or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_opaque or is_union))
-                assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque))
+                assert(len(trait_fn_lines) == 0 or not (is_opaque or is_unitary_enum or is_union_enum or is_union or is_result or vec_ty is not None))
+                assert(not is_opaque or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_union or is_result or vec_ty is not None))
+                assert(not is_unitary_enum or not (len(trait_fn_lines) != 0 or is_opaque or is_union_enum or is_union or is_result or vec_ty is not None))
+                assert(not is_union_enum or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_opaque or is_union or is_result or vec_ty is not None))
+                assert(not is_union or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_result or vec_ty is not None))
+                assert(not is_result or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or vec_ty is not None))
+                assert(vec_ty is None or not (len(trait_fn_lines) != 0 or is_unitary_enum or is_union_enum or is_opaque or is_union or is_result))
                 if is_opaque:
                     opaque_structs.add(struct_name)
                     out_java.write("\tpublic static native long " + struct_name + "_optional_none();\n")
@@ -699,46 +718,65 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                     out_c.write("}\n")
                 elif is_result:
                     result_templ_structs.add(struct_name)
-                elif is_unitary_enum:
-                    unitary_enums.add(struct_name)
-                    out_c.write("static inline " + struct_name + " " + struct_name + "_from_java(JNIEnv *env, jclass val) {\n")
-                    out_c.write("\tswitch ((*env)->CallIntMethod(env, val, ordinal_meth)) {\n")
-                    ord_v = 0
-                    for idx, struct_line in enumerate(field_lines):
-                        if idx == 0:
-                            out_java.write("\tpublic enum " + struct_name + " {\n")
-                        elif idx == len(field_lines) - 3:
-                            assert(struct_line.endswith("_Sentinel,"))
-                        elif idx == len(field_lines) - 2:
-                            out_java.write("\t}\n")
-                        elif idx == len(field_lines) - 1:
-                            assert(struct_line == "")
-                        else:
-                            out_java.write("\t" + struct_line + "\n")
-                            out_c.write("\t\tcase %d: return %s;\n" % (ord_v, struct_line.strip().strip(",")))
-                            ord_v = ord_v + 1
-                    out_c.write("\t}\n")
-                    out_c.write("\tabort();\n")
+                elif vec_ty is not None:
+                    out_java.write("\tpublic static native VecOrSliceDef " + struct_name + "_arr_info(long vec_ptr);\n")
+                    out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1arr_1info(JNIEnv *env, jclass _b, jlong ptr) {\n")
+                    out_c.write("\t" + struct_name + " *vec = (" + struct_name + "*)ptr;\n")
+                    out_c.write("\treturn (*env)->NewObject(env, slicedef_cls, slicedef_meth, (long)vec->data, (long)vec->datalen, sizeof(" + vec_ty + "));\n")
                     out_c.write("}\n")
-
-                    ord_v = 0
-                    out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
-                    out_c.write("\t// TODO: This is pretty inefficient, we really need to cache the field IDs and class\n")
-                    out_c.write("\tjclass enum_class = (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
-                    out_c.write("\tDO_ASSERT(enum_class != NULL);\n")
-                    out_c.write("\tswitch (val) {\n")
-                    for idx, struct_line in enumerate(field_lines):
-                        if idx > 0 and idx < len(field_lines) - 3:
-                            variant = struct_line.strip().strip(",")
-                            out_c.write("\t\tcase " + variant + ": {\n")
-                            out_c.write("\t\t\tjfieldID field = (*env)->GetStaticFieldID(env, enum_class, \"" + variant + "\", \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
-                            out_c.write("\t\t\tDO_ASSERT(field != NULL);\n")
-                            out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, enum_class, field);\n")
-                            out_c.write("\t\t}\n")
-                            ord_v = ord_v + 1
-                    out_c.write("\t\tdefault: abort();\n")
-                    out_c.write("\t}\n")
-                    out_c.write("}\n\n")
+                elif is_unitary_enum:
+                    with open(sys.argv[3] + "/" + struct_name + ".java", "w") as out_java_enum:
+                        out_java_enum.write("package org.ldk.enums;\n\n")
+                        unitary_enums.add(struct_name)
+                        out_c.write("static inline " + struct_name + " " + struct_name + "_from_java(JNIEnv *env, jclass val) {\n")
+                        out_c.write("\tswitch ((*env)->CallIntMethod(env, val, ordinal_meth)) {\n")
+                        ord_v = 0
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx == 0:
+                                out_java_enum.write("public enum " + struct_name + " {\n")
+                            elif idx == len(field_lines) - 3:
+                                assert(struct_line.endswith("_Sentinel,"))
+                            elif idx == len(field_lines) - 2:
+                                out_java_enum.write("\t; static native void init();\n")
+                                out_java_enum.write("\tstatic { init(); }\n")
+                                out_java_enum.write("}")
+                                out_java.write("\tstatic { " + struct_name + ".values(); /* Force enum statics to run */ }\n")
+                            elif idx == len(field_lines) - 1:
+                                assert(struct_line == "")
+                            else:
+                                out_java_enum.write(struct_line + "\n")
+                                out_c.write("\t\tcase %d: return %s;\n" % (ord_v, struct_line.strip().strip(",")))
+                                ord_v = ord_v + 1
+                        out_c.write("\t}\n")
+                        out_c.write("\tabort();\n")
+                        out_c.write("}\n")
+
+                        ord_v = 0
+                        out_c.write("static jclass " + struct_name + "_class = NULL;\n")
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx > 0 and idx < len(field_lines) - 3:
+                                variant = struct_line.strip().strip(",")
+                                out_c.write("static jfieldID " + struct_name + "_" + variant + " = NULL;\n")
+                        out_c.write("JNIEXPORT void JNICALL Java_org_ldk_enums_" + struct_name.replace("_", "_1") + "_init (JNIEnv * env, jclass clz) {\n")
+                        out_c.write("\t" + struct_name + "_class = (*env)->NewGlobalRef(env, clz);\n")
+                        out_c.write("\tDO_ASSERT(" + struct_name + "_class != NULL);\n")
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx > 0 and idx < len(field_lines) - 3:
+                                variant = struct_line.strip().strip(",")
+                                out_c.write("\t" + struct_name + "_" + variant + " = (*env)->GetStaticFieldID(env, " + struct_name + "_class, \"" + variant + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n")
+                                out_c.write("\tDO_ASSERT(" + struct_name + "_" + variant + " != NULL);\n")
+                        out_c.write("}\n")
+                        out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
+                        out_c.write("\tswitch (val) {\n")
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx > 0 and idx < len(field_lines) - 3:
+                                variant = struct_line.strip().strip(",")
+                                out_c.write("\t\tcase " + variant + ": \n")
+                                out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, " + struct_name + "_class, " + struct_name + "_" + variant + ");\n")
+                                ord_v = ord_v + 1
+                        out_c.write("\t\tdefault: abort();\n")
+                        out_c.write("\t}\n")
+                        out_c.write("}\n\n")
                 elif len(trait_fn_lines) > 0:
                     trait_structs.add(struct_name)
                     map_trait(struct_name, field_var_lines, trait_fn_lines)