Ensure `TypeInfo` always has `subty` set if its an array
authorMatt Corallo <git@bluematt.me>
Mon, 10 Jan 2022 00:50:23 +0000 (00:50 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 10 Jan 2022 06:33:14 +0000 (06:33 +0000)
genbindings.py

index 34434b2ee33e4327b29073bb58d20b3861a11e33..f426b23fc995cbdd513e958cc15a2a5f28d306ae 100755 (executable)
@@ -230,6 +230,7 @@ def java_c_types(fn_arg, ret_arr_len):
     arr_len = None
     mapped_type = []
     java_type_plural = None
+    arr_ty = None
     if fn_arg.startswith("void"):
         java_ty = "void"
         c_ty = "void"
@@ -240,6 +241,7 @@ def java_c_types(fn_arg, ret_arr_len):
         java_ty = "boolean"
         c_ty = "jboolean"
         fn_ty_arg = "Z"
+        arr_ty = "bool"
         fn_arg = fn_arg[4:].strip()
         is_primitive = True
     elif fn_arg.startswith("uint8_t"):
@@ -247,6 +249,7 @@ def java_c_types(fn_arg, ret_arr_len):
         java_ty = mapped_type[0]
         c_ty = "int8_t"
         fn_ty_arg = "B"
+        arr_ty = "uint8_t"
         fn_arg = fn_arg[7:].strip()
         is_primitive = True
     elif fn_arg.startswith("LDKu5"):
@@ -254,12 +257,14 @@ def java_c_types(fn_arg, ret_arr_len):
         java_hu_ty = "UInt5"
         rust_obj = "LDKu5"
         c_ty = "int8_t"
+        arr_ty = "uint8_t"
         fn_ty_arg = "B"
         fn_arg = fn_arg[6:].strip()
     elif fn_arg.startswith("uint16_t"):
         mapped_type = consts.c_type_map['uint16_t']
         java_ty = mapped_type[0]
         c_ty = "int16_t"
+        arr_ty = "uint16_t"
         fn_ty_arg = "S"
         fn_arg = fn_arg[8:].strip()
         is_primitive = True
@@ -267,6 +272,7 @@ def java_c_types(fn_arg, ret_arr_len):
         mapped_type = consts.c_type_map['uint32_t']
         java_ty = mapped_type[0]
         c_ty = "int32_t"
+        arr_ty = "uint32_t"
         fn_ty_arg = "I"
         fn_arg = fn_arg[8:].strip()
         is_primitive = True
@@ -277,10 +283,12 @@ def java_c_types(fn_arg, ret_arr_len):
         fn_ty_arg = "J"
         if fn_arg.startswith("uint64_t"):
             c_ty = "int64_t"
+            arr_ty = "uint64_t"
             fn_arg = fn_arg[8:].strip()
         else:
             java_ty = consts.ptr_native_ty
             c_ty = "int64_t"
+            arr_ty = "uintptr_t"
             rust_obj = "uintptr_t"
             fn_arg = fn_arg[9:].strip()
         is_primitive = True
@@ -288,10 +296,12 @@ def java_c_types(fn_arg, ret_arr_len):
         java_ty = consts.java_type_map["String"]
         java_hu_ty = consts.java_hu_type_map["String"]
         c_ty = "const char*"
+        arr_ty = "LDKStr"
         fn_ty_arg = "Ljava/lang/String;"
         fn_arg = fn_arg[6:].strip()
     elif fn_arg.startswith("LDKStr"):
         rust_obj = "LDKStr"
+        arr_ty = "LDKStr"
         java_ty = consts.java_type_map["String"]
         java_hu_ty = consts.java_hu_type_map["String"]
         c_ty = "jstring"
@@ -301,6 +311,7 @@ def java_c_types(fn_arg, ret_arr_len):
         arr_len = "len"
     else:
         ma = var_ty_regex.match(fn_arg)
+        arr_ty = ma.group(1).strip()
         if ma.group(1).strip() in unitary_enums:
             assert ma.group(1).strip().startswith("LDK")
             java_ty = ma.group(1).strip()[3:]
@@ -335,6 +346,7 @@ def java_c_types(fn_arg, ret_arr_len):
         fn_ty_arg = "J"
 
     var_is_arr = var_is_arr_regex.match(fn_arg)
+    subty = None
     if var_is_arr is not None or ret_arr_len is not None:
         assert(not take_by_ptr)
         assert(not is_ptr)
@@ -344,20 +356,28 @@ def java_c_types(fn_arg, ret_arr_len):
         else:
             java_ty = java_ty + "[]"
         c_ty = c_ty + "Array"
+
+        subty = java_c_types(arr_ty, None)
+        if subty is None:
+            assert java_c_types_none_allowed
+            return None
+        if is_ptr:
+            subty.pass_by_ref = True
+
         if var_is_arr is not None:
             if var_is_arr.group(1) == "":
                 return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
-                    passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg",
+                    passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg", subty=subty,
                     arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait)
             return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
-                passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1),
+                passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1), subty=subty,
                 arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait)
 
     if java_hu_ty is None:
         java_hu_ty = java_ty
     return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr,
         is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive,
-        contains_trait=contains_trait)
+        contains_trait=contains_trait, subty=subty)
 
 fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$")
 fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")