Support (fixed-length) arrays of 16-bit integers
authorMatt Corallo <git@bluematt.me>
Fri, 3 Mar 2023 01:12:36 +0000 (01:12 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 8 Mar 2023 05:12:58 +0000 (05:12 +0000)
gen_type_mapping.py
genbindings.py
java_strings.py
src/main/java/org/ldk/util/InternalUtils.java
typescript_strings.py

index d24ed230208d33a3742bc151007ebca890b8cfea..aebb290d03862d0d2e3181b3e64af0443ad98f18 100644 (file)
@@ -43,14 +43,14 @@ class TypeMappingGenerator:
             else:
                 arr_name = "ret"
                 arr_len = ret_arr_len
-            if ty_info.c_ty == "int8_tArray":
+            if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
                 (set_pfx, set_sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_len, ty_info)
-                ret_conv = ("int8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "")
+                ret_conv = (ty_info.c_ty + " " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "")
                 arg_conv_cleanup = None
                 if not arr_len.isdigit():
                     arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
                     arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = " +  self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + ";\n"
-                    if (not ty_info.is_ptr or not holds_ref) and ty_info.rust_obj != "LDKu8slice":
+                    if (not ty_info.is_ptr or not holds_ref) and (ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKu16slice"):
                         arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = MALLOC(" + arr_name + "_ref." + arr_len + ", \"" + ty_info.rust_obj + " Bytes\");\n"
                         arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_ref." + ty_info.arr_access, arr_name + "_ref." + arr_len, ty_info, True) + ";"
                     else:
@@ -59,10 +59,10 @@ class TypeMappingGenerator:
                     if ty_info.rust_obj == "LDKTransaction" or ty_info.rust_obj == "LDKWitness":
                         arg_conv = arg_conv + "\n" + arr_name + "_ref.data_is_owned = " + str(holds_ref).lower() + ";"
                     ret_conv = (ty_info.rust_obj + " " + arr_name + "_var = ", "")
-                    ret_conv = (ret_conv[0], ";\nint8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_name + "_var." + arr_len, ty_info) + ";\n")
+                    ret_conv = (ret_conv[0], ";\n" + ty_info.c_ty + " " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_name + "_var." + arr_len, ty_info) + ";\n")
                     (pfx, sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_name + "_var." + arr_len, ty_info)
                     ret_conv = (ret_conv[0], ret_conv[1] + pfx + arr_name + "_var." + ty_info.arr_access + sfx + ";")
-                    if not holds_ref and ty_info.rust_obj != "LDKu8slice":
+                    if not holds_ref and (ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKu16slice"):
                         ret_conv = (ret_conv[0], ret_conv[1] + "\n" + ty_info.rust_obj.replace("LDK", "") + "_free(" + arr_name + "_var);")
                     from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, None, arr_name)
                     to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv")
@@ -74,10 +74,11 @@ class TypeMappingGenerator:
                     from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, arr_len, arr_name)
                     to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv")
                 else:
-                    arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n"
+                    # Note that we just blindly assume we should be using unsigned integers here.
+                    arg_conv = "u" + ty_info.subty.c_ty + " " + arr_name + "_arr[" + arr_len + "];\n"
                     arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n"
                     arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_arr", arr_len, ty_info, True) + ";\n"
-                    arg_conv = arg_conv + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
+                    arg_conv = arg_conv + "u" + ty_info.subty.c_ty + " (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
                     ret_conv = (ret_conv[0] + "*", set_sfx + ";")
                     from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, arr_len, arr_name)
                     to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv")
index c9ae45ca252575c7cf0a42e7774584a0d907c098..77f9b472f7864383c2780524fb722f16982415cf 100755 (executable)
@@ -143,6 +143,11 @@ def java_c_types(fn_arg, ret_arr_len):
         assert var_is_arr_regex.match(fn_arg[8:])
         rust_obj = "LDKThirtyTwoBytes"
         arr_access = "data"
+    elif fn_arg.startswith("LDKEightU16s"):
+        fn_arg = "uint16_t (*" + fn_arg[13:] + ")[8]"
+        assert var_is_arr_regex.match(fn_arg[9:])
+        rust_obj = "LDKEightU16s"
+        arr_access = "data"
     elif fn_arg.startswith("LDKU128"):
         if fn_arg == "LDKU128":
             fn_arg = "LDKU128 arg"
index d956e0e11678ba66e3db3bbe4b2f7a941a8263cd..650e36bf92bbd996e4649a3ce5cacbb289853fc6 100644 (file)
@@ -426,6 +426,7 @@ _Static_assert(sizeof(void*) <= 8, "Pointers must fit into 64 bits");
 
 typedef jlongArray int64_tArray;
 typedef jbyteArray int8_tArray;
+typedef jshortArray int16_tArray;
 
 static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
        // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
@@ -525,14 +526,17 @@ import javax.annotation.Nullable;
     def set_native_arr_contents(self, arr_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
             return ("(*env)->SetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
+        elif ty_info.c_ty == "int16_tArray":
+            return ("(*env)->SetShortArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
         else:
             assert False
     def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
-        if ty_info.c_ty == "int8_tArray":
+        if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
+            fn_ty = "Byte" if ty_info.c_ty == "int8_tArray" else "Short"
             if copy:
-                return "(*env)->GetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
+                return "(*env)->Get" + fn_ty + "ArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
             else:
-                return "(*env)->GetByteArrayElements (env, " + arr_name + ", NULL)"
+                return "(*env)->Get" + fn_ty + "ArrayElements (env, " + arr_name + ", NULL)"
         elif not ty_info.java_ty[:len(ty_info.java_ty) - 2].endswith("[]"):
             return "(*env)->Get" + ty_info.subty.java_ty.title() + "ArrayElements (env, " + arr_name + ", NULL)"
         else:
@@ -616,7 +620,13 @@ import javax.annotation.Nullable;
         if arr_ty.rust_obj == "LDKU128":
             return ("" + arr_name + ".getLEBytes()", "")
         if fixed_len is not None:
-            return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+            if mapped_ty.c_ty == "int8_t" or mapped_ty.c_ty == "uint8_t":
+                return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+            elif mapped_ty.c_ty == "int16_t" or mapped_ty.c_ty == "uint16_t":
+                return ("InternalUtils.check_arr_16_len(" + arr_name + ", " + fixed_len + ")", "")
+            else:
+                print(arr_ty.c_ty)
+                assert False
         return None
     def primitive_arr_to_hu(self, arr_ty, fixed_len, arr_name, conv_name):
         if arr_ty.rust_obj == "LDKU128":
index d9c3f6ef8364a4f89ec1c2653f5d747455325345..4e706ec10d24d232505a555533a544d3e89a6373 100644 (file)
@@ -8,6 +8,13 @@ public class InternalUtils {
         return arr;
     }
 
+    public static short[] check_arr_16_len(short[] arr, int length) throws IllegalArgumentException {
+        if (arr != null && arr.length != length) {
+            throw new IllegalArgumentException("Array must be of fixed size " + length + " but was of length " + arr.length);
+        }
+        return arr;
+    }
+
     public static byte[] convUInt5Array(UInt5[] u5s) {
         byte[] res = new byte[u5s.length];
         for (int i = 0; i < u5s.length; i++) {
index 402e63996807408b1dde313f5854c65b730d9233..6706ff394da2c30b94acae49ab0c474bc8b9b1ca 100644 (file)
@@ -188,6 +188,16 @@ export function encodeUint8Array (inputArray: Uint8Array|null): number {
        return cArrayPointer;
 }
 /* @internal */
+export function encodeUint16Array (inputArray: Uint16Array|Array<number>|null): number {
+       if (inputArray == null) return 0;
+       const cArrayPointer = wasm.TS_malloc((inputArray.length + 4) * 2);
+       const arrayLengthView = new BigUint64Array(wasm.memory.buffer, cArrayPointer, 1);
+       arrayLengthView[0] = BigInt(inputArray.length);
+       const arrayMemoryView = new Uint16Array(wasm.memory.buffer, cArrayPointer + 8, inputArray.length);
+       arrayMemoryView.set(inputArray);
+       return cArrayPointer;
+}
+/* @internal */
 export function encodeUint32Array (inputArray: Uint32Array|Array<number>|null): number {
        if (inputArray == null) return 0;
        const cArrayPointer = wasm.TS_malloc((inputArray.length + 2) * 4);
@@ -213,6 +223,12 @@ export function check_arr_len(arr: Uint8Array|null, len: number): Uint8Array|nul
        return arr;
 }
 
+/* @internal */
+export function check_16_arr_len(arr: Uint16Array|null, len: number): Uint16Array|null {
+       if (arr !== null && arr.length != len) { throw new Error("Expected array of length " + len + " got " + arr.length); }
+       return arr;
+}
+
 /* @internal */
 export function getArrayLength(arrayPointer: number): number {
        const arraySizeViewer = new BigUint64Array(wasm.memory.buffer, arrayPointer, 1);
@@ -248,15 +264,13 @@ export function decodeUint8Array (arrayPointer: number, free = true): Uint8Array
        }
        return actualArray;
 }
-const decodeUint32Array = (arrayPointer: number, free = true) => {
+/* @internal */
+export function decodeUint16Array (arrayPointer: number, free = true): Uint16Array {
        const arraySize = getArrayLength(arrayPointer);
-       const actualArrayViewer = new Uint32Array(
-               wasm.memory.buffer, // value
-               arrayPointer + 8, // offset (ignoring length bytes)
-               arraySize // uint32 count
-       );
+       const actualArrayViewer = new Uint16Array(wasm.memory.buffer, arrayPointer + 8, arraySize);
        // Clone the contents, TODO: In the future we should wrap the Viewer in a class that
        // will free the underlying memory when it becomes unreachable instead of copying here.
+       // Note that doing so may have edge-case interactions with memory resizing (invalidating the buffer).
        const actualArray = actualArrayViewer.slice(0, arraySize);
        if (free) {
                wasm.TS_free(arrayPointer);
@@ -643,6 +657,7 @@ _Static_assert(sizeof(void*) == 4, "Pointers mut be 32 bits");
 DECL_ARR_TYPE(int64_t, int64_t);
 DECL_ARR_TYPE(uint64_t, uint64_t);
 DECL_ARR_TYPE(int8_t, int8_t);
+DECL_ARR_TYPE(int16_t, int16_t);
 DECL_ARR_TYPE(uint32_t, uint32_t);
 DECL_ARR_TYPE(void*, ptr);
 DECL_ARR_TYPE(char, char);
@@ -724,12 +739,17 @@ import * as bindings from '../bindings.mjs'
     def set_native_arr_contents(self, arr_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
             return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + ")")
+        elif ty_info.c_ty == "int16_tArray":
+            return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + " * 2)")
         else:
             assert False
     def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
-        if ty_info.c_ty == "int8_tArray":
+        if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
             if copy:
-                return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + arr_len + "); FREE(" + arr_name + ")"
+                byte_len = arr_len
+                if ty_info.c_ty == "int16_tArray":
+                    byte_len = arr_len + " * 2"
+                return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + byte_len + "); FREE(" + arr_name + ")"
         assert not copy
         if ty_info.c_ty == "ptrArray":
             return "(void*) " + arr_name + "->elems"
@@ -792,12 +812,16 @@ import * as bindings from '../bindings.mjs'
         if arr_ty.rust_obj == "LDKU128":
             return ("bindings.encodeUint128(" + inner + ")", "")
         if fixed_len is not None:
-            assert mapped_ty.c_ty == "int8_t"
-            inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+            if mapped_ty.c_ty == "int8_t":
+                inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+            elif mapped_ty.c_ty == "int16_t":
+                inner = "bindings.check_16_arr_len(" + arr_name + ", " + fixed_len + ")"
         if mapped_ty.c_ty.endswith("Array"):
             return ("bindings.encodeUint32Array(" + inner + ")", "")
         elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
             return ("bindings.encodeUint8Array(" + inner + ")", "")
+        elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t":
+            return ("bindings.encodeUint16Array(" + inner + ")", "")
         elif mapped_ty.c_ty == "uint32_t":
             return ("bindings.encodeUint32Array(" + inner + ")", "")
         elif mapped_ty.c_ty == "int64_t" or mapped_ty.c_ty == "uint64_t":
@@ -812,6 +836,8 @@ import * as bindings from '../bindings.mjs'
             return "const " + conv_name + ": bigint = bindings.decodeUint128(" + arr_name + ");"
         elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
             return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
+        elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t":
+            return "const " + conv_name + ": Uint16Array = bindings.decodeUint16Array(" + arr_name + ");"
         elif mapped_ty.c_ty == "uint64_t" or mapped_ty.c_ty == "int64_t":
             return "const " + conv_name + ": bigint[] = bindings.decodeUint64Array(" + arr_name + ");"
         else: