Add support for U128, passed as [u8; 16] but with human wrappers
authorMatt Corallo <git@bluematt.me>
Tue, 27 Dec 2022 02:02:24 +0000 (02:02 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 11 Jan 2023 21:14:44 +0000 (21:14 +0000)
genbindings.py
java_strings.py
src/main/java/org/ldk/util/UInt128.java [new file with mode: 0644]
typescript_strings.py

index 47df927a08c4f9f71e03ef4c8b62dac239a7d25d..fb3ad6819a16e7cace4b7ce0cb3c9708c8181179 100755 (executable)
@@ -139,6 +139,16 @@ def java_c_types(fn_arg, ret_arr_len):
         assert var_is_arr_regex.match(fn_arg[8:])
         rust_obj = "LDKThirtyTwoBytes"
         arr_access = "data"
+    elif fn_arg.startswith("LDKU128"):
+        if fn_arg == "LDKU128":
+            fn_arg = "LDKU128 arg"
+        if fn_arg.startswith("LDKU128*") or fn_arg.startswith("LDKU128 *"):
+            fn_arg = "uint8_t (" + fn_arg[8:] + ")[16]"
+        else:
+            fn_arg = "uint8_t (*" + fn_arg[8:] + ")[16]"
+        assert var_is_arr_regex.match(fn_arg[8:])
+        rust_obj = "LDKU128"
+        arr_access = "le_bytes"
     elif fn_arg.startswith("LDKTxid"):
         fn_arg = "uint8_t (*" + fn_arg[8:] + ")[32]"
         assert var_is_arr_regex.match(fn_arg[8:])
@@ -381,6 +391,8 @@ def java_c_types(fn_arg, ret_arr_len):
         else:
             java_ty = java_ty + "[]"
             java_hu_ty = java_ty
+        if rust_obj == "LDKU128":
+            java_hu_ty = consts.u128_native_ty
         c_ty = c_ty + "Array"
 
         subty = java_c_types(arr_ty, None)
index 6308ed518dd795e953f7ced3c3bdca265631fc75..151edba5ef5641e39092ba922751b8cd1235f645 100644 (file)
@@ -574,6 +574,7 @@ import javax.annotation.Nullable;
         self.file_ext = ".java"
         self.ptr_c_ty = "int64_t"
         self.ptr_native_ty = "long"
+        self.u128_native_ty = "UInt128"
         self.usize_c_ty = "int64_t"
         self.usize_native_ty = "long"
         self.native_zero_ptr = "0"
@@ -704,10 +705,14 @@ import javax.annotation.Nullable;
 
     def primitive_arr_from_hu(self, arr_ty, fixed_len, arr_name):
         mapped_ty = arr_ty.subty
+        if arr_ty.rust_obj == "LDKU128":
+            return ("" + arr_name + ".getLEBytes()", "")
         if fixed_len is not None:
             return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
         return None
     def primitive_arr_to_hu(self, arr_ty, fixed_len, arr_name, conv_name):
+        if arr_ty.rust_obj == "LDKU128":
+            return "org.ldk.util.UInt128 " + conv_name + " = new org.ldk.util.UInt128(" + arr_name + ");"
         return None
 
     def java_arr_ty_str(self, elem_ty_str):
@@ -729,7 +734,7 @@ import javax.annotation.Nullable;
     def fully_qualified_hu_ty_path(self, ty):
         if ty.java_fn_ty_arg.startswith("L") and ty.java_fn_ty_arg.endswith(";"):
             return ty.java_fn_ty_arg.strip("L;").replace("/", ".")
-        if ty.java_hu_ty == "UnqualifiedError" or ty.java_hu_ty == "UInt5" or ty.java_hu_ty == "WitnessVersion":
+        if ty.java_hu_ty == "UnqualifiedError" or ty.java_hu_ty == "UInt128" or ty.java_hu_ty == "UInt5" or ty.java_hu_ty == "WitnessVersion":
             return "org.ldk.util." + ty.java_hu_ty
         if not ty.is_native_primitive and ty.rust_obj is not None and not "[]" in ty.java_hu_ty:
             return "org.ldk.structs." + ty.java_hu_ty
diff --git a/src/main/java/org/ldk/util/UInt128.java b/src/main/java/org/ldk/util/UInt128.java
new file mode 100644 (file)
index 0000000..5b5f457
--- /dev/null
@@ -0,0 +1,42 @@
+package org.ldk.util;
+
+import java.util.Arrays;
+
+/**
+ * A 5-bit unsigned integer
+ */
+public class UInt128 {
+    private byte[] le_bytes;
+
+    /**
+     * Constructs a 128-bit integer from its little-endian representation
+     */
+    public UInt128(byte[] le_bytes) {
+        if (le_bytes.length != 16) {
+            throw new IllegalArgumentException();
+        }
+        this.le_bytes = le_bytes;
+    }
+
+    /**
+     * Constructs a 128-bit integer from a long, ignoring the sign bit
+     */
+    public UInt128(long val) {
+        byte[] le_bytes = new byte[16];
+        for (int i = 0; i < 8; i++)
+            le_bytes[i] = (byte) ((val >> i*8) & 0xff);
+        this.le_bytes = le_bytes;
+    }
+
+    /**
+     * @return The value as 16 little endian bytes
+     */
+    public byte[] getLEBytes() { return le_bytes; }
+
+    @Override public boolean equals(Object o) {
+        if (o == null || !(o instanceof UInt128)) return false;
+        return Arrays.equals(le_bytes, ((UInt128) o).le_bytes);
+    }
+
+    @Override public int hashCode() { return Arrays.hashCode(le_bytes); }
+}
index 37d12b61818efbfa35fedc92e79b717364449a2c..15e35d459983d3c1827a513340cf7e317b8d9727 100644 (file)
@@ -166,6 +166,16 @@ export function WitnessVersionArrToBytes(inputArray: Array<WitnessVersion>): Uin
 
 
 
+/* @internal */
+export function encodeUint128 (inputVal: bigint): number {
+       if (inputVal >= 0x10000000000000000000000000000000n) throw "U128s cannot exceed 128 bits";
+       const cArrayPointer = wasm.TS_malloc(16 + 8);
+       const arrayLengthView = new BigUint64Array(wasm.memory.buffer, cArrayPointer, 1);
+       arrayLengthView[0] = BigInt(16);
+       const arrayMemoryView = new Uint8Array(wasm.memory.buffer, cArrayPointer + 8, 16);
+       for (var i = 0; i < 16; i++) arrayMemoryView[i] = Number((inputVal >> BigInt(i)*8n) & 0xffn);
+       return cArrayPointer;
+}
 /* @internal */
 export function encodeUint8Array (inputArray: Uint8Array|null): number {
        if (inputArray == null) return 0;
@@ -210,6 +220,21 @@ export function getArrayLength(arrayPointer: number): number {
        return Number(len % (2n ** 32n));
 }
 /* @internal */
+export function decodeUint128 (arrayPointer: number, free = true): bigint {
+       const arraySize = getArrayLength(arrayPointer);
+       if (arraySize != 16) throw "Need 16 bytes for a uint128";
+       const actualArrayViewer = new Uint8Array(wasm.memory.buffer, arrayPointer + 8, arraySize);
+       var val = 0n;
+       for (var i = 0; i < 16; i++) {
+               val <<= 8n;
+               val |= BigInt(actualArrayViewer[i]!);
+       }
+       if (free) {
+               wasm.TS_free(arrayPointer);
+       }
+       return val;
+}
+/* @internal */
 export function decodeUint8Array (arrayPointer: number, free = true): Uint8Array {
        const arraySize = getArrayLength(arrayPointer);
        const actualArrayViewer = new Uint8Array(wasm.memory.buffer, arrayPointer + 8, arraySize);
@@ -678,6 +703,7 @@ import * as bindings from '../bindings.mjs'
         self.file_ext = ".mts"
         self.ptr_c_ty = "uint64_t"
         self.ptr_native_ty = "bigint"
+        self.u128_native_ty = "bigint"
         self.usize_c_ty = "uint32_t"
         self.usize_native_ty = "number"
         self.native_zero_ptr = "0n"
@@ -760,6 +786,8 @@ import * as bindings from '../bindings.mjs'
     def primitive_arr_from_hu(self, arr_ty, fixed_len, arr_name):
         mapped_ty = arr_ty.subty
         inner = arr_name
+        if arr_ty.rust_obj == "LDKU128":
+            return ("bindings.encodeUint128(" + inner + ")", "")
         if fixed_len is not None:
             assert mapped_ty.c_ty == "int8_t"
             inner = "bindings.check_arr_len(" + arr_name + ", " + fixed_len + ")"
@@ -777,7 +805,9 @@ import * as bindings from '../bindings.mjs'
 
     def primitive_arr_to_hu(self, arr_ty, fixed_len, arr_name, conv_name):
         mapped_ty = arr_ty.subty
-        if mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
+        if arr_ty.rust_obj == "LDKU128":
+            return "const " + conv_name + ": bigint = bindings.decodeUint128(" + arr_name + ");"
+        elif mapped_ty.c_ty == "uint8_t" or mapped_ty.c_ty == "int8_t":
             return "const " + conv_name + ": Uint8Array = bindings.decodeUint8Array(" + arr_name + ");"
         elif mapped_ty.c_ty == "uint64_t" or mapped_ty.c_ty == "int64_t":
             return "const " + conv_name + ": bigint[] = bindings.decodeUint64Array(" + arr_name + ");"