Support (fixed-length) arrays of 16-bit integers
[ldk-java] / typescript_strings.py
index 402e63996807408b1dde313f5854c65b730d9233..6706ff394da2c30b94acae49ab0c474bc8b9b1ca 100644 (file)
@@ -188,6 +188,16 @@ export function encodeUint8Array (inputArray: Uint8Array|null): number {
        return cArrayPointer;
 }
 /* @internal */
+export function encodeUint16Array (inputArray: Uint16Array|Array<number>|null): number {
+       if (inputArray == null) return 0;
+       const cArrayPointer = wasm.TS_malloc((inputArray.length + 4) * 2);
+       const arrayLengthView = new BigUint64Array(wasm.memory.buffer, cArrayPointer, 1);
+       arrayLengthView[0] = BigInt(inputArray.length);
+       const arrayMemoryView = new Uint16Array(wasm.memory.buffer, cArrayPointer + 8, inputArray.length);
+       arrayMemoryView.set(inputArray);
+       return cArrayPointer;
+}
+/* @internal */
 export function encodeUint32Array (inputArray: Uint32Array|Array<number>|null): number {
        if (inputArray == null) return 0;
        const cArrayPointer = wasm.TS_malloc((inputArray.length + 2) * 4);
@@ -213,6 +223,12 @@ export function check_arr_len(arr: Uint8Array|null, len: number): Uint8Array|nul
        return arr;
 }
 
+/* @internal */
+export function check_16_arr_len(arr: Uint16Array|null, len: number): Uint16Array|null {
+       if (arr !== null && arr.length != len) { throw new Error("Expected array of length " + len + " got " + arr.length); }
+       return arr;
+}
+
 /* @internal */
 export function getArrayLength(arrayPointer: number): number {
        const arraySizeViewer = new BigUint64Array(wasm.memory.buffer, arrayPointer, 1);
@@ -248,15 +264,13 @@ export function decodeUint8Array (arrayPointer: number, free = true): Uint8Array
        }
        return actualArray;
 }
-const decodeUint32Array = (arrayPointer: number, free = true) => {
+/* @internal */
+export function decodeUint16Array (arrayPointer: number, free = true): Uint16Array {
        const arraySize = getArrayLength(arrayPointer);
-       const actualArrayViewer = new Uint32Array(
-               wasm.memory.buffer, // value
-               arrayPointer + 8, // offset (ignoring length bytes)
-               arraySize // uint32 count
-       );
+       const actualArrayViewer = new Uint16Array(wasm.memory.buffer, arrayPointer + 8, arraySize);
        // Clone the contents, TODO: In the future we should wrap the Viewer in a class that
        // will free the underlying memory when it becomes unreachable instead of copying here.
+       // Note that doing so may have edge-case interactions with memory resizing (invalidating the buffer).
        const actualArray = actualArrayViewer.slice(0, arraySize);
        if (free) {
                wasm.TS_free(arrayPointer);
@@ -643,6 +657,7 @@ _Static_assert(sizeof(void*) == 4, "Pointers mut be 32 bits");
 DECL_ARR_TYPE(int64_t, int64_t);
 DECL_ARR_TYPE(uint64_t, uint64_t);
 DECL_ARR_TYPE(int8_t, int8_t);
+DECL_ARR_TYPE(int16_t, int16_t);
 DECL_ARR_TYPE(uint32_t, uint32_t);
 DECL_ARR_TYPE(void*, ptr);
 DECL_ARR_TYPE(char, char);
@@ -724,12 +739,17 @@ import * as bindings from '../bindings.mjs'
     def set_native_arr_contents(self, arr_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
             return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + ")")
+        elif ty_info.c_ty == "int16_tArray":
+            return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + " * 2)")
         else:
             assert False
     def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
-        if ty_info.c_ty == "int8_tArray":
+        if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
             if copy:
-                return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + arr_len + "); FREE(" + arr_name + ")"
+                byte_len = arr_len
+                if ty_info.c_ty == "int16_tArray":
+                    byte_len = arr_len + " * 2"
+                return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + byte_len + "); FREE(" + arr_name + ")"
         assert not copy
         if ty_info.c_ty == "ptrArray":
             return "(void*) " + arr_name + "->elems"
@@ -792,12 +812,16 @@ import * as bindings from '../bindings.mjs'
         if arr_ty.rust_obj == "LDKU128":
             return ("bindings.encodeUint128(" + inner + ")", "")
         if fixed_len is not None:
-            assert mapped_ty.c_ty == "int8_t"
-            inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+            if mapped_ty.c_ty == "int8_t":
+                inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+            elif mapped_ty.c_ty == "int16_t":
+                inner = "bindings.check_16_arr_len(" + arr_name + ", " + fixed_len + ")"
         if mapped_ty.c_ty.endswith("Array"):
             return ("bindings.encodeUint32Array(" + inner + ")", "")
         elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
             return ("bindings.encodeUint8Array(" + inner + ")", "")
+        elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t":
+            return ("bindings.encodeUint16Array(" + inner + ")", "")
         elif mapped_ty.c_ty == "uint32_t":
             return ("bindings.encodeUint32Array(" + inner + ")", "")
         elif mapped_ty.c_ty == "int64_t" or mapped_ty.c_ty == "uint64_t":
@@ -812,6 +836,8 @@ import * as bindings from '../bindings.mjs'
             return "const " + conv_name + ": bigint = bindings.decodeUint128(" + arr_name + ");"
         elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
             return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
+        elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t":
+            return "const " + conv_name + ": Uint16Array = bindings.decodeUint16Array(" + arr_name + ");"
         elif mapped_ty.c_ty == "uint64_t" or mapped_ty.c_ty == "int64_t":
             return "const " + conv_name + ": bigint[] = bindings.decodeUint64Array(" + arr_name + ");"
         else: