Handle C -> Java calls from Rust-spawned threads
authorMatt Corallo <git@bluematt.me>
Wed, 2 Jun 2021 19:26:42 +0000 (19:26 +0000)
committerMatt Corallo <git-ldk-build@bluematt.me>
Tue, 8 Jun 2021 20:26:10 +0000 (20:26 +0000)
Turns out you need to "attach" threads to the JVM before you can
make Java calls.

java_strings.py

index cc545e8297f97797dc53ce3b91924c12fe6cf7d9..94998db17a6dbade2c919bc505ca3a7b9821236f 100644 (file)
@@ -453,6 +453,21 @@ import java.util.Arrays;
         self.ptr_arr = "jobjectArray"
         self.get_native_arr_len_call = ("(*env)->GetArrayLength(env, ", ")")
 
+    def construct_jenv(self):
+        res =  "JNIEnv *env;\n"
+        res += "jint get_jenv_res = (*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_6);\n"
+        res += "if (get_jenv_res == JNI_EDETACHED) {\n"
+        res += "\tDO_ASSERT((*j_calls->vm)->AttachCurrentThread(j_calls->vm, (void**)&env, NULL) == JNI_OK);\n"
+        res += "} else {\n"
+        res += "\tDO_ASSERT(get_jenv_res == JNI_OK);\n"
+        res += "}\n"
+        return res
+    def deconstruct_jenv(self):
+        res = "if (get_jenv_res == JNI_EDETACHED) {\n"
+        res += "\tDO_ASSERT((*j_calls->vm)->DetachCurrentThread(j_calls->vm) == JNI_OK);\n"
+        res += "}\n"
+        return res
+
     def release_native_arr_ptr_call(self, ty_info, arr_var, arr_ptr_var):
         if ty_info.subty is None or not ty_info.subty.c_ty.endswith("Array"):
             return "(*env)->ReleasePrimitiveArrayCritical(env, " + arr_var + ", " + arr_ptr_var + ", 0)"
@@ -761,9 +776,9 @@ import java.util.Arrays;
                 out_c = out_c + "static void " + struct_name + "_JCalls_free(void* this_arg) {\n"
                 out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
                 out_c = out_c + "\tif (atomic_fetch_sub_explicit(&j_calls->refcnt, 1, memory_order_acquire) == 1) {\n"
-                out_c = out_c + "\t\tJNIEnv *env;\n"
-                out_c = out_c + "\t\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_6) == JNI_OK);\n"
+                out_c += "\t\t" + self.construct_jenv().replace("\n", "\n\t\t").strip() + "\n"
                 out_c = out_c + "\t\t(*env)->DeleteWeakGlobalRef(env, j_calls->o);\n"
+                out_c += "\t\t" + self.deconstruct_jenv().replace("\n", "\n\t\t").strip() + "\n"
                 out_c = out_c + "\t\tFREE(j_calls);\n"
                 out_c = out_c + "\t}\n}\n"
 
@@ -781,8 +796,7 @@ import java.util.Arrays;
 
                 out_c = out_c + ") {\n"
                 out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
-                out_c = out_c + "\tJNIEnv *env;\n"
-                out_c = out_c + "\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_6) == JNI_OK);\n"
+                out_c += "\t" + self.construct_jenv().replace("\n", "\n\t").strip() + "\n"
 
                 for arg_info in fn_line.args_ty:
                     if arg_info.ret_conv is not None:
@@ -793,8 +807,10 @@ import java.util.Arrays;
                 out_c = out_c + "\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tCHECK(obj != NULL);\n"
                 if fn_line.ret_ty_info.c_ty.endswith("Array"):
                     out_c = out_c + "\t" + fn_line.ret_ty_info.c_ty + " ret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
+                elif fn_line.ret_ty_info.c_ty == "void":
+                    out_c += "\t(*env)->Call" + fn_line.ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.fn_name + "_meth"
                 elif not fn_line.ret_ty_info.passed_as_ptr:
-                    out_c = out_c + "\treturn (*env)->Call" + fn_line.ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.fn_name + "_meth"
+                    out_c += "\t" + fn_line.ret_ty_info.c_ty + " ret = (*env)->Call" + fn_line.ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.fn_name + "_meth"
                 else:
                     out_c = out_c + "\t" + fn_line.ret_ty_info.rust_obj + "* ret = (" + fn_line.ret_ty_info.rust_obj + "*)(*env)->CallLongMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
 
@@ -805,7 +821,13 @@ import java.util.Arrays;
                         out_c = out_c + ", " + arg_info.arg_name
                 out_c = out_c + ");\n"
                 if fn_line.ret_ty_info.arg_conv is not None:
-                    out_c = out_c + "\t" + fn_line.ret_ty_info.arg_conv.replace("\n", "\n\t") + "\n\treturn " + fn_line.ret_ty_info.arg_conv_name + ";\n"
+                    out_c += "\t" + fn_line.ret_ty_info.arg_conv.replace("\n", "\n\t") + "\n"
+                    out_c += "\t" + self.deconstruct_jenv().replace("\n", "\n\t").strip() + "\n"
+                    out_c += "\treturn " + fn_line.ret_ty_info.arg_conv_name + ";\n"
+                else:
+                    out_c += "\t" + self.deconstruct_jenv().replace("\n", "\n\t").strip() + "\n"
+                    if not fn_line.ret_ty_info.passed_as_ptr and fn_line.ret_ty_info.c_ty != "void":
+                        out_c += "\treturn ret;\n"
 
                 out_c = out_c + "}\n"