[Java] Add a reachabilityFence on underlying trait impl calls
[ldk-java] / java_strings.py
index ef4ada4782d49ccc0118236ce5ff2f251327fa88..3a8ba6cb613047a728cb77a09abd7b5557664aba 100644 (file)
@@ -16,6 +16,12 @@ class Consts:
             uint32_t = ['int'],
             uint64_t = ['long'],
         )
+        self.java_type_map = dict(
+            String = "String"
+        )
+        self.java_hu_type_map = dict(
+            String = "String"
+        )
 
         self.to_hu_conv_templates = dict(
             ptr = '{human_type} {var_name}_hu_conv = null; if ({var_name} < 0 || {var_name} > 4096) { {var_name}_hu_conv = new {human_type}(null, {var_name}); }',
@@ -106,6 +112,7 @@ import org.ldk.impl.bindings;
 import org.ldk.enums.*;
 import org.ldk.util.*;
 import java.util.Arrays;
+import java.lang.ref.Reference;
 import javax.annotation.Nullable;
 
 public class UtilMethods {
@@ -120,6 +127,30 @@ class CommonBase {
 }
 """
 
+        self.txout_defn = """public class TxOut extends CommonBase {
+       /** The script_pubkey in this output */
+       public final byte[] script_pubkey;
+       /** The value, in satoshis, of this output */
+       public final long value;
+
+       TxOut(java.lang.Object _dummy, long ptr) {
+               super(ptr);
+               this.script_pubkey = bindings.TxOut_get_script_pubkey(ptr);
+               this.value = bindings.TxOut_get_value(ptr);
+       }
+       public TxOut(long value, byte[] script_pubkey) {
+               super(bindings.TxOut_new(script_pubkey, value));
+               this.script_pubkey = bindings.TxOut_get_script_pubkey(ptr);
+               this.value = bindings.TxOut_get_value(ptr);
+       }
+
+       @Override @SuppressWarnings(\"deprecation\")
+       protected void finalize() throws Throwable {
+               super.finalize();
+               if (ptr != 0) { bindings.TxOut_free(ptr); }
+       }
+}"""
+
         self.c_file_pfx = """#include <jni.h>
 // On OSX jlong (ie long long) is not equivalent to int64_t, so we override here
 #define int64_t jlong
@@ -129,6 +160,9 @@ class CommonBase {
 #include <stdatomic.h>
 #include <stdlib.h>
 
+#define LIKELY(v) __builtin_expect(!!(v), 1)
+#define UNLIKELY(v) __builtin_expect(!!(v), 0)
+
 """
 
         if self.target == Target.ANDROID:
@@ -512,6 +546,7 @@ import org.ldk.impl.bindings;
 import org.ldk.enums.*;
 import org.ldk.util.*;
 import java.util.Arrays;
+import java.lang.ref.Reference;
 import javax.annotation.Nullable;
 
 """
@@ -590,10 +625,24 @@ import javax.annotation.Nullable;
         else:
             return "(*env)->Release" + ty_info.java_ty.strip("[]").title() + "ArrayElements(env, " + arr_name + ", " + dest_name + ", 0)"
 
+    def map_hu_array_elems(self, arr_name, conv_name, arr_ty, elem_ty):
+        if elem_ty.java_ty == "long" and elem_ty.java_hu_ty != "long":
+            return arr_name + " != null ? Arrays.stream(" + arr_name + ").mapToLong(" + conv_name + " -> " + elem_ty.from_hu_conv[0] + ").toArray() : null"
+        elif elem_ty.java_ty == "long":
+            return arr_name + " != null ? Arrays.stream(" + arr_name + ").map(" + conv_name + " -> " + elem_ty.from_hu_conv[0] + ").toArray() : null"
+        elif elem_ty.java_hu_ty == "UInt5":
+            return arr_name + " != null ? InternalUtils.convUInt5Array(" + arr_name + ") : null"
+        else:
+            return arr_name + " != null ? Arrays.stream(" + arr_name + ").map(" + conv_name + " -> " + elem_ty.from_hu_conv[0] + ").toArray(" + arr_ty.java_ty + "::new) : null"
+
     def str_ref_to_native_call(self, var_name, str_len):
         return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
     def str_ref_to_c_call(self, var_name):
         return "java_to_owned_str(env, " + var_name + ")"
+    def str_to_hu_conv(self, var_name):
+        return None
+    def str_from_hu_conv(self, var_name):
+        return None
 
     def c_fn_name_define_pfx(self, fn_name, has_args):
         if has_args:
@@ -612,12 +661,57 @@ import javax.annotation.Nullable;
         res = res + "}\n"
         return res
 
+    def var_decl_statement(self, ty_string, var_name, statement):
+        return ty_string + " " + var_name + " = " + statement
+
+    def get_java_arr_len(self, arr_name):
+        return arr_name + ".length"
+    def get_java_arr_elem(self, elem_ty, arr_name, idx):
+        return arr_name + "[" + idx + "]"
+    def constr_hu_array(self, ty_info, arr_len):
+        base_ty = ty_info.subty.java_hu_ty.split("[")[0].split("<")[0]
+        conv = "new " + base_ty + "[" + arr_len + "]"
+        if "[" in ty_info.subty.java_hu_ty.split("<")[0]:
+            # Do a bit of a dance to move any excess [] to the end
+            conv += "[" + ty_info.subty.java_hu_ty.split("<")[0].split("[")[1]
+        return conv
+    def cleanup_converted_native_array(self, ty_info, arr_name):
+        return None
+
+    def primitive_arr_from_hu(self, mapped_ty, fixed_len, arr_name):
+        if fixed_len is not None:
+            return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+        return None
+    def primitive_arr_to_hu(self, primitive_ty, fixed_len, arr_name, conv_name):
+        return None
+
+    def java_arr_ty_str(self, elem_ty_str):
+        return elem_ty_str + "[]"
+
+    def for_n_in_range(self, n, minimum, maximum):
+        return "for (int " + n + " = " + minimum + "; " + n + " < " + maximum + "; " + n + "++) {"
+    def for_n_in_arr(self, n, arr_name, arr_elem_ty):
+        return ("for (" + arr_elem_ty.java_hu_ty + " " + n + ": " + arr_name + ") { ", " }")
+
+    def get_ptr(self, var):
+        return var + ".ptr"
+    def set_null_skip_free(self, var):
+        return var + ".ptr" + " = 0;"
+
+    def add_ref(self, holder, referent):
+        return holder + ".ptrs_to.add(" + referent + ")"
+
     def native_c_unitary_enum_map(self, struct_name, variants, enum_doc_comment):
         out_java_enum = "package org.ldk.enums;\n\n"
         out_java = ""
         out_c = ""
-        out_c = out_c + "static inline LDK" + struct_name + " LDK" + struct_name + "_from_java(" + self.c_fn_args_pfx + ") {\n"
-        out_c = out_c + "\tswitch ((*env)->CallIntMethod(env, clz, ordinal_meth)) {\n"
+        out_c += "static inline LDK" + struct_name + " LDK" + struct_name + "_from_java(" + self.c_fn_args_pfx + ") {\n"
+        out_c += "\tjint ord = (*env)->CallIntMethod(env, clz, ordinal_meth);\n"
+        out_c += "\tif (UNLIKELY((*env)->ExceptionCheck(env))) {\n"
+        out_c += "\t\t(*env)->ExceptionDescribe(env);\n"
+        out_c += "\t\t(*env)->FatalError(env, \"A call to " + struct_name + ".ordinal() from rust threw an exception.\");\n"
+        out_c += "\t}\n"
+        out_c += "\tswitch (ord) {\n"
 
         if enum_doc_comment is not None:
             out_java_enum += "/**\n * " + enum_doc_comment.replace("\n", "\n * ") + "\n */\n"
@@ -633,9 +727,10 @@ import javax.annotation.Nullable;
         out_java_enum = out_java_enum + "\tstatic { init(); }\n"
         out_java_enum = out_java_enum + "}"
         out_java = out_java + "\tstatic { " + struct_name + ".values(); /* Force enum statics to run */ }\n"
-        out_c = out_c + "\t}\n"
-        out_c = out_c + "\tabort();\n"
-        out_c = out_c + "}\n"
+        out_c += "\t}\n"
+        out_c += "\t(*env)->FatalError(env, \"A call to " + struct_name + ".ordinal() from rust returned an invalid value.\");\n"
+        out_c += "\tabort(); // Unreachable, but will let the compiler know we don't return here\n"
+        out_c += "}\n"
 
         out_c = out_c + "static jclass " + struct_name + "_class = NULL;\n"
         for var, _ in variants:
@@ -672,14 +767,11 @@ import javax.annotation.Nullable;
             out_c = out_c + "static jmethodID " + struct_name + "_" + var + "_meth = NULL;\n"
         out_c = out_c + self.c_fn_ty_pfx + "void JNICALL Java_org_ldk_impl_bindings_00024" + struct_name.replace("_", "_1") + "_init (" + self.c_fn_args_pfx + ") {\n"
         for var_name in variants:
-            out_c = out_c + "\t" + struct_name + "_" + var_name + "_class =\n"
-            if self.target == Target.ANDROID:
-                out_c = out_c + "\t\t(*env)->NewGlobalRef(env, (*env)->FindClass(env, \"org/ldk/impl/bindings$" + struct_name + "$" + var_name + "\"));\n"
-            else:
-                out_c = out_c + "\t\t(*env)->NewGlobalRef(env, (*env)->FindClass(env, \"Lorg/ldk/impl/bindings$" + struct_name + "$" + var_name + ";\"));\n"
-            out_c = out_c + "\tCHECK(" + struct_name + "_" + var_name + "_class != NULL);\n"
-            out_c = out_c + "\t" + struct_name + "_" + var_name + "_meth = (*env)->GetMethodID(env, " + struct_name + "_" + var_name + "_class, \"<init>\", \"(" + init_meth_jty_strs[var_name] + ")V\");\n"
-            out_c = out_c + "\tCHECK(" + struct_name + "_" + var_name + "_meth != NULL);\n"
+            out_c += "\t" + struct_name + "_" + var_name + "_class =\n"
+            out_c += "\t\t(*env)->NewGlobalRef(env, (*env)->FindClass(env, \"org/ldk/impl/bindings$" + struct_name + "$" + var_name + "\"));\n"
+            out_c += "\tCHECK(" + struct_name + "_" + var_name + "_class != NULL);\n"
+            out_c += "\t" + struct_name + "_" + var_name + "_meth = (*env)->GetMethodID(env, " + struct_name + "_" + var_name + "_class, \"<init>\", \"(" + init_meth_jty_strs[var_name] + ")V\");\n"
+            out_c += "\tCHECK(" + struct_name + "_" + var_name + "_meth != NULL);\n"
         out_c = out_c + "}\n"
         return out_c
 
@@ -790,7 +882,8 @@ import javax.annotation.Nullable;
                     else:
                         java_trait_constr = java_trait_constr + arg_info.arg_name
 
-                java_trait_constr = java_trait_constr + ");\n"
+                java_trait_constr += ");\n"
+                java_trait_constr += "\t\t\t\tReference.reachabilityFence(arg);\n"
                 if fn_line.ret_ty_info.java_ty != "void":
                     if fn_line.ret_ty_info.from_hu_conv is not None:
                         java_trait_constr = java_trait_constr + "\t\t\t\t" + fn_line.ret_ty_info.java_ty + " result = " + fn_line.ret_ty_info.from_hu_conv[0] + ";\n"
@@ -886,7 +979,7 @@ import javax.annotation.Nullable;
                     out_c = out_c + "\t" + fn_line.ret_ty_info.c_ty + " ret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
                 elif fn_line.ret_ty_info.c_ty == "void":
                     out_c += "\t(*env)->Call" + fn_line.ret_ty_info.java_ty.title() + "Method(env, obj, j_calls->" + fn_line.fn_name + "_meth"
-                elif fn_line.ret_ty_info.java_ty == "String":
+                elif fn_line.ret_ty_info.java_hu_ty == "String":
                     # Manually write out String methods as they're just an Object
                     out_c += "\t" + fn_line.ret_ty_info.c_ty + " ret = (*env)->CallObjectMethod(env, obj, j_calls->" + fn_line.fn_name + "_meth"
                 elif not fn_line.ret_ty_info.passed_as_ptr:
@@ -901,7 +994,7 @@ import javax.annotation.Nullable;
                         out_c = out_c + ", " + arg_info.arg_name
                 out_c = out_c + ");\n"
 
-                out_c += "\tif ((*env)->ExceptionCheck(env)) {\n"
+                out_c += "\tif (UNLIKELY((*env)->ExceptionCheck(env))) {\n"
                 out_c += "\t\t(*env)->ExceptionDescribe(env);\n"
                 out_c += "\t\t(*env)->FatalError(env, \"A call to " + fn_line.fn_name + " in " + struct_name + " from rust threw an exception.\");\n"
                 out_c += "\t}\n"
@@ -1163,6 +1256,58 @@ import javax.annotation.Nullable;
     def map_tuple(self, struct_name):
         return self.map_opaque_struct(struct_name, "A Tuple")
 
+    def map_result(self, struct_name, res_map, err_map):
+        human_ty = struct_name.replace("LDKCResult", "Result")
+        java_hu_struct = ""
+        java_hu_struct += self.hu_struct_file_prefix
+        java_hu_struct += "public class " + human_ty + " extends CommonBase {\n"
+        java_hu_struct += "\tprivate " + human_ty + "(Object _dummy, long ptr) { super(ptr); }\n"
+        java_hu_struct += "\tprotected void finalize() throws Throwable {\n"
+        java_hu_struct += "\t\tif (ptr != 0) { bindings." + struct_name.replace("LDK","") + "_free(ptr); } super.finalize();\n"
+        java_hu_struct += "\t}\n\n"
+        java_hu_struct += "\tstatic " + human_ty + " constr_from_ptr(long ptr) {\n"
+        java_hu_struct += "\t\tif (bindings." + struct_name.replace("LDK", "") + "_is_ok(ptr)) {\n"
+        java_hu_struct += "\t\t\treturn new " + human_ty + "_OK(null, ptr);\n"
+        java_hu_struct += "\t\t} else {\n"
+        java_hu_struct += "\t\t\treturn new " + human_ty + "_Err(null, ptr);\n"
+        java_hu_struct += "\t\t}\n"
+        java_hu_struct += "\t}\n"
+
+        java_hu_struct += "\tpublic static final class " + human_ty + "_OK extends " + human_ty + " {\n"
+
+        if res_map.java_hu_ty != "void":
+            java_hu_struct += "\t\tpublic final " + res_map.java_hu_ty + " res;\n"
+        java_hu_struct += "\t\tprivate " + human_ty + "_OK(Object _dummy, long ptr) {\n"
+        java_hu_struct += "\t\t\tsuper(_dummy, ptr);\n"
+        if res_map.java_hu_ty == "void":
+            pass
+        elif res_map.to_hu_conv is not None:
+            java_hu_struct += "\t\t\t" + res_map.java_ty + " res = bindings." + struct_name.replace("LDK", "") + "_get_ok(ptr);\n"
+            java_hu_struct += "\t\t\t" + res_map.to_hu_conv.replace("\n", "\n\t\t\t")
+            java_hu_struct += "\n\t\t\tthis.res = " + res_map.to_hu_conv_name + ";\n"
+        else:
+            java_hu_struct += "\t\t\tthis.res = bindings." + struct_name.replace("LDK", "") + "_get_ok(ptr);\n"
+        java_hu_struct += "\t\t}\n"
+        java_hu_struct += "\t}\n\n"
+
+        java_hu_struct += "\tpublic static final class " + human_ty + "_Err extends " + human_ty + " {\n"
+        if err_map.java_hu_ty != "void":
+            java_hu_struct += "\t\tpublic final " + err_map.java_hu_ty + " err;\n"
+        java_hu_struct += "\t\tprivate " + human_ty + "_Err(Object _dummy, long ptr) {\n"
+        java_hu_struct += "\t\t\tsuper(_dummy, ptr);\n"
+        if err_map.java_hu_ty == "void":
+            pass
+        elif err_map.to_hu_conv is not None:
+            java_hu_struct += "\t\t\t" + err_map.java_ty + " err = bindings." + struct_name.replace("LDK", "") + "_get_err(ptr);\n"
+            java_hu_struct += "\t\t\t" + err_map.to_hu_conv.replace("\n", "\n\t\t\t")
+            java_hu_struct += "\n\t\t\tthis.err = " + err_map.to_hu_conv_name + ";\n"
+        else:
+            java_hu_struct += "\t\t\tthis.err = bindings." + struct_name.replace("LDK", "") + "_get_err(ptr);\n"
+        java_hu_struct += "\t\t}\n"
+
+        java_hu_struct += "\t}\n\n"
+        return java_hu_struct
+
     def map_function(self, argument_types, c_call_string, method_name, meth_n, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_as_ref, args_known, type_mapping_generator, doc_comment):
         out_java = ""
         out_c = ""
@@ -1202,7 +1347,7 @@ import javax.annotation.Nullable;
                     out_java_struct += "\tpublic static " + return_type_info.java_hu_ty + " with_default("
                 else:
                     out_java_struct += "\tpublic static " + return_type_info.java_hu_ty + " " + meth_n + "("
-            elif meth_n == "clone_ptr":
+            elif meth_n == "clone_ptr" or (struct_meth.startswith("LDKCResult") and (meth_n == "get_ok" or meth_n == "get_err")):
                 out_java_struct += ("\t" + return_type_info.java_hu_ty + " " + meth_n + "(")
             else:
                 if meth_n == "hash" and return_type_info.java_hu_ty == "long":
@@ -1301,6 +1446,45 @@ import javax.annotation.Nullable;
                 else:
                     out_java_struct += (info.arg_name)
             out_java_struct += (");\n")
+
+            # This is completely nuts. The OpenJDK JRE JIT will optimize out a object which is on
+            # the stack, calling its finalizer immediately even if member methods are *actively
+            # executing* on the same object, as long as said object is on the stack. There is no
+            # concrete specification for when the optimizer is allowed to do this, and when it is
+            # not, so there is absolutely no way to be certain that this fix suffices.
+            #
+            # Instead, the "Java Language Specification" says only that an object is reachable
+            # (i.e. will not yet be finalized) if it "can be accessed in any potential continuing
+            # computation from any live thread". To any sensible reader this would mean actively
+            # executing a member function on an object would make it not eligible for finalization.
+            # But, no, dear reader, this statement does not say that. Well, okay, it says that,
+            # very explicitly in fact, but those are just, like, words, man.
+            #
+            # In the seemingly non-normative text further down, a few examples of things the
+            # optimizer can do are given, including "if the values in an object's fields are
+            # stored in registers[, t]he may then access the registers instead of the object, and
+            # never access the object again[, implying] that the object is garbage". This appears
+            # to fully contradict both the above statement, the API documentation in java.lang.ref
+            # regarding when a reference is "strongly reachable", and basic common sense. There is
+            # no concrete set of limitations stated, however, seemingly implying the JIT could
+            # decide your code would run faster by simply garbage collecting everything
+            # immediately, ensuring your code finishes soon, just by SEGFAULT. Thus, we're really
+            # entirely flying blind here. We add some fences and hope that its sufficient, but
+            # with no specification to rely on, we cannot be certain of anything.
+            #
+            # TL;DR: The Java Language "Specification" provides no real guarantees on when an
+            # object will be considered available for garbage collection once the JIT kicks in, so
+            # we put in some fences and hope to god the JIT doesn't get smarter/more broken.
+            for idx, info in enumerate(argument_types):
+                if idx == 0 and takes_self:
+                    out_java_struct += ("\t\tReference.reachabilityFence(this);\n")
+                elif info.arg_name in default_constructor_args:
+                    for explode_idx, explode_arg in enumerate(default_constructor_args[info.arg_name]):
+                        expl_arg_name = info.arg_name + "_" + explode_arg.arg_name
+                        out_java_struct += ("\t\tReference.reachabilityFence(" + expl_arg_name + ");\n")
+                elif info.c_ty != "void":
+                    out_java_struct += ("\t\tReference.reachabilityFence(" + info.arg_name + ");\n")
+
             if return_type_info.java_ty == "long" and return_type_info.java_hu_ty != "long":
                 out_java_struct += "\t\tif (ret >= 0 && ret <= 4096) { return null; }\n"
 
@@ -1337,3 +1521,6 @@ import javax.annotation.Nullable;
             out_java_struct += ("\t}\n\n")
 
         return (out_java, out_c, out_java_struct + extra_java_struct_out)
+
+    def cleanup(self):
+        pass