From d294ad273a3d99aa2857efb10b225a2e20c84a16 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 11 Nov 2021 17:34:21 +0000 Subject: [PATCH] Make enum-contains-trait detection more robust --- bindingstypes.py | 3 ++- gen_type_mapping.py | 17 +++++++++-------- genbindings.py | 23 +++++++++++++++++------ 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/bindingstypes.py b/bindingstypes.py index edb19565..44b4182e 100644 --- a/bindingstypes.py +++ b/bindingstypes.py @@ -1,5 +1,5 @@ class TypeInfo: - def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, subty=None): + def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, subty=None, contains_trait=False): self.is_native_primitive = is_native_primitive self.rust_obj = rust_obj self.java_ty = java_ty @@ -16,6 +16,7 @@ class TypeInfo: self.subty = subty self.pass_by_ref = is_ptr self.requires_clone = None + self.contains_trait = contains_trait def get_full_rust_ty(self): ret = "" diff --git a/gen_type_mapping.py b/gen_type_mapping.py index 81bb0f99..cffdc839 100644 --- a/gen_type_mapping.py +++ b/gen_type_mapping.py @@ -374,15 +374,16 @@ class TypeMappingGenerator: if needs_full_clone and (ty_info.rust_obj.replace("LDK", "") + "_clone") not in self.clone_fns: # We really need a full clone here, but for now we just implement # a manual clone explicitly for Options - if ty_info.rust_obj.startswith("LDKCOption"): + if ty_info.contains_trait: + assert ty_info.rust_obj.startswith("LDKCOption") # We don't support contained traits for anything else yet optional_ty = ty_info.rust_obj[11:-1] - if "LDK" + optional_ty in self.trait_structs: - base_conv += "\nif (" + ty_info.var_name + "_conv.tag == " + ty_info.rust_obj + "_Some) {" - base_conv += "\n\t// Manually implement clone for Java trait instances" - optional_ty_info = self.java_c_types("LDK" + optional_ty + " " + ty_info.var_name, None) - base_conv += self.consts.trait_struct_inc_refcnt(optional_ty_info).\ - replace("\n", "\n\t").replace(ty_info.var_name + "_conv", ty_info.var_name + "_conv.some") - base_conv += "\n}" + assert "LDK" + optional_ty in self.trait_structs # We don't support contained traits for anything else yet + base_conv += "\nif (" + ty_info.var_name + "_conv.tag == " + ty_info.rust_obj + "_Some) {" + base_conv += "\n\t// Manually implement clone for Java trait instances" + optional_ty_info = self.java_c_types("LDK" + optional_ty + " " + ty_info.var_name, None) + base_conv += self.consts.trait_struct_inc_refcnt(optional_ty_info).\ + replace("\n", "\n\t").replace(ty_info.var_name + "_conv", ty_info.var_name + "_conv.some") + base_conv += "\n}" ret_conv = ("uint64_t " + ty_info.var_name + "_ref = ((uint64_t)&", ") | 1;") if not holds_ref: ret_conv = (ty_info.rust_obj + " *" + ty_info.var_name + "_copy = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n", "") diff --git a/genbindings.py b/genbindings.py index c46644b8..a2c66863 100755 --- a/genbindings.py +++ b/genbindings.py @@ -94,7 +94,8 @@ def doc_to_params_ret_nullable(doc): return (params, ret_null) unitary_enums = set() -complex_enums = set() +# Map from enum name to "contains trait object" +complex_enums = {} opaque_structs = set() trait_structs = {} result_types = set() @@ -225,6 +226,7 @@ def java_c_types(fn_arg, ret_arr_len): var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False) is_primitive = False + contains_trait = False arr_len = None mapped_type = [] java_type_plural = None @@ -316,6 +318,10 @@ def java_c_types(fn_arg, ret_arr_len): fn_ty_arg = "J" fn_arg = ma.group(2).strip() rust_obj = ma.group(1).strip() + if rust_obj in trait_structs: + contains_trait = True + elif rust_obj in complex_enums: + contains_trait = complex_enums[rust_obj] take_by_ptr = True if fn_arg.startswith(" *") or fn_arg.startswith("*"): @@ -339,15 +345,16 @@ def java_c_types(fn_arg, ret_arr_len): if var_is_arr.group(1) == "": return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg", - arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False) + arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait) return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1), - arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False) + arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait) if java_hu_ty is None: java_hu_ty = java_ty return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr, - is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive) + is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive, + contains_trait=contains_trait) fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$") fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$") @@ -539,10 +546,10 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}", def map_complex_enum(struct_name, union_enum_items, inline_enum_variants, enum_doc_comment): java_hu_type = struct_name.replace("LDK", "").replace("COption", "Option") - complex_enums.add(struct_name) enum_variants = [] tag_field_lines = union_enum_items["field_lines"] + contains_trait = False for idx, (struct_line, _) in enumerate(tag_field_lines): if idx == 0: assert(struct_line == "typedef enum %s_Tag {" % struct_name) @@ -560,6 +567,7 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}", for idx, (field, field_docs) in enumerate(enum_var_lines): if idx != 0 and idx < len(enum_var_lines) - 2 and field.strip() != "": field_ty = type_mapping_generator.java_c_types(field.strip(' ;'), None) + contains_trait |= field_ty.contains_trait if field_docs is not None and doc_to_field_nullable(field_docs): field_conv = type_mapping_generator.map_type_with_info(field_ty, False, None, False, True, True) else: @@ -569,10 +577,13 @@ with open(sys.argv[1]) as in_h, open(f"{sys.argv[2]}/bindings{consts.file_ext}", elif camel_to_snake(variant_name) in inline_enum_variants: # TODO: If we ever have a rust enum Variant(Option) we need to pipe # docs through to there, and then potentially mark the field nullable. - fields.append((type_mapping_generator.map_type(inline_enum_variants[camel_to_snake(variant_name)] + " " + camel_to_snake(variant_name), False, None, False, True), None)) + mapped = type_mapping_generator.map_type(inline_enum_variants[camel_to_snake(variant_name)] + " " + camel_to_snake(variant_name), False, None, False, True) + contains_trait |= mapped.ty_info.contains_trait + fields.append((mapped, None)) enum_variants.append(ComplexEnumVariantInfo(variant_name, fields, True)) else: enum_variants.append(ComplexEnumVariantInfo(variant_name, fields, True)) + complex_enums[struct_name] = contains_trait with open(f"{sys.argv[3]}/structs/{java_hu_type}{consts.file_ext}", "w") as out_java_enum: (out_java_addendum, out_java_enum_addendum, out_c_addendum) = consts.map_complex_enum(struct_name, enum_variants, camel_to_snake, enum_doc_comment) -- 2.30.2