[TS] Don't null-check array conversions that aren't nullable
[ldk-java] / csharp_strings.py
index 5361302be675bc3268437769b18c0f48a584b9c5..74830d0458ce2a56bc23a15fd3a5e76b130cf906 100644 (file)
@@ -15,6 +15,8 @@ class Consts:
             uint16_t = ['short'],
             uint32_t = ['int'],
             uint64_t = ['long'],
+            int64_t = ['long'],
+            double = ['double'],
         )
         self.java_type_map = dict(
             String = "string"
@@ -31,11 +33,42 @@ class Consts:
         self.bindings_header = """
 using org.ldk.enums;
 using org.ldk.impl;
+using System;
 using System.Runtime.InteropServices;
 
 namespace org { namespace ldk { namespace impl {
 
 internal class bindings {
+       internal class ArrayCoder : ICustomMarshaler {
+               int size = 0;
+               GCHandle pinnedArray;
+               public static ICustomMarshaler GetInstance(string pstrCookie) {
+                       return new ArrayCoder();
+               }
+
+               public Object MarshalNativeToManaged(IntPtr pNativeData) { throw new NotImplementedException(); }
+               public IntPtr MarshalManagedToNative(Object obj) {
+                       if (obj.GetType() == typeof(byte[])) {
+                               byte[] inp = (byte[])obj;
+                               IntPtr data = Marshal.AllocHGlobal(inp.Length + 8);
+                               Marshal.WriteInt64(data, inp.Length);
+                               Marshal.Copy(inp, 0, data + 8, inp.Length);
+                               this.size = inp.Length + 8;
+                               return data;
+                       } else {
+                               throw new NotImplementedException();
+                       }
+               }
+               public void CleanUpNativeData(IntPtr pNativeData) {
+                       Marshal.FreeHGlobal(pNativeData);
+               }
+               public void CleanUpManagedData(Object ManagedObj) { }
+               public int GetNativeDataSize() {
+                       // Blindly guess based on the last allocation, no idea how else to implement this.
+                       return this.size;
+               }
+       }
+
        /*static {
                init(java.lang.Enum.class, VecOrSliceDef.class);
                init_class_cache();
@@ -79,6 +112,33 @@ public class CommonBase {
 }
 """
 
+        self.txin_defn = """public class TxIn : CommonBase {
+       /** The witness in this input, in serialized form */
+       public readonly byte[] witness;
+       /** The script_sig in this input */
+       public readonly byte[] script_sig;
+       /** The transaction output's sequence number */
+       public readonly int sequence;
+       /** The txid this input is spending */
+       public readonly byte[] previous_txid;
+       /** The output index within the spent transaction of the output this input is spending */
+       public readonly int previous_vout;
+
+       internal TxIn(object _dummy, long ptr) : base(ptr) {
+               this.witness = bindings.TxIn_get_witness(ptr);
+               this.script_sig = bindings.TxIn_get_script_sig(ptr);
+               this.sequence = bindings.TxIn_get_sequence(ptr);
+               this.previous_txid = bindings.TxIn_get_previous_txid(ptr);
+               this.previous_vout = bindings.TxIn_get_previous_vout(ptr);
+       }
+       public TxIn(byte[] witness, byte[] script_sig, int sequence, byte[] previous_txid, int previous_vout)
+       : this(null, bindings.TxIn_new(witness, script_sig, sequence, previous_txid, previous_vout)) {}
+
+       ~TxIn() {
+               if (ptr != 0) { bindings.TxIn_free(ptr); }
+       }
+}"""
+
         self.txout_defn = """public class TxOut : CommonBase {
        /** The script_pubkey in this output */
        public readonly byte[] script_pubkey;
@@ -89,10 +149,7 @@ public class CommonBase {
                this.script_pubkey = bindings.TxOut_get_script_pubkey(ptr);
                this.value = bindings.TxOut_get_value(ptr);
        }
-    public TxOut(long value, byte[] script_pubkey) : base(bindings.TxOut_new(script_pubkey, value)) {
-               this.script_pubkey = bindings.TxOut_get_script_pubkey(ptr);
-               this.value = bindings.TxOut_get_value(ptr);
-       }
+    public TxOut(long value, byte[] script_pubkey) : this(null, bindings.TxOut_new(script_pubkey, value)) {}
 
        ~TxOut() {
                if (ptr != 0) { bindings.TxOut_free(ptr); }
@@ -132,7 +189,8 @@ public class CommonBase {
         self.c_file_pfx = self.c_file_pfx + "#include <stdio.h>\n#define DEBUG_PRINT(...) fprintf(stderr, __VA_ARGS__)\n"
 
         if not DEBUG or sys.platform == "darwin":
-            self.c_file_pfx = self.c_file_pfx + """#define MALLOC(a, _) malloc(a)
+            self.c_file_pfx = self.c_file_pfx + """#define do_MALLOC(a, _b, _c) malloc(a)
+#define MALLOC(a, _) malloc(a)
 #define FREE(p) if ((uint64_t)(p) > 4096) { free(p); }
 #define CHECK_ACCESS(p)
 #define CHECK_INNER_FIELD_ACCESS_OR_NULL(v)
@@ -199,11 +257,13 @@ static void new_allocation(void* res, const char* struct_name, size_t len) {
        allocation_ll = new_alloc;
        DO_ASSERT(!pthread_mutex_unlock(&allocation_mtx));
 }
-static void* MALLOC(size_t len, const char* struct_name) {
+static void* do_MALLOC(size_t len, const char* struct_name, int lineno) {
        void* res = __real_malloc(len);
-       new_allocation(res, struct_name, len);
+       new_allocation(res, struct_name, lineno);
        return res;
 }
+#define MALLOC(len, struct_name) do_MALLOC(len, struct_name, __LINE__)
+
 void __real_free(void* ptr);
 static void alloc_freed(void* ptr) {
        allocation* p = NULL;
@@ -312,38 +372,49 @@ _Static_assert(offsetof(LDKCVec_u8Z, datalen) == offsetof(LDKu8slice, datalen),
 
 _Static_assert(sizeof(void*) <= 8, "Pointers must fit into 64 bits");
 
-typedef jlongArray int64_tArray;
-typedef jbyteArray int8_tArray;
-
-static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
-       // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
-       char* conv_buf = MALLOC(len + 1, "str conv buf");
-       memcpy(conv_buf, chars, len);
-       conv_buf[len] = 0;
-       jstring ret = (*env)->NewStringUTF(env, conv_buf);
-       FREE(conv_buf);
-       return ret;
+#define DECL_ARR_TYPE(ty, name) \\
+       struct name##array { \\
+               uint64_t arr_len; /* uint32_t would suffice but we want to align uint64_ts as well */ \\
+               ty elems[]; \\
+       }; \\
+       typedef struct name##array * name##Array; \\
+       static inline name##Array init_##name##Array(size_t arr_len, int lineno) { \\
+               name##Array arr = (name##Array)do_MALLOC(arr_len * sizeof(ty) + sizeof(uint64_t), #name" array init", lineno); \\
+               arr->arr_len = arr_len; \\
+               return arr; \\
+       }
+
+DECL_ARR_TYPE(int64_t, int64_t);
+DECL_ARR_TYPE(uint64_t, uint64_t);
+DECL_ARR_TYPE(int8_t, int8_t);
+DECL_ARR_TYPE(int16_t, int16_t);
+DECL_ARR_TYPE(uint32_t, uint32_t);
+DECL_ARR_TYPE(void*, ptr);
+DECL_ARR_TYPE(char, char);
+typedef charArray jstring;
+
+static inline jstring str_ref_to_cs(const char* chars, size_t len) {
+       charArray arr = init_charArray(len, __LINE__);
+       memcpy(arr->elems, chars, len);
+       return arr;
 }
-static inline LDKStr java_to_owned_str(JNIEnv *env, jstring str) {
-       uint64_t str_len = (*env)->GetStringUTFLength(env, str);
-       char* newchars = MALLOC(str_len + 1, "String chars");
-       const char* jchars = (*env)->GetStringUTFChars(env, str, NULL);
-       memcpy(newchars, jchars, str_len);
-       newchars[str_len] = 0;
-       (*env)->ReleaseStringUTFChars(env, str, jchars);
+static inline LDKStr str_ref_to_owned_c(const jstring str) {
+       char* newchars = MALLOC(str->arr_len + 1, "String chars");
+       memcpy(newchars, str->elems, str->arr_len);
+       newchars[str->arr_len] = 0;
        LDKStr res = {
                .chars = newchars,
-               .len = str_len,
+               .len = str->arr_len,
                .chars_is_owned = true
        };
        return res;
 }
 
-const char* CS_LDK_get_ldk_c_bindings_version() {
-       return str_ref_to_java(check_get_ldk_bindings_version(), strlen(check_get_ldk_bindings_version()));
+jstring CS_LDK_get_ldk_c_bindings_version() {
+       return str_ref_to_cs(check_get_ldk_bindings_version(), strlen(check_get_ldk_bindings_version()));
 }
-const char* CS_LDK_get_ldk_version() {
-       return str_ref_to_java(check_get_ldk_version(), strlen(check_get_ldk_version()));
+jstring CS_LDK_get_ldk_version() {
+       return str_ref_to_cs(check_get_ldk_version(), strlen(check_get_ldk_version()));
 }
 #include "version.c"
 """
@@ -384,69 +455,46 @@ namespace org { namespace ldk { namespace structs {
     def c_fn_name_define_pfx(self, fn_name, have_args):
         return " CS_LDK_" + fn_name + "("
 
-    def construct_jenv(self):
-        res =  "JNIEnv *env;\n"
-        res += "jint get_jenv_res = (*j_calls->vm)->GetEnv(j_calls->vm, (void**)&env, JNI_VERSION_1_6);\n"
-        res += "if (get_jenv_res == JNI_EDETACHED) {\n"
-        res += "\tDO_ASSERT((*j_calls->vm)->AttachCurrentThread(j_calls->vm, (void**)&env, NULL) == JNI_OK);\n"
-        res += "} else {\n"
-        res += "\tDO_ASSERT(get_jenv_res == JNI_OK);\n"
-        res += "}\n"
-        return res
-    def deconstruct_jenv(self):
-        res = "if (get_jenv_res == JNI_EDETACHED) {\n"
-        res += "\tDO_ASSERT((*j_calls->vm)->DetachCurrentThread(j_calls->vm) == JNI_OK);\n"
-        res += "}\n"
-        return res
-
     def release_native_arr_ptr_call(self, ty_info, arr_var, arr_ptr_var):
-        if ty_info.subty is None or not ty_info.subty.c_ty.endswith("Array"):
-            return "(*env)->ReleasePrimitiveArrayCritical(env, " + arr_var + ", " + arr_ptr_var + ", 0)"
         return None
     def create_native_arr_call(self, arr_len, ty_info):
-        if ty_info.c_ty == "int8_tArray":
-            return "(*env)->NewByteArray(env, " + arr_len + ")"
-        elif ty_info.subty.c_ty.endswith("Array"):
-            clz_var = ty_info.java_fn_ty_arg[1:].replace("[", "arr_of_")
-            self.c_array_class_caches.add(clz_var)
-            return "(*env)->NewObjectArray(env, " + arr_len + ", " + clz_var + "_clz, NULL);\n"
-        else:
-            return "(*env)->New" + ty_info.java_ty.strip("[]").title() + "Array(env, " + arr_len + ")"
+        if ty_info.c_ty == "ptrArray":
+            assert ty_info.rust_obj == "LDKCVec_U5Z" or (ty_info.subty is not None and (ty_info.subty.c_ty.endswith("Array") or ty_info.subty.rust_obj == "LDKStr"))
+        return "init_" + ty_info.c_ty + "(" + arr_len + ", __LINE__)"
     def set_native_arr_contents(self, arr_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
-            return ("(*env)->SetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
+            return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + ")")
+        elif ty_info.c_ty == "int16_tArray":
+            return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + " * 2)")
         else:
             assert False
     def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
-        if ty_info.c_ty == "int8_tArray":
+        if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
             if copy:
-                return "(*env)->GetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
-            else:
-                return "(*env)->GetByteArrayElements (env, " + arr_name + ", NULL)"
-        elif not ty_info.java_ty[:len(ty_info.java_ty) - 2].endswith("[]"):
-            return "(*env)->Get" + ty_info.subty.java_ty.title() + "ArrayElements (env, " + arr_name + ", NULL)"
+                byte_len = arr_len
+                if ty_info.c_ty == "int16_tArray":
+                    byte_len = arr_len + " * 2"
+                return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + byte_len + "); FREE(" + arr_name + ")"
+        assert not copy
+        if ty_info.c_ty == "ptrArray":
+            return "(void*) " + arr_name + "->elems"
         else:
-            return None
+            return arr_name + "->elems"
     def get_native_arr_elem(self, arr_name, idxc, ty_info):
-        if self.get_native_arr_contents(arr_name, "", "", ty_info, False) is None:
-            return "(*env)->GetObjectArrayElement(env, " + arr_name + ", " + idxc + ")"
-        else:
-            assert False # Only called if above is None
+        assert False # Only called if above is None
     def get_native_arr_ptr_call(self, ty_info):
-        if ty_info.subty is not None and ty_info.subty.c_ty.endswith("Array"):
-            return None
-        return ("(*env)->GetPrimitiveArrayCritical(env, ", ", NULL)")
+        if ty_info.subty is not None:
+            return "(" + ty_info.subty.c_ty + "*)(((uint8_t*)", ") + 8)"
+        return "(" + ty_info.c_ty + "*)(((uint8_t*)", ") + 8)"
     def get_native_arr_entry_call(self, ty_info, arr_name, idxc, entry_access):
-        if ty_info.subty is None or not ty_info.subty.c_ty.endswith("Array"):
-            return None
-        return "(*env)->SetObjectArrayElement(env, " + arr_name + ", " + idxc + ", " + entry_access + ")"
+        return None
     def cleanup_native_arr_ref_contents(self, arr_name, dest_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
-            return "(*env)->ReleaseByteArrayElements(env, " + arr_name + ", (int8_t*)" + dest_name + ", 0);"
+            return "FREE(" + arr_name + ");"
         else:
-            return "(*env)->Release" + ty_info.java_ty.strip("[]").title() + "ArrayElements(env, " + arr_name + ", " + dest_name + ", 0)"
+            return "FREE(" + arr_name + ")"
 
-    def map_hu_array_elems(self, arr_name, conv_name, arr_ty, elem_ty):
+    def map_hu_array_elems(self, arr_name, conv_name, arr_ty, elem_ty, is_nullable):
         if elem_ty.java_hu_ty == "UInt5":
             return arr_name + " != null ? InternalUtils.convUInt5Array(" + arr_name + ") : null"
         elif elem_ty.java_hu_ty == "WitnessVersion":
@@ -455,9 +503,9 @@ namespace org { namespace ldk { namespace structs {
             return arr_name + " != null ? InternalUtils.mapArray(" + arr_name + ", " + conv_name + " => " + elem_ty.from_hu_conv[0] + ") : null"
 
     def str_ref_to_native_call(self, var_name, str_len):
-        return "str_ref_to_java(env, " + var_name + ", " + str_len + ")"
+        return "str_ref_to_cs(" + var_name + ", " + str_len + ")"
     def str_ref_to_c_call(self, var_name):
-        return "java_to_owned_str(env, " + var_name + ")"
+        return "str_ref_to_owned_c(" + var_name + ")"
     def str_to_hu_conv(self, var_name):
         return None
     def str_from_hu_conv(self, var_name):
@@ -703,17 +751,21 @@ namespace org { namespace ldk { namespace structs {
                 java_trait_constr = java_trait_constr + ", " + var.arg_name
             else:
                 java_trait_constr += ", " + var[1] + ".new_impl(" + var[1] + "_impl"
+                suptrait_constr = ""
                 for suparg in var[2]:
                     if isinstance(suparg, ConvInfo):
-                        java_trait_constr += ", " + suparg.arg_name
+                        suptrait_constr += ", " + suparg.arg_name
                     else:
-                        java_trait_constr += ", " + suparg[1]
-                java_trait_constr += ").bindings_instance"
+                        suptrait_constr += ", " + suparg[1] + "_impl"
+                java_trait_constr += suptrait_constr + ").bindings_instance"
                 for suparg in var[2]:
                     if isinstance(suparg, ConvInfo):
                         java_trait_constr += ", " + suparg.arg_name
                     else:
-                        java_trait_constr += ", " + suparg[1]
+                        java_trait_constr += ", " + suparg[1] + ".new_impl("
+                        # Blindly assume that we can just strip the first arg to build the args for the supertrait
+                        java_trait_constr += suptrait_constr.split(", ", 1)[1]
+                        java_trait_constr += ").bindings_instance"
         out_java_trait += "\t}\n" + java_trait_wrapper + "\n"
         out_java_trait += java_trait_constr + ");\n\t\treturn impl_holder.held;\n\t}\n"
 
@@ -749,9 +801,7 @@ namespace org { namespace ldk { namespace structs {
                 out_c = out_c + "static void " + struct_name + "_JCalls_free(void* this_arg) {\n"
                 out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
                 out_c = out_c + "\tif (atomic_fetch_sub_explicit(&j_calls->refcnt, 1, memory_order_acquire) == 1) {\n"
-                out_c += "\t\t" + self.construct_jenv().replace("\n", "\n\t\t").strip() + "\n"
                 out_c = out_c + "\t\t(*env)->DeleteWeakGlobalRef(env, j_calls->o);\n"
-                out_c += "\t\t" + self.deconstruct_jenv().replace("\n", "\n\t\t").strip() + "\n"
                 out_c = out_c + "\t\tFREE(j_calls);\n"
                 out_c = out_c + "\t}\n}\n"
 
@@ -769,7 +819,6 @@ namespace org { namespace ldk { namespace structs {
 
                 out_c = out_c + ") {\n"
                 out_c = out_c + "\t" + struct_name + "_JCalls *j_calls = (" + struct_name + "_JCalls*) this_arg;\n"
-                out_c += "\t" + self.construct_jenv().replace("\n", "\n\t").strip() + "\n"
 
                 for arg_info in fn_line.args_ty:
                     if arg_info.ret_conv is not None:
@@ -804,10 +853,8 @@ namespace org { namespace ldk { namespace structs {
 
                 if fn_line.ret_ty_info.arg_conv is not None:
                     out_c += "\t" + fn_line.ret_ty_info.arg_conv.replace("\n", "\n\t") + "\n"
-                    out_c += "\t" + self.deconstruct_jenv().replace("\n", "\n\t").strip() + "\n"
                     out_c += "\treturn " + fn_line.ret_ty_info.arg_conv_name + ";\n"
                 else:
-                    out_c += "\t" + self.deconstruct_jenv().replace("\n", "\n\t").strip() + "\n"
                     if not fn_line.ret_ty_info.passed_as_ptr and fn_line.ret_ty_info.c_ty != "void":
                         out_c += "\treturn ret;\n"
 
@@ -1094,6 +1141,8 @@ namespace org { namespace ldk { namespace structs {
                 out_c += (", ")
             if arg_conv_info.c_ty != "void":
                 out_c += (arg_conv_info.c_ty + " " + arg_conv_info.arg_name)
+                if "[]" in arg_conv_info.java_ty:
+                    out_java += "[MarshalAs(UnmanagedType.CustomMarshaler, MarshalType=\"org.ldk.impl.ArrayCoder\")] "
                 out_java += (arg_conv_info.java_ty + " _" + arg_conv_info.arg_name) # Add a _ to avoid using reserved words
 
         out_java_struct = ""