Support (fixed-length) arrays of 16-bit integers
[ldk-java] / java_strings.py
index d956e0e11678ba66e3db3bbe4b2f7a941a8263cd..650e36bf92bbd996e4649a3ce5cacbb289853fc6 100644 (file)
@@ -426,6 +426,7 @@ _Static_assert(sizeof(void*) <= 8, "Pointers must fit into 64 bits");
 
 typedef jlongArray int64_tArray;
 typedef jbyteArray int8_tArray;
+typedef jshortArray int16_tArray;
 
 static inline jstring str_ref_to_java(JNIEnv *env, const char* chars, size_t len) {
        // Sadly we need to create a temporary because Java can't accept a char* without a 0-terminator
@@ -525,14 +526,17 @@ import javax.annotation.Nullable;
     def set_native_arr_contents(self, arr_name, arr_len, ty_info):
         if ty_info.c_ty == "int8_tArray":
             return ("(*env)->SetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
+        elif ty_info.c_ty == "int16_tArray":
+            return ("(*env)->SetShortArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", ", ")")
         else:
             assert False
     def get_native_arr_contents(self, arr_name, dest_name, arr_len, ty_info, copy):
-        if ty_info.c_ty == "int8_tArray":
+        if ty_info.c_ty == "int8_tArray" or ty_info.c_ty == "int16_tArray":
+            fn_ty = "Byte" if ty_info.c_ty == "int8_tArray" else "Short"
             if copy:
-                return "(*env)->GetByteArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
+                return "(*env)->Get" + fn_ty + "ArrayRegion(env, " + arr_name + ", 0, " + arr_len + ", " + dest_name + ")"
             else:
-                return "(*env)->GetByteArrayElements (env, " + arr_name + ", NULL)"
+                return "(*env)->Get" + fn_ty + "ArrayElements (env, " + arr_name + ", NULL)"
         elif not ty_info.java_ty[:len(ty_info.java_ty) - 2].endswith("[]"):
             return "(*env)->Get" + ty_info.subty.java_ty.title() + "ArrayElements (env, " + arr_name + ", NULL)"
         else:
@@ -616,7 +620,13 @@ import javax.annotation.Nullable;
         if arr_ty.rust_obj == "LDKU128":
             return ("" + arr_name + ".getLEBytes()", "")
         if fixed_len is not None:
-            return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+            if mapped_ty.c_ty == "int8_t" or mapped_ty.c_ty == "uint8_t":
+                return ("InternalUtils.check_arr_len(" + arr_name + ", " + fixed_len + ")", "")
+            elif mapped_ty.c_ty == "int16_t" or mapped_ty.c_ty == "uint16_t":
+                return ("InternalUtils.check_arr_16_len(" + arr_name + ", " + fixed_len + ")", "")
+            else:
+                print(arr_ty.c_ty)
+                assert False
         return None
     def primitive_arr_to_hu(self, arr_ty, fixed_len, arr_name, conv_name):
         if arr_ty.rust_obj == "LDKU128":