Cache enum fields
[ldk-java] / genbindings.py
index d9999d963516d923377cbb821bc887cd1e03c78d..ab1df82027b622c26f3f70a76e189559a80b528c 100755 (executable)
@@ -737,7 +737,10 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                             elif idx == len(field_lines) - 3:
                                 assert(struct_line.endswith("_Sentinel,"))
                             elif idx == len(field_lines) - 2:
+                                out_java_enum.write("\t; static native void init();\n")
+                                out_java_enum.write("\tstatic { init(); }\n")
                                 out_java_enum.write("}")
+                                out_java.write("\tstatic { " + struct_name + ".values(); /* Force enum statics to run */ }\n")
                             elif idx == len(field_lines) - 1:
                                 assert(struct_line == "")
                             else:
@@ -749,19 +752,27 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                         out_c.write("}\n")
 
                         ord_v = 0
+                        out_c.write("static jclass " + struct_name + "_class = NULL;\n")
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx > 0 and idx < len(field_lines) - 3:
+                                variant = struct_line.strip().strip(",")
+                                out_c.write("static jfieldID " + struct_name + "_" + variant + " = NULL;\n")
+                        out_c.write("JNIEXPORT void JNICALL Java_org_ldk_enums_" + struct_name.replace("_", "_1") + "_init (JNIEnv * env, jclass clz) {\n")
+                        out_c.write("\t" + struct_name + "_class = (*env)->NewGlobalRef(env, clz);\n")
+                        out_c.write("\tDO_ASSERT(" + struct_name + "_class != NULL);\n")
+                        for idx, struct_line in enumerate(field_lines):
+                            if idx > 0 and idx < len(field_lines) - 3:
+                                variant = struct_line.strip().strip(",")
+                                out_c.write("\t" + struct_name + "_" + variant + " = (*env)->GetStaticFieldID(env, " + struct_name + "_class, \"" + variant + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n")
+                                out_c.write("\tDO_ASSERT(" + struct_name + "_" + variant + " != NULL);\n")
+                        out_c.write("}\n")
                         out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
-                        out_c.write("\t// TODO: This is pretty inefficient, we really need to cache the field IDs and class\n")
-                        out_c.write("\tjclass enum_class = (*env)->FindClass(env, \"Lorg/ldk/enums/" + struct_name + ";\");\n")
-                        out_c.write("\tDO_ASSERT(enum_class != NULL);\n")
                         out_c.write("\tswitch (val) {\n")
                         for idx, struct_line in enumerate(field_lines):
                             if idx > 0 and idx < len(field_lines) - 3:
                                 variant = struct_line.strip().strip(",")
-                                out_c.write("\t\tcase " + variant + ": {\n")
-                                out_c.write("\t\t\tjfieldID field = (*env)->GetStaticFieldID(env, enum_class, \"" + variant + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n")
-                                out_c.write("\t\t\tDO_ASSERT(field != NULL);\n")
-                                out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, enum_class, field);\n")
-                                out_c.write("\t\t}\n")
+                                out_c.write("\t\tcase " + variant + ": \n")
+                                out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, " + struct_name + "_class, " + struct_name + "_" + variant + ");\n")
                                 ord_v = ord_v + 1
                         out_c.write("\t\tdefault: abort();\n")
                         out_c.write("\t}\n")