]> git.bitcoin.ninja Git - ldk-java/blobdiff - genbindings.py
Drop stale jni header files
[ldk-java] / genbindings.py
index 2a5c44e2c6ea7e200d5d42dcba4b8a0218454d99..24d3df5940d47da822e69d352edf28ea879727a8 100755 (executable)
@@ -23,6 +23,8 @@ class ConvInfo:
         assert(ty_info.c_ty is not None)
         assert(ty_info.java_ty is not None)
         assert(arg_name is not None)
+        self.passed_as_ptr = ty_info.passed_as_ptr
+        self.rust_obj = ty_info.rust_obj
         self.c_ty = ty_info.c_ty
         self.java_ty = ty_info.java_ty
         self.java_fn_ty_arg = ty_info.java_fn_ty_arg
@@ -43,6 +45,22 @@ class ConvInfo:
         else:
             out_java.write(" arg")
             out_c.write(" arg")
+fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$")
+fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")
+reg_fn_regex = re.compile("([A-Za-z_0-9\* ]* \*?)([a-zA-Z_0-9]*)\((.*)\);$")
+clone_fns = set()
+with open(sys.argv[1]) as in_h:
+    for line in in_h:
+        reg_fn = reg_fn_regex.match(line)
+        if reg_fn is not None:
+            if reg_fn.group(2).endswith("_clone"):
+                clone_fns.add(reg_fn.group(2))
+            continue
+        arr_fn = fn_ret_arr_regex.match(line)
+        if arr_fn is not None:
+            if arr_fn.group(2).endswith("_clone"):
+                clone_fns.add(arr_fn.group(2))
+            continue
 
 with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.argv[4], "w") as out_c:
     opaque_structs = set()
@@ -218,6 +236,10 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n"
                 opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = (void*)(" + ty_info.var_name + " & (~1));\n"
                 opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.is_owned = (" + ty_info.var_name + " & 1) || (" + ty_info.var_name + " == 0);"
+                if (ty_info.rust_obj.replace("LDK", "") + "_clone") in clone_fns and not ty_info.is_ptr and not is_free:
+                    # TODO: This is a bit too naive, even with the checks above, we really need to know if rust wants a ref or not, not just if its pass as a ptr.
+                    opaque_arg_conv = opaque_arg_conv + "\nif (" + ty_info.var_name + "_conv.inner != NULL)\n"
+                    opaque_arg_conv = opaque_arg_conv + "\t" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
                 if not ty_info.is_ptr:
                     if ty_info.rust_obj in unitary_enums:
                         return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
@@ -510,7 +532,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         out_c.write("typedef struct " + struct_name + "_JCalls {\n")
         out_c.write("\tatomic_size_t refcnt;\n")
         out_c.write("\tJavaVM *vm;\n")
-        out_c.write("\tjobject o;\n")
+        out_c.write("\tjweak o;\n")
         for var_line in field_var_lines:
             if var_line.group(1) in trait_structs:
                 out_c.write("\t" + var_line.group(1) + "_JCalls* " + var_line.group(2) + ";\n")
@@ -561,13 +583,14 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                         out_c.write(arg_info.arg_name)
                         out_c.write(arg_info.ret_conv[1].replace('\n', '\n\t').replace("_env", "env") + "\n")
 
+                out_c.write("\tjobject obj = (*env)->NewLocalRef(env, j_calls->o);\n\tDO_ASSERT(obj != NULL);\n")
                 if ret_ty_info.c_ty.endswith("Array"):
                     assert(ret_ty_info.c_ty == "jbyteArray")
-                    out_c.write("\tjbyteArray jret = (*env)->CallObjectMethod(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
+                    out_c.write("\tjbyteArray jret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.group(2) + "_meth")
                 elif not ret_ty_info.passed_as_ptr:
-                    out_c.write("\treturn (*env)->Call" + ret_ty_info.java_ty.title() + "Method(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth")
+                    out_c.write("\treturn (*env)->Call" + ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.group(2) + "_meth")
                 else:
-                    out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*env)->CallLongMethod(env, j_calls->o, j_calls->" + fn_line.group(2) + "_meth");
+                    out_c.write("\t" + fn_line.group(1).strip() + "* ret = (" + fn_line.group(1).strip() + "*)(*env)->CallLongMethod(env, obj, j_calls->" + fn_line.group(2) + "_meth");
 
                 for arg_info in arg_names:
                     if arg_info.ret_conv is not None:
@@ -591,7 +614,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
                 out_c.write("\tif (atomic_fetch_sub_explicit(&j_calls->refcnt, 1, memory_order_acquire) == 1) {\n")
                 out_c.write("\t\tJNIEnv *env;\n")
                 out_c.write("\t\tDO_ASSERT((*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_8) == JNI_OK);\n")
-                out_c.write("\t\t(*env)->DeleteGlobalRef(env, j_calls->o);\n")
+                out_c.write("\t\t(*env)->DeleteWeakGlobalRef(env, j_calls->o);\n")
                 out_c.write("\t\tFREE(j_calls);\n")
                 out_c.write("\t}\n}\n")
 
@@ -621,7 +644,7 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
         out_c.write("\t" + struct_name + "_JCalls *calls = MALLOC(sizeof(" + struct_name + "_JCalls), \"" + struct_name + "_JCalls\");\n")
         out_c.write("\tatomic_init(&calls->refcnt, 1);\n")
         out_c.write("\tDO_ASSERT((*env)->GetJavaVM(env, &calls->vm) == 0);\n")
-        out_c.write("\tcalls->o = (*env)->NewGlobalRef(env, o);\n")
+        out_c.write("\tcalls->o = (*env)->NewWeakGlobalRef(env, o);\n")
         for (fn_line, java_meth_descr) in zip(trait_fn_lines, java_meths):
             if fn_line.group(2) != "free" and fn_line.group(2) != "clone":
                 out_c.write("\tcalls->" + fn_line.group(2) + "_meth = (*env)->GetMethodID(env, c, \"" + fn_line.group(2) + "\", \"" + java_meth_descr + "\");\n")
@@ -661,7 +684,9 @@ with open(sys.argv[1]) as in_h, open(sys.argv[2], "w") as out_java, open(sys.arg
 
         out_java.write("\tpublic static native " + struct_name + " " + struct_name + "_get_obj_from_jcalls(long val);\n")
         out_c.write("JNIEXPORT jobject JNICALL Java_org_ldk_impl_bindings_" + struct_name.replace("_", "_1") + "_1get_1obj_1from_1jcalls (JNIEnv * env, jclass _a, jlong val) {\n")
-        out_c.write("\treturn ((" + struct_name + "_JCalls*)val)->o;\n")
+        out_c.write("\tjobject ret = (*env)->NewLocalRef(env, ((" + struct_name + "_JCalls*)val)->o);\n")
+        out_c.write("\tDO_ASSERT(ret != NULL);\n")
+        out_c.write("\treturn ret;\n")
         out_c.write("}\n")
 
         for fn_line in trait_fn_lines: