else:
arr_name = "ret"
arr_len = ret_arr_len
- if ty_info.c_ty == "int8_tArray":
+ if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
(set_pfx, set_sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_len, ty_info)
- ret_conv = ("int8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "")
+ ret_conv = (ty_info.c_ty + " " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_len, ty_info) + ";\n" + set_pfx, "")
arg_conv_cleanup = None
if not arr_len.isdigit():
arg_conv = ty_info.rust_obj + " " + arr_name + "_ref;\n"
arg_conv = arg_conv + arr_name + "_ref." + arr_len + " = " + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + ";\n"
- if (not ty_info.is_ptr or not holds_ref) and ty_info.rust_obj != "LDKu8slice":
+ if (not ty_info.is_ptr or not holds_ref) and (ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKu16slice"):
arg_conv = arg_conv + arr_name + "_ref." + ty_info.arr_access + " = MALLOC(" + arr_name + "_ref." + arr_len + ", \"" + ty_info.rust_obj + " Bytes\");\n"
arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_ref." + ty_info.arr_access, arr_name + "_ref." + arr_len, ty_info, True) + ";"
else:
if ty_info.rust_obj == "LDKTransaction" or ty_info.rust_obj == "LDKWitness":
arg_conv = arg_conv + "\n" + arr_name + "_ref.data_is_owned = " + str(holds_ref).lower() + ";"
ret_conv = (ty_info.rust_obj + " " + arr_name + "_var = ", "")
- ret_conv = (ret_conv[0], ";\nint8_tArray " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_name + "_var." + arr_len, ty_info) + ";\n")
+ ret_conv = (ret_conv[0], ";\n" + ty_info.c_ty + " " + arr_name + "_arr = " + self.consts.create_native_arr_call(arr_name + "_var." + arr_len, ty_info) + ";\n")
(pfx, sfx) = self.consts.set_native_arr_contents(arr_name + "_arr", arr_name + "_var." + arr_len, ty_info)
ret_conv = (ret_conv[0], ret_conv[1] + pfx + arr_name + "_var." + ty_info.arr_access + sfx + ";")
- if not holds_ref and ty_info.rust_obj != "LDKu8slice":
+ if not holds_ref and (ty_info.rust_obj != "LDKu8slice" and ty_info.rust_obj != "LDKu16slice"):
ret_conv = (ret_conv[0], ret_conv[1] + "\n" + ty_info.rust_obj.replace("LDK", "") + "_free(" + arr_name + "_var);")
from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, None, arr_name)
to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv")
from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, arr_len, arr_name)
to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv")
else:
- arg_conv = "unsigned char " + arr_name + "_arr[" + arr_len + "];\n"
+ # Note that we just blindly assume we should be using unsigned integers here.
+ arg_conv = "u" + ty_info.subty.c_ty + " " + arr_name + "_arr[" + arr_len + "];\n"
arg_conv = arg_conv + "CHECK(" + self.consts.get_native_arr_len_call[0] + arr_name + self.consts.get_native_arr_len_call[1] + " == " + arr_len + ");\n"
arg_conv = arg_conv + self.consts.get_native_arr_contents(arr_name, arr_name + "_arr", arr_len, ty_info, True) + ";\n"
- arg_conv = arg_conv + "unsigned char (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
+ arg_conv = arg_conv + "u" + ty_info.subty.c_ty + " (*" + arr_name + "_ref)[" + arr_len + "] = &" + arr_name + "_arr;"
ret_conv = (ret_conv[0] + "*", set_sfx + ";")
from_hu_conv = self.consts.primitive_arr_from_hu(ty_info, arr_len, arr_name)
to_hu_conv = self.consts.primitive_arr_to_hu(ty_info, None, arr_name, arr_name + "_conv")
assert var_is_arr_regex.match(fn_arg[8:])
rust_obj = "LDKThirtyTwoBytes"
arr_access = "data"
+ elif fn_arg.startswith("LDKEightU16s"):
+ fn_arg = "uint16_t (*" + fn_arg[13:] + ")[8]"
+ assert var_is_arr_regex.match(fn_arg[9:])
+ rust_obj = "LDKEightU16s"
+ arr_access = "data"
elif fn_arg.startswith("LDKU128"):
if fn_arg == "LDKU128":
fn_arg = "LDKU128 arg"
typedef jlongArray int64_tArray;
typedef jbyteArray int8_tArray;
+typedef jshortArray int16_tArray;
static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
// Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
def set_native_arr_contents(self, arr_name, arr_len, ty_info):
if ty_info.c_ty == "int8_tArray":
return ("(*env)->SetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
+ elif ty_info.c_ty == "int16_tArray":
+ return ("(*env)->SetShortArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
else:
assert False
def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
- if ty_info.c_ty == "int8_tArray":
+ if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
+ fn_ty = "Byte" if ty_info.c_ty == "int8_tArray" else "Short"
if copy:
- return "(*env)->GetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
+ return "(*env)->Get" + fn_ty + "ArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
else:
- return "(*env)->GetByteArrayElements (env, " + arr_name + ", NULL)"
+ return "(*env)->Get" + fn_ty + "ArrayElements (env, " + arr_name + ", NULL)"
elif not ty_info.java_ty[:len(ty_info.java_ty) - 2].endswith("[]"):
return "(*env)->Get" + ty_info.subty.java_ty.title() + "ArrayElements (env, " + arr_name + ", NULL)"
else:
if arr_ty.rust_obj == "LDKU128":
return ("" + arr_name + ".getLEBytes()", "")
if fixed_len is not None:
- return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+ if mapped_ty.c_ty == "int8_t" or mapped_ty.c_ty == "uint8_t":
+ return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+ elif mapped_ty.c_ty == "int16_t" or mapped_ty.c_ty == "uint16_t":
+ return ("InternalUtils.check_arr_16_len(" + arr_name + ", " + fixed_len + ")", "")
+ else:
+ print(arr_ty.c_ty)
+ assert False
return None
def primitive_arr_to_hu(self, arr_ty, fixed_len, arr_name, conv_name):
if arr_ty.rust_obj == "LDKU128":
return arr;
}
+ public static short[] check_arr_16_len(short[] arr, int length) throws IllegalArgumentException {
+ if (arr != null && arr.length != length) {
+ throw new IllegalArgumentException("Array must be of fixed size " + length + " but was of length " + arr.length);
+ }
+ return arr;
+ }
+
public static byte[] convUInt5Array(UInt5[] u5s) {
byte[] res = new byte[u5s.length];
for (int i = 0; i < u5s.length; i++) {
return cArrayPointer;
}
/* @internal */
+export function encodeUint16Array (inputArray: Uint16Array|Array<number>|null): number {
+ if (inputArray == null) return 0;
+ const cArrayPointer = wasm.TS_malloc((inputArray.length + 4) * 2);
+ const arrayLengthView = new BigUint64Array(wasm.memory.buffer, cArrayPointer, 1);
+ arrayLengthView[0] = BigInt(inputArray.length);
+ const arrayMemoryView = new Uint16Array(wasm.memory.buffer, cArrayPointer + 8, inputArray.length);
+ arrayMemoryView.set(inputArray);
+ return cArrayPointer;
+}
+/* @internal */
export function encodeUint32Array (inputArray: Uint32Array|Array<number>|null): number {
if (inputArray == null) return 0;
const cArrayPointer = wasm.TS_malloc((inputArray.length + 2) * 4);
return arr;
}
+/* @internal */
+export function check_16_arr_len(arr: Uint16Array|null, len: number): Uint16Array|null {
+ if (arr !== null && arr.length != len) { throw new Error("Expected array of length " + len + " got " + arr.length); }
+ return arr;
+}
+
/* @internal */
export function getArrayLength(arrayPointer: number): number {
const arraySizeViewer = new BigUint64Array(wasm.memory.buffer, arrayPointer, 1);
}
return actualArray;
}
-const decodeUint32Array = (arrayPointer: number, free = true) => {
+/* @internal */
+export function decodeUint16Array (arrayPointer: number, free = true): Uint16Array {
const arraySize = getArrayLength(arrayPointer);
- const actualArrayViewer = new Uint32Array(
- wasm.memory.buffer, // value
- arrayPointer + 8, // offset (ignoring length bytes)
- arraySize // uint32 count
- );
+ const actualArrayViewer = new Uint16Array(wasm.memory.buffer, arrayPointer + 8, arraySize);
// Clone the contents, TODO: In the future we should wrap the Viewer in a class that
// will free the underlying memory when it becomes unreachable instead of copying here.
+ // Note that doing so may have edge-case interactions with memory resizing (invalidating the buffer).
const actualArray = actualArrayViewer.slice(0, arraySize);
if (free) {
wasm.TS_free(arrayPointer);
DECL_ARR_TYPE(int64_t, int64_t);
DECL_ARR_TYPE(uint64_t, uint64_t);
DECL_ARR_TYPE(int8_t, int8_t);
+DECL_ARR_TYPE(int16_t, int16_t);
DECL_ARR_TYPE(uint32_t, uint32_t);
DECL_ARR_TYPE(void*, ptr);
DECL_ARR_TYPE(char, char);
def set_native_arr_contents(self, arr_name, arr_len, ty_info):
if ty_info.c_ty == "int8_tArray":
return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + ")")
+ elif ty_info.c_ty == "int16_tArray":
+ return ("memcpy(" + arr_name + "->elems, ", ", " + arr_len + " * 2)")
else:
assert False
def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
- if ty_info.c_ty == "int8_tArray":
+ if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
if copy:
- return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + arr_len + "); FREE(" + arr_name + ")"
+ byte_len = arr_len
+ if ty_info.c_ty == "int16_tArray":
+ byte_len = arr_len + " * 2"
+ return "memcpy(" + dest_name + ", " + arr_name + "->elems, " + byte_len + "); FREE(" + arr_name + ")"
assert not copy
if ty_info.c_ty == "ptrArray":
return "(void*) " + arr_name + "->elems"
if arr_ty.rust_obj == "LDKU128":
return ("bindings.encodeUint128(" + inner + ")", "")
if fixed_len is not None:
- assert mapped_ty.c_ty == "int8_t"
- inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+ if mapped_ty.c_ty == "int8_t":
+ inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
+ elif mapped_ty.c_ty == "int16_t":
+ inner = "bindings.check_16_arr_len(" + arr_name + ", " + fixed_len + ")"
if mapped_ty.c_ty.endswith("Array"):
return ("bindings.encodeUint32Array(" + inner + ")", "")
elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
return ("bindings.encodeUint8Array(" + inner + ")", "")
+ elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t":
+ return ("bindings.encodeUint16Array(" + inner + ")", "")
elif mapped_ty.c_ty == "uint32_t":
return ("bindings.encodeUint32Array(" + inner + ")", "")
elif mapped_ty.c_ty == "int64_t" or mapped_ty.c_ty == "uint64_t":
return "const " + conv_name + ": bigint = bindings.decodeUint128(" + arr_name + ");"
elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
+ elif mapped_ty.c_ty == "uint16_t" or mapped_ty.c_ty == "int16_t":
+ return "const " + conv_name + ": Uint16Array = bindings.decodeUint16Array(" + arr_name + ");"
elif mapped_ty.c_ty == "uint64_t" or mapped_ty.c_ty == "int64_t":
return "const " + conv_name + ": bigint[] = bindings.decodeUint64Array(" + arr_name + ");"
else: