CHECK_ACCESS inner pointers using the new __unmangle_inner_ptr meth
[ldk-java] / java_strings.py
index 5452829d70275e4e547bcf47dfc50fc2dba4df44..6e77fc347569a5145ef50e6e48b2575c9965b177 100644 (file)
@@ -18,8 +18,8 @@ class Consts:
         )
 
         self.to_hu_conv_templates = dict(
-            ptr = '{human_type} {var_name}_hu_conv = new {human_type}(null, {var_name});',
-            default = '{human_type} {var_name}_hu_conv = new {human_type}(null, {var_name});'
+            ptr = '{human_type} {var_name}_hu_conv = null; if ({var_name} < 0 || {var_name} > 4096) { {var_name}_hu_conv = new {human_type}(null, {var_name}); }',
+            default = '{human_type} {var_name}_hu_conv = null; if ({var_name} < 0 || {var_name} > 4096) { {var_name}_hu_conv = new {human_type}(null, {var_name}); }'
         )
 
         self.bindings_header = """package org.ldk.impl;
@@ -115,7 +115,7 @@ import java.util.LinkedList;
 class CommonBase {
        long ptr;
        LinkedList<Object> ptrs_to = new LinkedList();
-       protected CommonBase(long ptr) { assert ptr > 1024; this.ptr = ptr; }
+       protected CommonBase(long ptr) { assert ptr < 0 || ptr > 4096; this.ptr = ptr; }
 }
 """
 
@@ -161,7 +161,9 @@ void __attribute__((constructor)) spawn_stderr_redirection() {
 
         if not DEBUG or sys.platform == "darwin":
             self.c_file_pfx = self.c_file_pfx + """#define MALLOC(a, _) malloc(a)
-#define FREE(p) if ((uint64_t)(p) > 1024) { free(p); }
+#define FREE(p) if ((uint64_t)(p) > 4096) { free(p); }
+#define CHECK_ACCESS(p)
+#define CHECK_INNER_FIELD_ACCESS_OR_NULL(v)
 """
         if not DEBUG:
             self.c_file_pfx += """#define DO_ASSERT(a) (void)(a)
@@ -282,7 +284,7 @@ static void alloc_freed(void* ptr) {
        while (it->ptr != ptr) {
                p = it; it = it->next;
                if (it == NULL) {
-                       DEBUG_PRINT("Tried to free unknown pointer %p at:\\n", ptr);
+                       DEBUG_PRINT("ERROR: Tried to free unknown pointer %p at:\\n", ptr);
                        void* bt[BT_MAX];
                        int bt_len = backtrace(bt, BT_MAX);
                        backtrace_symbols_fd(bt, bt_len, STDERR_FILENO);
@@ -297,7 +299,7 @@ static void alloc_freed(void* ptr) {
        __real_free(it);
 }
 static void FREE(void* ptr) {
-       if ((uint64_t)ptr < 1024) return; // Rust loves to create pointers to the NULL page for dummys
+       if ((uint64_t)ptr <= 4096) return; // Rust loves to create pointers to the NULL page for dummys
        alloc_freed(ptr);
        __real_free(ptr);
 }
@@ -318,6 +320,31 @@ void __wrap_free(void* ptr) {
        __real_free(ptr);
 }
 
+static void CHECK_ACCESS(const void* ptr) {
+       DO_ASSERT(!pthread_mutex_lock(&allocation_mtx));
+       allocation* it = allocation_ll;
+       while (it->ptr != ptr) {
+               it = it->next;
+               if (it == NULL) {
+                       DEBUG_PRINT("ERROR: Tried to access unknown pointer %p at:\\n", ptr);
+                       void* bt[BT_MAX];
+                       int bt_len = backtrace(bt, BT_MAX);
+                       backtrace_symbols_fd(bt, bt_len, STDERR_FILENO);
+                       DEBUG_PRINT("\\n\\n");
+                       DO_ASSERT(!pthread_mutex_unlock(&allocation_mtx));
+                       return; // addrsan should catch and print more info than we have
+               }
+       }
+       DO_ASSERT(!pthread_mutex_unlock(&allocation_mtx));
+}
+#define CHECK_INNER_FIELD_ACCESS_OR_NULL(v) \\
+       if (v.is_owned && v.inner != NULL) { \\
+               const void *p = __unmangle_inner_ptr(v.inner); \\
+               if (p != NULL) { \\
+                       CHECK_ACCESS(p); \\
+               } \\
+       }
+
 void* __real_realloc(void* ptr, size_t newlen);
 void* __wrap_realloc(void* ptr, size_t len) {
        if (ptr != NULL) alloc_freed(ptr);
@@ -595,8 +622,10 @@ import javax.annotation.Nullable;
             out_java_enum += "/**\n * " + enum_doc_comment.replace("\n", "\n * ") + "\n */\n"
         out_java_enum += "public enum " + struct_name + " {\n"
         ord_v = 0
-        for var in variants:
-            out_java_enum = out_java_enum + "\t" + var + ",\n"
+        for var, var_docs in variants:
+            if var_docs is not None:
+                out_java_enum += "\t/**\n\t * " + var_docs.replace("\n", "\n\t * ") + "\n\t */\n"
+            out_java_enum += "\t" + var + ",\n"
             out_c = out_c + "\t\tcase %d: return %s;\n" % (ord_v, var)
             ord_v = ord_v + 1
         out_java_enum = out_java_enum + "\t; static native void init();\n"
@@ -608,19 +637,19 @@ import javax.annotation.Nullable;
         out_c = out_c + "}\n"
 
         out_c = out_c + "static jclass " + struct_name + "_class = NULL;\n"
-        for var in variants:
+        for var, _ in variants:
             out_c = out_c + "static jfieldID " + struct_name + "_" + var + " = NULL;\n"
         out_c = out_c + self.c_fn_ty_pfx + "void JNICALL Java_org_ldk_enums_" + struct_name.replace("_", "_1") + "_init (" + self.c_fn_args_pfx + ") {\n"
         out_c = out_c + "\t" + struct_name + "_class = (*env)->NewGlobalRef(env, clz);\n"
         out_c = out_c + "\tCHECK(" + struct_name + "_class != NULL);\n"
-        for var in variants:
+        for var, _ in variants:
             out_c = out_c + "\t" + struct_name + "_" + var + " = (*env)->GetStaticFieldID(env, " + struct_name + "_class, \"" + var + "\", \"Lorg/ldk/enums/" + struct_name + ";\");\n"
             out_c = out_c + "\tCHECK(" + struct_name + "_" + var + " != NULL);\n"
         out_c = out_c + "}\n"
         out_c = out_c + "static inline jclass LDK" + struct_name + "_to_java(JNIEnv *env, LDK" + struct_name + " val) {\n"
         out_c = out_c + "\tswitch (val) {\n"
         ord_v = 0
-        for var in variants:
+        for var, _ in variants:
             out_c = out_c + "\t\tcase " + var + ":\n"
             out_c = out_c + "\t\t\treturn (*env)->GetStaticObjectField(env, " + struct_name + "_class, " + struct_name + "_" + var + ");\n"
             ord_v = ord_v + 1
@@ -1030,7 +1059,9 @@ import javax.annotation.Nullable;
         out_java +=  ("\t\tprivate " + struct_name + "() {}\n")
         for var in variant_list:
             out_java +=  ("\t\tpublic final static class " + var.var_name + " extends " + struct_name + " {\n")
-            java_hu_subclasses = java_hu_subclasses + "\tpublic final static class " + var.var_name + " extends " + java_hu_type + " {\n"
+            if var.var_docs is not None:
+                java_hu_subclasses += "\t/**\n\t * " + var.var_docs.replace("\n", "\n\t * ") + "\n\t */\n"
+            java_hu_subclasses += "\tpublic final static class " + var.var_name + " extends " + java_hu_type + " {\n"
             out_java_enum += ("\t\tif (raw_val.getClass() == bindings." + struct_name + "." + var.var_name + ".class) {\n")
             out_java_enum += ("\t\t\treturn new " + var.var_name + "(ptr, (bindings." + struct_name + "." + var.var_name + ")raw_val);\n")
             init_meth_jty_str = ""
@@ -1112,11 +1143,12 @@ import javax.annotation.Nullable;
         out_opaque_struct_human += self.hu_struct_file_prefix
         out_opaque_struct_human += "\n/**\n * " + struct_doc_comment.replace("\n", "\n * ") + "\n */\n"
         out_opaque_struct_human += "@SuppressWarnings(\"unchecked\") // We correctly assign various generic arrays\n"
-        out_opaque_struct_human += ("public class " + struct_name.replace("LDK","") + " extends CommonBase")
+        hu_name = struct_name.replace("LDKC2Tuple", "TwoTuple").replace("LDKC3Tuple", "ThreeTuple").replace("LDK", "")
+        out_opaque_struct_human += ("public class " + hu_name + " extends CommonBase")
         if struct_name.startswith("LDKLocked"):
             out_opaque_struct_human += (" implements AutoCloseable")
         out_opaque_struct_human += (" {\n")
-        out_opaque_struct_human += ("\t" + struct_name.replace("LDK", "") + "(Object _dummy, long ptr) { super(ptr); }\n")
+        out_opaque_struct_human += ("\t" + hu_name + "(Object _dummy, long ptr) { super(ptr); }\n")
         if struct_name.startswith("LDKLocked"):
             out_opaque_struct_human += ("\t@Override public void close() {\n")
         else:
@@ -1127,8 +1159,10 @@ import javax.annotation.Nullable;
         out_opaque_struct_human += ("\t}\n\n")
         return out_opaque_struct_human
 
+    def map_tuple(self, struct_name):
+        return self.map_opaque_struct(struct_name, "A Tuple")
 
-    def map_function(self, argument_types, c_call_string, method_name, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_as_ref, args_known, type_mapping_generator, doc_comment):
+    def map_function(self, argument_types, c_call_string, method_name, meth_n, return_type_info, struct_meth, default_constructor_args, takes_self, takes_self_as_ref, args_known, type_mapping_generator, doc_comment):
         out_java = ""
         out_c = ""
         out_java_struct = None
@@ -1155,7 +1189,6 @@ import javax.annotation.Nullable;
         if not args_known:
             out_java_struct += ("\t// Skipped " + method_name + "\n")
         else:
-            meth_n = method_name[len(struct_meth) + 1 if len(struct_meth) != 0 else 0:].strip("_")
             if doc_comment is not None:
                 out_java_struct += "\t/**\n\t * " + doc_comment.replace("\n", "\n\t * ") + "\n\t */\n"
             if return_type_info.nullable:
@@ -1256,7 +1289,7 @@ import javax.annotation.Nullable;
                     out_java_struct += (info.arg_name)
             out_java_struct += (");\n")
             if return_type_info.java_ty == "long" and return_type_info.java_hu_ty != "long":
-                out_java_struct += "\t\tif (ret < 1024) { return null; }\n"
+                out_java_struct += "\t\tif (ret >= 0 && ret <= 4096) { return null; }\n"
 
             if return_type_info.to_hu_conv is not None:
                 if not takes_self: