Macro-ize assert to handle side-effects, fix JavaVM access, util fns
[ldk-java] / genbindings.py
index e0b2e184b03e77ec26ec9d3b2ca46fb10c6e88c5..a5e9ec41e38c62742b561606341f301aa0bd3b0c 100755 (executable)
@@ -239,7 +239,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         # any _free function.
                         # To avoid any issues, we first assert that the incoming object is non-ref.
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
-                            ret_conv = (ty_info.rust_obj + "* ret = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*ret = ", ";\nassert(ret->is_owned);\nret->is_owned = false;"),
+                            ret_conv = (ty_info.rust_obj + "* ret = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*ret = ", ";\nDO_ASSERT(ret->is_owned);\nret->is_owned = false;"),
                             ret_conv_name = "(long)ret",
                             arg_conv = None, arg_conv_name = None)
                     else:
@@ -314,7 +314,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
     def map_trait(struct_name, field_var_lines, trait_fn_lines):
         out_c.write("typedef struct " + struct_name + "_JCalls {\n")
         out_c.write("\tatomic_size_t refcnt;\n")
-        out_c.write("\tJNIEnv *env;\n")
+        out_c.write("\tJavaVM *vm;\n")
         out_c.write("\tjobject o;\n")
         for var_line in field_var_lines:
             if var_line.group(1) in trait_structs:
@@ -357,17 +357,19 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 out_java.write(");\n")
                 out_c.write(") {\n")
                 out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
+                out_c.write("\tJNIEnv *env;\n")
+                out_c.write("\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n")
 
                 for arg_info in arg_names:
                     if arg_info.ret_conv is not None:
-                        out_c.write("\t" + arg_info.ret_conv[0].replace('\n', '\n\t').replace("_env", "j_calls->env"));
+                        out_c.write("\t" + arg_info.ret_conv[0].replace('\n', '\n\t').replace("_env", "env"));
                         out_c.write(arg_info.arg_name)
-                        out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "j_calls->env") + "\n")
+                        out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "env") + "\n")
 
                 if not ret_ty_info.passed_as_ptr:
-                    out_c.write("\treturn (*j_calls->env)->Call" + ret_ty_info.java_ty.title() + "Method(j_calls->env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
+                    out_c.write("\treturn (*env)->Call" + ret_ty_info.java_ty.title() + "Method(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
                 else:
-                    out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*j_calls->env)->CallLongMethod(j_calls->env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth");
+                    out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*env)->CallLongMethod(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth");
 
                 for arg_info in arg_names:
                     if arg_info.ret_conv is not None:
@@ -385,7 +387,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 out_c.write("static void " + struct_name + "_JCalls_free(void* this_arg) {\n")
                 out_c.write("\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n")
                 out_c.write("\tif (atomic_fetch_sub_explicit(&j_calls->refcnt, 1, memory_order_acquire) == 1) {\n")
-                out_c.write("\t\t(*j_calls->env)->DeleteGlobalRef(j_calls->env, j_calls->o);\n")
+                out_c.write("\t\tJNIEnv *env;\n")
+                out_c.write("\t\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n")
+                out_c.write("\t\t(*env)->DeleteGlobalRef(env, j_calls->o);\n")
                 out_c.write("\t\tFREE(j_calls);\n")
                 out_c.write("\t}\n}\n")
 
@@ -411,15 +415,15 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         out_c.write(") {\n")
 
         out_c.write("\tjclass c = (*env)->GetObjectClass(env, o);\n")
-        out_c.write("\tassert(c != NULL);\n")
+        out_c.write("\tDO_ASSERT(c != NULL);\n")
         out_c.write("\t" + struct_name + "_JCalls *calls = MALLOC(sizeof(" + struct_name + "_JCalls), \"" + struct_name + "_JCalls\");\n")
         out_c.write("\tatomic_init(&calls->refcnt, 1);\n")
-        out_c.write("\tcalls->env = env;\n")
+        out_c.write("\tDO_ASSERT((*env)->GetJavaVM(env, &calls->vm) == 0);\n")
         out_c.write("\tcalls->o = (*env)->NewGlobalRef(env, o);\n")
         for (fn_line, java_meth_descr) in zip(trait_fn_lines, java_meths):
             if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
                 out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + java_meth_descr + "\");\n")
-                out_c.write("\tassert(calls->" + fn_line.group(2) + "_meth != NULL);\n")
+                out_c.write("\tDO_ASSERT(calls->" + fn_line.group(2) + "_meth != NULL);\n")
         out_c.write("\n\t" + struct_name + " ret = {\n")
         out_c.write("\t\t.this_arg = (void*) calls,\n")
         for fn_line in trait_fn_lines:
@@ -476,7 +480,6 @@ public class bindings {
     out_c.write("""#include \"org_ldk_impl_bindings.h\"
 #include <rust_types.h>
 #include <lightning.h>
-#include <assert.h>
 #include <string.h>
 #include <stdatomic.h>
 """)
@@ -484,13 +487,16 @@ public class bindings {
     if sys.argv[4] == "false":
         out_c.write("#define MALLOC(a, _) malloc(a)\n")
         out_c.write("#define FREE free\n")
+        out_c.write("#define DO_ASSERT(a) (void)(a)\n")
     else:
-        out_c.write("""
+        out_c.write("""#include <assert.h>
+#define DO_ASSERT(a) do { bool _assert_val = (a); assert(_assert_val); } while(0)
+
 #include <threads.h>
 static mtx_t allocation_mtx;
 
 void __attribute__((constructor)) init_mtx() {
-       assert(mtx_init(&allocation_mtx, mtx_plain) == thrd_success);
+       DO_ASSERT(mtx_init(&allocation_mtx, mtx_plain) == thrd_success);
 }
 
 typedef struct allocation {
@@ -505,28 +511,28 @@ void* MALLOC(size_t len, const char* struct_name) {
        allocation* new_alloc = malloc(sizeof(allocation));
        new_alloc->ptr = res;
        new_alloc->struct_name = struct_name;
-       assert(mtx_lock(&allocation_mtx) == thrd_success);
+       DO_ASSERT(mtx_lock(&allocation_mtx) == thrd_success);
        new_alloc->next = allocation_ll;
        allocation_ll = new_alloc;
-       assert(mtx_unlock(&allocation_mtx) == thrd_success);
+       DO_ASSERT(mtx_unlock(&allocation_mtx) == thrd_success);
        return res;
 }
 
 void FREE(void* ptr) {
        allocation* p = NULL;
-       assert(mtx_lock(&allocation_mtx) == thrd_success);
+       DO_ASSERT(mtx_lock(&allocation_mtx) == thrd_success);
        allocation* it = allocation_ll;
        while (it->ptr != ptr) { p = it; it = it->next; }
        if (p) { p->next = it->next; } else { allocation_ll = it->next; }
-       assert(mtx_unlock(&allocation_mtx) == thrd_success);
-       assert(it->ptr == ptr);
+       DO_ASSERT(mtx_unlock(&allocation_mtx) == thrd_success);
+       DO_ASSERT(it->ptr == ptr);
        free(it);
        free(ptr);
 }
 
 void __attribute__((destructor)) check_leaks() {
        for (allocation* a = allocation_ll; a != NULL; a = a->next) { fprintf(stderr, "%s %p remains\\n", a->struct_name, a->ptr); }
-       assert(allocation_ll == NULL);
+       DO_ASSERT(allocation_ll == NULL);
 }
 """)
 
@@ -536,6 +542,8 @@ void __attribute__((destructor)) check_leaks() {
        public static native boolean deref_bool(long ptr);
        public static native long deref_long(long ptr);
        public static native void free_heap_ptr(long ptr);
+       public static native byte[] get_u8_slice_bytes(long slice_ptr);
+       public static native long bytes_to_u8_vec(byte[] bytes);
        public static native long u8_vec_len(long vec);
 
 """)
@@ -543,7 +551,7 @@ void __attribute__((destructor)) check_leaks() {
 jmethodID ordinal_meth = NULL;
 JNIEXPORT void Java_org_ldk_impl_bindings_init(JNIEnv * env, jclass _b, jclass enum_class) {
        ordinal_meth = (*env)->GetMethodID(env, enum_class, "ordinal", "()I");
-       assert(ordinal_meth != NULL);
+       DO_ASSERT(ordinal_meth != NULL);
 }
 
 JNIEXPORT jboolean JNICALL Java_org_ldk_impl_bindings_deref_1bool (JNIEnv * env, jclass _a, jlong ptr) {
@@ -555,6 +563,19 @@ JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_deref_1long (JNIEnv * env, jc
 JNIEXPORT void JNICALL Java_org_ldk_impl_bindings_free_1heap_1ptr (JNIEnv * env, jclass _a, jlong ptr) {
        FREE((void*)ptr);
 }
+JNIEXPORT jbyteArray JNICALL Java_org_ldk_impl_bindings_get_1u8_1slice_1bytes (JNIEnv * _env, jclass _b, jlong slice_ptr) {
+       LDKu8slice *slice = (LDKu8slice*)slice_ptr;
+       jbyteArray ret_arr = (*_env)->NewByteArray(_env, slice->datalen);
+       (*_env)->SetByteArrayRegion(_env, ret_arr, 0, slice->datalen, slice->data);
+       return ret_arr;
+}
+JNIEXPORT long JNICALL Java_org_ldk_impl_bindings_bytes_1to_1u8_1vec (JNIEnv * _env, jclass _b, jbyteArray bytes) {
+       LDKCVec_u8Z *vec = (LDKCVec_u8Z*)MALLOC(sizeof(LDKCVec_u8Z), "LDKCVec_u8");
+       vec->datalen = (*_env)->GetArrayLength(_env, bytes);
+       vec->data = (uint8_t*)malloc(vec->datalen); // May be freed by rust, so don't track allocation
+       (*_env)->GetByteArrayRegion (_env, bytes, 0, vec->datalen, vec->data);
+       return (long)vec;
+}
 JNIEXPORT jlong JNICALL Java_org_ldk_impl_bindings_u8_1vec_1len (JNIEnv * env, jclass _a, jlong ptr) {
        LDKCVec_u8Z *vec = (LDKCVec_u8Z*)ptr;
        return (long)vec->datalen;
@@ -674,25 +695,25 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
                             out_c.write("\t\tcase %d: return %s;\n" % (ord_v, struct_line.strip().strip(",")))
                             ord_v = ord_v + 1
                     out_c.write("\t}\n")
-                    out_c.write("\tassert(false);\n")
+                    out_c.write("\tabort();\n")
                     out_c.write("}\n")
 
                     ord_v = 0
                     out_c.write("static inline jclass " + struct_name + "_to_java(JNIEnv *env, " + struct_name + " val) {\n")
                     out_c.write("\t// TODO: This is pretty inefficient, we really need to cache the field IDs and class\n")
                     out_c.write("\tjclass enum_class = (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
-                    out_c.write("\tassert(enum_class != NULL);\n")
+                    out_c.write("\tDO_ASSERT(enum_class != NULL);\n")
                     out_c.write("\tswitch (val) {\n")
                     for idx, struct_line in enumerate(field_lines):
                         if idx > 0 and idx < len(field_lines) - 3:
                             variant = struct_line.strip().strip(",")
                             out_c.write("\t\tcase " + variant + ": {\n")
                             out_c.write("\t\t\tjfieldID field = (*env)->GetStaticFieldID(env, enum_class, \"" + variant + "\", \"Lorg/ldk/impl/bindings$" + struct_name + ";\");\n")
-                            out_c.write("\t\t\tassert(field != NULL);\n")
+                            out_c.write("\t\t\tDO_ASSERT(field != NULL);\n")
                             out_c.write("\t\t\treturn (*env)->GetStaticObjectField(env, enum_class, field);\n")
                             out_c.write("\t\t}\n")
                             ord_v = ord_v + 1
-                    out_c.write("\t\tdefault: assert(false);\n")
+                    out_c.write("\t\tdefault: abort();\n")
                     out_c.write("\t}\n")
                     out_c.write("}\n\n")
                 elif len(trait_fn_lines) > 0: