From: Matt Corallo Date: Sat, 3 Aug 2024 02:37:12 +0000 (+0000) Subject: Be more conservative about holding a ref to a cloned object X-Git-Tag: v0.0.124.0^2~10 X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=29a9573986a462b45368ad49a93d1362e86cac4a;p=ldk-java Be more conservative about holding a ref to a cloned object If we clone an object before passing to/from native code, we generally don't need to hold a reference to that object in the struct that owns the method. The one exception we include here is if the object being passed is either a trait or holds a reference to a trait. Its not clear if this is required but it seems like a potential sharp edge. --- diff --git a/bindingstypes.py b/bindingstypes.py index 8c3f0db8..7257e5eb 100644 --- a/bindingstypes.py +++ b/bindingstypes.py @@ -1,5 +1,5 @@ class TypeInfo: - def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, subty=None, contains_trait=False): + def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, is_trait, subty=None, contains_trait=False): self.is_native_primitive = is_native_primitive self.rust_obj = rust_obj self.java_ty = java_ty @@ -16,6 +16,7 @@ class TypeInfo: self.subty = subty self.pass_by_ref = is_ptr self.requires_clone = None + self.is_trait = is_trait self.contains_trait = contains_trait def get_full_rust_ty(self): diff --git a/gen_type_mapping.py b/gen_type_mapping.py index dc9899bd..0d5020c0 100644 --- a/gen_type_mapping.py +++ b/gen_type_mapping.py @@ -300,7 +300,7 @@ class TypeMappingGenerator: else: from_hu_conv = (self.consts.get_ptr(ty_info.var_name), self.consts.add_ref("this", ty_info.var_name)) opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n" - opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = untag_ptr(" + ty_info.var_name + ");\n" + opaque_arg_conv += ty_info.var_name + "_conv.inner = untag_ptr(" + ty_info.var_name + ");\n" opaque_arg_conv += ty_info.var_name + "_conv.is_owned = ptr_is_owned(" + ty_info.var_name + ");\n" opaque_arg_conv += "CHECK_INNER_FIELD_ACCESS_OR_NULL(" + ty_info.var_name + "_conv);" @@ -313,6 +313,8 @@ class TypeMappingGenerator: # whereas in the first we prefer to clone in C to avoid additional Java code as much as possible. if holds_ref: opaque_arg_conv += "\n" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);" + if not ty_info.pass_by_ref and not ty_info.contains_trait and not ty_info.is_trait: + from_hu_conv = (from_hu_conv[0], "") elif is_nullable: from_hu_conv = (ty_info.var_name + " == null ? " + self.consts.native_zero_ptr + " : " + ty_info.var_name + ".clone_ptr()", "") else: @@ -415,9 +417,10 @@ class TypeMappingGenerator: to_hu_conv = self.consts.var_decl_statement(ty_info.java_hu_ty, "ret_hu_conv", "new " + ty_info.java_hu_ty + "(null, " + ty_info.var_name + ")") + ";\n" + self.consts.add_ref("ret_hu_conv", "this") + ";", to_hu_conv_name = "ret_hu_conv", from_hu_conv = from_hu_conv) needs_full_clone = not is_free and (not ty_info.is_ptr or ty_info.requires_clone == True) and ty_info.requires_clone != False + from_hu_add_ref = "" + if ty_info.contains_trait or ty_info.is_trait or needs_full_clone: + from_hu_add_ref = self.consts.add_ref("this", ty_info.var_name) if needs_full_clone: - if "res" in ty_info.var_name: # XXX: This is a stupid hack - needs_full_clone = False if needs_full_clone and (ty_info.rust_obj.replace("LDK", "") + "_clone") in self.clone_fns: # arg_conv is used when converting a function argument from java normally (with holds_ref set), # and when converting a java value being returned from a trait method (with holds_ref unset). @@ -425,8 +428,12 @@ class TypeMappingGenerator: # whereas in the first we prefer to clone in C to avoid additional Java code as much as possible. if holds_ref: base_conv += "\n" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone((" + ty_info.rust_obj + "*)untag_ptr(" + ty_info.var_name + "));" + if not ty_info.pass_by_ref and not ty_info.contains_trait and not ty_info.is_trait: + from_hu_add_ref = "" else: - from_hu_conv = (ty_info.var_name + ".clone_ptr()", "") + if not ty_info.pass_by_ref and not ty_info.contains_trait and not ty_info.is_trait: + from_hu_add_ref = "" + from_hu_conv = (ty_info.var_name + ".clone_ptr()", from_hu_add_ref) base_conv += "\n" + "FREE(untag_ptr(" + ty_info.var_name + "));" elif needs_full_clone: base_conv = base_conv + "\n// WARNING: we may need a move here but no clone is available for " + ty_info.rust_obj @@ -454,8 +461,7 @@ class TypeMappingGenerator: ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = ", "") ret_conv = (ret_conv[0], ";\n" + self.consts.ptr_c_ty + " " + ty_info.var_name + "_ref = tag_ptr(" + ty_info.var_name + "_copy, true);") if from_hu_conv is None: - from_hu_conv = (self.consts.get_ptr(ty_info.var_name), "") - from_hu_conv = (from_hu_conv[0], self.consts.add_ref("this", ty_info.var_name)) + from_hu_conv = (self.consts.get_ptr(ty_info.var_name), from_hu_add_ref) fully_qualified_ty = self.consts.fully_qualified_hu_ty_path(ty_info) to_hu_call = fully_qualified_ty + ".constr_from_ptr(" + ty_info.var_name + ")" return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, @@ -472,7 +478,7 @@ class TypeMappingGenerator: else: ret_conv = (ty_info.rust_obj + "* " + ty_info.var_name + "_conv = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*" + ty_info.var_name + "_conv = ", ";") if from_hu_conv is None: - from_hu_conv = (self.consts.get_ptr(ty_info.var_name), "") + from_hu_conv = (self.consts.get_ptr(ty_info.var_name), from_hu_add_ref) return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None, ret_conv = ret_conv, ret_conv_name = "tag_ptr(" + ty_info.var_name + "_conv, true)", @@ -498,7 +504,7 @@ class TypeMappingGenerator: else: to_hu_conv_sfx = "" if from_hu_conv is None: - from_hu_conv = (self.consts.get_ptr(ty_info.var_name), "") + from_hu_conv = (self.consts.get_ptr(ty_info.var_name), from_hu_add_ref) return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name, arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None, ret_conv = ret_conv, ret_conv_name = ret_conv_name, diff --git a/genbindings.py b/genbindings.py index 9ae48c4e..432ed9d0 100755 --- a/genbindings.py +++ b/genbindings.py @@ -262,15 +262,16 @@ def java_c_types(fn_arg, ret_arr_len): if is_ptr: res.pass_by_ref = True java_ty = consts.java_arr_ty_str(res.java_ty) + is_trait = res.rust_obj in trait_structs if res.is_native_primitive or res.passed_as_ptr: return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=java_ty, java_hu_ty=res.java_hu_ty + "[]", java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=res.c_ty + "Array", passed_as_ptr=False, is_ptr=is_ptr, - nonnull_ptr=nonnull_ptr, is_const=is_const, + nonnull_ptr=nonnull_ptr, is_const=is_const, is_trait = is_trait, var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False) else: return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=java_ty, java_hu_ty=res.java_hu_ty + "[]", java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=consts.ptr_arr, passed_as_ptr=False, is_ptr=is_ptr, - nonnull_ptr=nonnull_ptr, is_const=is_const, + nonnull_ptr=nonnull_ptr, is_const=is_const, is_trait = is_trait, var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False) is_primitive = False @@ -428,6 +429,7 @@ def java_c_types(fn_arg, ret_arr_len): var_is_arr = var_is_arr_regex.match(fn_arg) subty = None + is_trait = rust_obj in trait_structs if var_is_arr is not None or ret_arr_len is not None: assert(not take_by_ptr) assert(not is_ptr) @@ -453,16 +455,18 @@ def java_c_types(fn_arg, ret_arr_len): if var_is_arr.group(1) == "": return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg", subty=subty, - arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait) + arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, + is_trait=is_trait, contains_trait=contains_trait) return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const, passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1), subty=subty, - arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait) + arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, + is_trait=is_trait, contains_trait=contains_trait) if java_hu_ty is None: java_hu_ty = java_ty return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr, is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive, - contains_trait=contains_trait, subty=subty) + contains_trait=contains_trait, subty=subty, is_trait=is_trait) fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$") fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")