Make enum-contains-trait detection more robust
authorMatt Corallo <git@bluematt.me>
Thu, 11 Nov 2021 17:34:21 +0000 (17:34 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 11 Nov 2021 17:34:21 +0000 (17:34 +0000)
bindingstypes.py
gen_type_mapping.py
genbindings.py

index edb19565140c0c435892615810e59986414b6778..44b4182eda41814785c219754e51db1891176dee 100644 (file)
@@ -1,5 +1,5 @@
 class TypeInfo:
-    def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, subty=None):
+    def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, subty=None, contains_trait=False):
         self.is_native_primitive = is_native_primitive
         self.rust_obj = rust_obj
         self.java_ty = java_ty
@@ -16,6 +16,7 @@ class TypeInfo:
         self.subty = subty
         self.pass_by_ref = is_ptr
         self.requires_clone = None
+        self.contains_trait = contains_trait
 
     def get_full_rust_ty(self):
         ret = ""
index 81bb0f99b8ebcaa5236073fa627514c8d224a212..cffdc83965bfe51dc9075b11541d05e336dd0ffc 100644 (file)
@@ -374,15 +374,16 @@ class TypeMappingGenerator:
                     if needs_full_clone and (ty_info.rust_obj.replace("LDK", "") + "_clone") not in self.clone_fns:
                         # We really need a full clone here, but for now we just implement
                         # a manual clone explicitly for Option<Trait>s
-                        if ty_info.rust_obj.startswith("LDKCOption"):
+                        if ty_info.contains_trait:
+                            assert ty_info.rust_obj.startswith("LDKCOption") # We don't support contained traits for anything else yet
                             optional_ty = ty_info.rust_obj[11:-1]
-                            if "LDK" + optional_ty in self.trait_structs:
-                                base_conv += "\nif (" + ty_info.var_name + "_conv.tag == " + ty_info.rust_obj + "_Some) {"
-                                base_conv += "\n\t// Manually implement clone for Java trait instances"
-                                optional_ty_info = self.java_c_types("LDK" + optional_ty + " " + ty_info.var_name, None)
-                                base_conv += self.consts.trait_struct_inc_refcnt(optional_ty_info).\
-                                    replace("\n", "\n\t").replace(ty_info.var_name + "_conv", ty_info.var_name + "_conv.some")
-                                base_conv += "\n}"
+                            assert "LDK" + optional_ty in self.trait_structs # We don't support contained traits for anything else yet
+                            base_conv += "\nif (" + ty_info.var_name + "_conv.tag == " + ty_info.rust_obj + "_Some) {"
+                            base_conv += "\n\t// Manually implement clone for Java trait instances"
+                            optional_ty_info = self.java_c_types("LDK" + optional_ty + " " + ty_info.var_name, None)
+                            base_conv += self.consts.trait_struct_inc_refcnt(optional_ty_info).\
+                                replace("\n", "\n\t").replace(ty_info.var_name + "_conv", ty_info.var_name + "_conv.some")
+                            base_conv += "\n}"
                     ret_conv = ("uint64_t " + ty_info.var_name + "_ref = ((uint64_t)&", ") | 1;")
                     if not holds_ref:
                         ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_copy = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n", "")
index c46644b8ee746eb76a801db18c4bf2c44eff2f1c..a2c66863677fe9d7fe758553c1d9a595561af8f8 100755 (executable)
@@ -94,7 +94,8 @@ def doc_to_params_ret_nullable(doc):
     return (params, ret_null)
 
 unitary_enums = set()
-complex_enums = set()
+# Map from enum name to "contains trait object"
+complex_enums = {}
 opaque_structs = set()
 trait_structs = {}
 result_types = set()
@@ -225,6 +226,7 @@ def java_c_types(fn_arg, ret_arr_len):
                 var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False)
 
     is_primitive = False
+    contains_trait = False
     arr_len = None
     mapped_type = []
     java_type_plural = None
@@ -316,6 +318,10 @@ def java_c_types(fn_arg, ret_arr_len):
             fn_ty_arg = "J"
             fn_arg = ma.group(2).strip()
             rust_obj = ma.group(1).strip()
+            if rust_obj in trait_structs:
+                contains_trait = True
+            elif rust_obj in complex_enums:
+                contains_trait = complex_enums[rust_obj]
             take_by_ptr = True
 
     if fn_arg.startswith(" *") or fn_arg.startswith("*"):
@@ -339,15 +345,16 @@ def java_c_types(fn_arg, ret_arr_len):
             if var_is_arr.group(1) == "":
                 return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
                     passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg",
-                    arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False)
+                    arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait)
             return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
                 passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1),
-                arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False)
+                arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait)
 
     if java_hu_ty is None:
         java_hu_ty = java_ty
     return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr,
-        is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive)
+        is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive,
+        contains_trait=contains_trait)
 
 fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$")
 fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")
@@ -539,10 +546,10 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
 
     def map_complex_enum(struct_name, union_enum_items, inline_enum_variants, enum_doc_comment):
         java_hu_type = struct_name.replace("LDK", "").replace("COption", "Option")
-        complex_enums.add(struct_name)
 
         enum_variants = []
         tag_field_lines = union_enum_items["field_lines"]
+        contains_trait = False
         for idx, (struct_line, _) in enumerate(tag_field_lines):
             if idx == 0:
                 assert(struct_line == "typedef enum %s_Tag {" % struct_name)
@@ -560,6 +567,7 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
                     for idx, (field, field_docs) in enumerate(enum_var_lines):
                         if idx != 0 and idx < len(enum_var_lines) - 2 and field.strip() != "":
                             field_ty = type_mapping_generator.java_c_types(field.strip(' ;'), None)
+                            contains_trait |= field_ty.contains_trait
                             if field_docs is not None and doc_to_field_nullable(field_docs):
                                 field_conv = type_mapping_generator.map_type_with_info(field_ty, False, None, False, True, True)
                             else:
@@ -569,10 +577,13 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}",
                 elif camel_to_snake(variant_name) in inline_enum_variants:
                     # TODO: If we ever have a rust enum Variant(Option<Struct>) we need to pipe
                     # docs through to there, and then potentially mark the field nullable.
-                    fields.append((type_mapping_generator.map_type(inline_enum_variants[camel_to_snake(variant_name)] + " " + camel_to_snake(variant_name), False, None, False, True), None))
+                    mapped = type_mapping_generator.map_type(inline_enum_variants[camel_to_snake(variant_name)] + " " + camel_to_snake(variant_name), False, None, False, True)
+                    contains_trait |= mapped.ty_info.contains_trait
+                    fields.append((mapped, None))
                     enum_variants.append(ComplexEnumVariantInfo(variant_name, fields, True))
                 else:
                     enum_variants.append(ComplexEnumVariantInfo(variant_name, fields, True))
+        complex_enums[struct_name] = contains_trait
 
         with open(f"{sys.argv[3]}/structs/{java_hu_type}{consts.file_ext}", "w") as out_java_enum:
             (out_java_addendum, out_java_enum_addendum, out_c_addendum) = consts.map_complex_enum(struct_name, enum_variants, camel_to_snake, enum_doc_comment)