CHECK_ACCESS inner pointers using the new __unmangle_inner_ptr meth
[ldk-java] / java_strings.py
index 357c32eaaaf20083cb4ba4cdcc9e09d35db1f6b2..6e77fc347569a5145ef50e6e48b2575c9965b177 100644 (file)
@@ -18,12 +18,13 @@ class Consts:
         )
 
         self.to_hu_conv_templates = dict(
-            ptr = '{human_type} {var_name}_hu_conv = new {human_type}(null, {var_name});',
-            default = '{human_type} {var_name}_hu_conv = new {human_type}(null, {var_name});'
+            ptr = '{human_type} {var_name}_hu_conv = null; if ({var_name} < 0 || {var_name} > 4096) { {var_name}_hu_conv = new {human_type}(null, {var_name}); }',
+            default = '{human_type} {var_name}_hu_conv = null; if ({var_name} < 0 || {var_name} > 4096) { {var_name}_hu_conv = new {human_type}(null, {var_name}); }'
         )
 
         self.bindings_header = """package org.ldk.impl;
 import org.ldk.enums.*;
+import org.ldk.impl.version;
 import java.io.File;
 import java.io.InputStream;
 import java.io.IOException;
@@ -55,7 +56,7 @@ public class bindings {
                                Path libpath = new File(tmpdir.toPath().toString(), "liblightningjni.so").toPath();
                                Files.copy(is, libpath, StandardCopyOption.REPLACE_EXISTING);
                                Runtime.getRuntime().load(libpath.toString());
-                       } catch (IOException e) {
+                       } catch (Exception e) {
                                System.err.println("Failed to load LDK native library.");
                                System.err.println("System LDK native library load failed with: " + system_load_err);
                                System.err.println("Resource-based LDK native library load failed with: " + e);
@@ -64,7 +65,7 @@ public class bindings {
                }
                init(java.lang.Enum.class, VecOrSliceDef.class);
                init_class_cache();
-               if (!get_lib_version_string().equals(get_ldk_java_bindings_version()))
+               if (!get_lib_version_string().equals(version.get_ldk_java_bindings_version()))
                        throw new IllegalArgumentException("Compiled LDK library and LDK class failes do not match");
                // Fetching the LDK versions from C also checks that the header and binaries match
                get_ldk_c_bindings_version();
@@ -74,9 +75,6 @@ public class bindings {
        static native void init_class_cache();
        static native String get_lib_version_string();
 
-       public static String get_ldk_java_bindings_version() {
-               return "<git_version_ldk_garbagecollected>";
-       }
        public static native String get_ldk_c_bindings_version();
        public static native String get_ldk_version();
 
@@ -93,6 +91,13 @@ public class bindings {
        public static native long new_empty_slice_vec();
 
 """
+        self.bindings_version_file = """package org.ldk.impl;
+
+public class version {
+       public static String get_ldk_java_bindings_version() {
+               return "<git_version_ldk_garbagecollected>";
+       }
+}"""
 
         self.bindings_footer = "}\n"
 
@@ -110,7 +115,7 @@ import java.util.LinkedList;
 class CommonBase {
        long ptr;
        LinkedList<Object> ptrs_to = new LinkedList();
-       protected CommonBase(long ptr) { assert ptr > 1024; this.ptr = ptr; }
+       protected CommonBase(long ptr) { assert ptr < 0 || ptr > 4096; this.ptr = ptr; }
 }
 """
 
@@ -156,7 +161,9 @@ void __attribute__((constructor)) spawn_stderr_redirection() {
 
         if not DEBUG or sys.platform == "darwin":
             self.c_file_pfx = self.c_file_pfx + """#define MALLOC(a, _) malloc(a)
-#define FREE(p) if ((uint64_t)(p) > 1024) { free(p); }
+#define FREE(p) if ((uint64_t)(p) > 4096) { free(p); }
+#define CHECK_ACCESS(p)
+#define CHECK_INNER_FIELD_ACCESS_OR_NULL(v)
 """
         if not DEBUG:
             self.c_file_pfx += """#define DO_ASSERT(a) (void)(a)
@@ -277,7 +284,7 @@ static void alloc_freed(void* ptr) {
        while (it->ptr != ptr) {
                p = it; it = it->next;
                if (it == NULL) {
-                       DEBUG_PRINT("Tried to free unknown pointer %p at:\\n", ptr);
+                       DEBUG_PRINT("ERROR: Tried to free unknown pointer %p at:\\n", ptr);
                        void* bt[BT_MAX];
                        int bt_len = backtrace(bt, BT_MAX);
                        backtrace_symbols_fd(bt, bt_len, STDERR_FILENO);
@@ -292,7 +299,7 @@ static void alloc_freed(void* ptr) {
        __real_free(it);
 }
 static void FREE(void* ptr) {
-       if ((uint64_t)ptr < 1024) return; // Rust loves to create pointers to the NULL page for dummys
+       if ((uint64_t)ptr <= 4096) return; // Rust loves to create pointers to the NULL page for dummys
        alloc_freed(ptr);
        __real_free(ptr);
 }
@@ -313,6 +320,31 @@ void __wrap_free(void* ptr) {
        __real_free(ptr);
 }
 
+static void CHECK_ACCESS(const void* ptr) {
+       DO_ASSERT(!pthread_mutex_lock(&allocation_mtx));
+       allocation* it = allocation_ll;
+       while (it->ptr != ptr) {
+               it = it->next;
+               if (it == NULL) {
+                       DEBUG_PRINT("ERROR: Tried to access unknown pointer %p at:\\n", ptr);
+                       void* bt[BT_MAX];
+                       int bt_len = backtrace(bt, BT_MAX);
+                       backtrace_symbols_fd(bt, bt_len, STDERR_FILENO);
+                       DEBUG_PRINT("\\n\\n");
+                       DO_ASSERT(!pthread_mutex_unlock(&allocation_mtx));
+                       return; // addrsan should catch and print more info than we have
+               }
+       }
+       DO_ASSERT(!pthread_mutex_unlock(&allocation_mtx));
+}
+#define CHECK_INNER_FIELD_ACCESS_OR_NULL(v) \\
+       if (v.is_owned && v.inner != NULL) { \\
+               const void *p = __unmangle_inner_ptr(v.inner); \\
+               if (p != NULL) { \\
+                       CHECK_ACCESS(p); \\
+               } \\
+       }
+
 void* __real_realloc(void* ptr, size_t newlen);
 void* __wrap_realloc(void* ptr, size_t len) {
        if (ptr != NULL) alloc_freed(ptr);
@@ -461,16 +493,17 @@ static inline LDKStr java_to_owned_str(JNIEnv *env, jstring str) {
        return res;
 }
 
-JNIEXPORT jstring JNICALL Java_org_ldk_impl_bindings_get_1lib_1version_1string(JNIEnv *env, jclass _c) {
-       return str_ref_to_java(env, "<git_version_ldk_garbagecollected>", strlen("<git_version_ldk_garbagecollected>"));
-}
 JNIEXPORT jstring JNICALL Java_org_ldk_impl_bindings_get_1ldk_1c_1bindings_1version(JNIEnv *env, jclass _c) {
        return str_ref_to_java(env, check_get_ldk_bindings_version(), strlen(check_get_ldk_bindings_version()));
 }
 JNIEXPORT jstring JNICALL Java_org_ldk_impl_bindings_get_1ldk_1version(JNIEnv *env, jclass _c) {
        return str_ref_to_java(env, check_get_ldk_version(), strlen(check_get_ldk_version()));
 }
+#include "version.c"
 """
+        self.c_version_file = """JNIEXPORT jstring JNICALL Java_org_ldk_impl_bindings_get_1lib_1version_1string(JNIEnv *env, jclass _c) {
+       return str_ref_to_java(env, "<git_version_ldk_garbagecollected>", strlen("<git_version_ldk_garbagecollected>"));
+}"""
 
         self.hu_struct_file_prefix = """package org.ldk.structs;
 
@@ -589,8 +622,10 @@ import javax.annotation.Nullable;
             out_java_enum += "/**\n * " + enum_doc_comment.replace("\n", "\n * ") + "\n */\n"
         out_java_enum += "public enum " + struct_name + " {\n"
         ord_v = 0
-        for var in variants:
-            out_java_enum = out_java_enum + "\t" + var + ",\n"
+        for var, var_docs in variants:
+            if var_docs is not None:
+                out_java_enum += "\t/**\n\t * " + var_docs.replace("\n", "\n\t * ") + "\n\t */\n"
+            out_java_enum += "\t" + var + ",\n"
             out_c = out_c + "\t\tcase %d: return %s;\n" % (ord_v, var)
             ord_v = ord_v + 1
         out_java_enum = out_java_enum + "\t; static native void init();\n"
@@ -602,19 +637,19 @@ import javax.annotation.Nullable;
         out_c = out_c + "}\n"
 
         out_c = out_c + "static jclass " + struct_name + "_class = NULL;\n"
-        for var in variants:
+        for var, _ in variants:
             out_c = out_c + "static jfieldID " + struct_name + "_" + var + " = NULL;\n"
         out_c = out_c + self.c_fn_ty_pfx + "void JNICALL Java_org_ldk_enums_" + struct_name.replace("_", "_1") + "_init (" + self.c_fn_args_pfx + ") {\n"
         out_c = out_c + "\t" + struct_name + "_class = (*env)->NewGlobalRef(env, clz);\n"
         out_c = out_c + "\tCHECK(" + struct_name + "_class != NULL);\n"
-        for var in variants:
+        for var, _ in variants:
             out_c = out_c + "\t" + struct_name + "_" + var + " = (*env)->GetStaticFieldID(env, " + struct_name + "_class, \"" + var + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n"
             out_c = out_c + "\tCHECK(" + struct_name + "_" + var + " != NULL);\n"
         out_c = out_c + "}\n"
         out_c = out_c + "static inline jclass LDK" + struct_name + "_to_java(JNIEnv *env, LDK" + struct_name + " val) {\n"
         out_c = out_c + "\tswitch (val) {\n"
         ord_v = 0
-        for var in variants:
+        for var, _ in variants:
             out_c = out_c + "\t\tcase " + var + ":\n"
             out_c = out_c + "\t\t\treturn (*env)->GetStaticObjectField(env, " + struct_name + "_class, " + struct_name + "_" + var + ");\n"
             ord_v = ord_v + 1
@@ -1024,7 +1059,9 @@ import javax.annotation.Nullable;
         out_java +=  ("\t\tprivate " + struct_name + "() {}\n")
         for var in variant_list:
             out_java +=  ("\t\tpublic final static class " + var.var_name + " extends " + struct_name + " {\n")
-            java_hu_subclasses = java_hu_subclasses + "\tpublic final static class " + var.var_name + " extends " + java_hu_type + " {\n"
+            if var.var_docs is not None:
+                java_hu_subclasses += "\t/**\n\t * " + var.var_docs.replace("\n", "\n\t * ") + "\n\t */\n"
+            java_hu_subclasses += "\tpublic final static class " + var.var_name + " extends " + java_hu_type + " {\n"
             out_java_enum += ("\t\tif (raw_val.getClass() == bindings." + struct_name + "." + var.var_name + ".class) {\n")
             out_java_enum += ("\t\t\treturn new " + var.var_name + "(ptr, (bindings." + struct_name + "." + var.var_name + ")raw_val);\n")
             init_meth_jty_str = ""
@@ -1106,11 +1143,12 @@ import javax.annotation.Nullable;
         out_opaque_struct_human += self.hu_struct_file_prefix
         out_opaque_struct_human += "\n/**\n * " + struct_doc_comment.replace("\n", "\n * ") + "\n */\n"
         out_opaque_struct_human += "@SuppressWarnings(\"unchecked\") // We correctly assign various generic arrays\n"
-        out_opaque_struct_human += ("public class " + struct_name.replace("LDK","") + " extends CommonBase")
+        hu_name = struct_name.replace("LDKC2Tuple", "TwoTuple").replace("LDKC3Tuple", "ThreeTuple").replace("LDK", "")
+        out_opaque_struct_human += ("public class " + hu_name + " extends CommonBase")
         if struct_name.startswith("LDKLocked"):
             out_opaque_struct_human += (" implements AutoCloseable")
         out_opaque_struct_human += (" {\n")
-        out_opaque_struct_human += ("\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); }\n")
+        out_opaque_struct_human += ("\t" + hu_name + "(Object _dummy, long ptr) { super(ptr); }\n")
         if struct_name.startswith("LDKLocked"):
             out_opaque_struct_human += ("\t@Override public void close() {\n")
         else:
@@ -1121,8 +1159,10 @@ import javax.annotation.Nullable;
         out_opaque_struct_human += ("\t}\n\n")
         return out_opaque_struct_human
 
+    def map_tuple(self, struct_name):
+        return self.map_opaque_struct(struct_name, "A Tuple")
 
-    def map_function(self, argument_types, c_call_string, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_as_ref, args_known, type_mapping_generator, doc_comment):
+    def map_function(self, argument_types, c_call_string, method_name, meth_n, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_as_ref, args_known, type_mapping_generator, doc_comment):
         out_java = ""
         out_c = ""
         out_java_struct = None
@@ -1149,7 +1189,6 @@ import javax.annotation.Nullable;
         if not args_known:
             out_java_struct += ("\t// Skipped " + method_name + "\n")
         else:
-            meth_n = method_name[len(struct_meth) + 1 if len(struct_meth) != 0 else 0:].strip("_")
             if doc_comment is not None:
                 out_java_struct += "\t/**\n\t * " + doc_comment.replace("\n", "\n\t * ") + "\n\t */\n"
             if return_type_info.nullable:
@@ -1250,7 +1289,7 @@ import javax.annotation.Nullable;
                     out_java_struct += (info.arg_name)
             out_java_struct += (");\n")
             if return_type_info.java_ty == "long" and return_type_info.java_hu_ty != "long":
-                out_java_struct += "\t\tif (ret < 1024) { return null; }\n"
+                out_java_struct += "\t\tif (ret >= 0 && ret <= 4096) { return null; }\n"
 
             if return_type_info.to_hu_conv is not None:
                 if not takes_self: