]> git.bitcoin.ninja Git - ldk-java/commitdiff
Be more conservative about holding a ref to a cloned object
authorMatt Corallo <git@bluematt.me>
Sat, 3 Aug 2024 02:37:12 +0000 (02:37 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 3 Sep 2024 14:10:11 +0000 (14:10 +0000)
If we clone an object before passing to/from native code, we
generally don't need to hold a reference to that object in the
struct that owns the method.

The one exception we include here is if the object being passed is
either a trait or holds a reference to a trait. Its not clear if
this is required but it seems like a potential sharp edge.

bindingstypes.py
gen_type_mapping.py
genbindings.py

index 8c3f0db8e7e520a5cbde12a338e1e4fcb2b4ebfe..7257e5eb245d857eb9dbff284490274622898a3a 100644 (file)
@@ -1,5 +1,5 @@
 class TypeInfo:
-    def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, subty=None, contains_trait=False):
+    def __init__(self, is_native_primitive, rust_obj, java_ty, java_fn_ty_arg, java_hu_ty, c_ty, is_const, passed_as_ptr, is_ptr, nonnull_ptr, var_name, arr_len, arr_access, is_trait, subty=None, contains_trait=False):
         self.is_native_primitive = is_native_primitive
         self.rust_obj = rust_obj
         self.java_ty = java_ty
@@ -16,6 +16,7 @@ class TypeInfo:
         self.subty = subty
         self.pass_by_ref = is_ptr
         self.requires_clone = None
+        self.is_trait = is_trait
         self.contains_trait = contains_trait
 
     def get_full_rust_ty(self):
index dc9899bdb1f9edbbaaa487a96bb6ed9b4fc02314..0d5020c0573fa44736aeaca1dbb3947be48640ad 100644 (file)
@@ -300,7 +300,7 @@ class TypeMappingGenerator:
                 else:
                     from_hu_conv = (self.consts.get_ptr(ty_info.var_name), self.consts.add_ref("this", ty_info.var_name))
                 opaque_arg_conv = ty_info.rust_obj + " " + ty_info.var_name + "_conv;\n"
-                opaque_arg_conv = opaque_arg_conv + ty_info.var_name + "_conv.inner = untag_ptr(" + ty_info.var_name + ");\n"
+                opaque_arg_conv += ty_info.var_name + "_conv.inner = untag_ptr(" + ty_info.var_name + ");\n"
                 opaque_arg_conv += ty_info.var_name + "_conv.is_owned = ptr_is_owned(" + ty_info.var_name + ");\n"
                 opaque_arg_conv += "CHECK_INNER_FIELD_ACCESS_OR_NULL(" + ty_info.var_name + "_conv);"
 
@@ -313,6 +313,8 @@ class TypeMappingGenerator:
                         # whereas in the first we prefer to clone in C to avoid additional Java code as much as possible.
                         if holds_ref:
                             opaque_arg_conv += "\n" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone(&" + ty_info.var_name + "_conv);"
+                            if not ty_info.pass_by_ref and not ty_info.contains_trait and not ty_info.is_trait:
+                                from_hu_conv = (from_hu_conv[0], "")
                         elif is_nullable:
                             from_hu_conv = (ty_info.var_name + " == null ? " + self.consts.native_zero_ptr + " : " + ty_info.var_name + ".clone_ptr()", "")
                         else:
@@ -415,9 +417,10 @@ class TypeMappingGenerator:
                         to_hu_conv = self.consts.var_decl_statement(ty_info.java_hu_ty, "ret_hu_conv", "new " + ty_info.java_hu_ty + "(null, " + ty_info.var_name + ")") + ";\n" + self.consts.add_ref("ret_hu_conv", "this") + ";",
                         to_hu_conv_name = "ret_hu_conv", from_hu_conv = from_hu_conv)
                 needs_full_clone = not is_free and (not ty_info.is_ptr or ty_info.requires_clone == True) and ty_info.requires_clone != False
+                from_hu_add_ref = ""
+                if ty_info.contains_trait or ty_info.is_trait or needs_full_clone:
+                    from_hu_add_ref = self.consts.add_ref("this", ty_info.var_name)
                 if needs_full_clone:
-                    if "res" in ty_info.var_name: # XXX: This is a stupid hack
-                        needs_full_clone = False
                     if needs_full_clone and (ty_info.rust_obj.replace("LDK", "") + "_clone") in self.clone_fns:
                         # arg_conv is used when converting a function argument from java normally (with holds_ref set),
                         # and when converting a java value being returned from a trait method (with holds_ref unset).
@@ -425,8 +428,12 @@ class TypeMappingGenerator:
                         # whereas in the first we prefer to clone in C to avoid additional Java code as much as possible.
                         if holds_ref:
                             base_conv += "\n" + ty_info.var_name + "_conv = " + ty_info.rust_obj.replace("LDK", "") + "_clone((" + ty_info.rust_obj + "*)untag_ptr(" + ty_info.var_name + "));"
+                            if not ty_info.pass_by_ref and not ty_info.contains_trait and not ty_info.is_trait:
+                                from_hu_add_ref = ""
                         else:
-                            from_hu_conv = (ty_info.var_name + ".clone_ptr()", "")
+                            if not ty_info.pass_by_ref and not ty_info.contains_trait and not ty_info.is_trait:
+                                from_hu_add_ref = ""
+                            from_hu_conv = (ty_info.var_name + ".clone_ptr()", from_hu_add_ref)
                             base_conv += "\n" + "FREE(untag_ptr(" + ty_info.var_name + "));"
                     elif needs_full_clone:
                         base_conv = base_conv + "\n// WARNING: we may need a move here but no clone is available for " + ty_info.rust_obj
@@ -454,8 +461,7 @@ class TypeMappingGenerator:
                         ret_conv = (ret_conv[0] + "*" + ty_info.var_name + "_copy = ", "")
                         ret_conv = (ret_conv[0], ";\n" + self.consts.ptr_c_ty + " " + ty_info.var_name + "_ref = tag_ptr(" + ty_info.var_name + "_copy, true);")
                     if from_hu_conv is None:
-                        from_hu_conv = (self.consts.get_ptr(ty_info.var_name), "")
-                    from_hu_conv = (from_hu_conv[0], self.consts.add_ref("this", ty_info.var_name))
+                        from_hu_conv = (self.consts.get_ptr(ty_info.var_name), from_hu_add_ref)
                     fully_qualified_ty = self.consts.fully_qualified_hu_ty_path(ty_info)
                     to_hu_call = fully_qualified_ty + ".constr_from_ptr(" + ty_info.var_name + ")"
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
@@ -472,7 +478,7 @@ class TypeMappingGenerator:
                     else:
                         ret_conv = (ty_info.rust_obj + "* " + ty_info.var_name + "_conv = MALLOC(sizeof(" + ty_info.rust_obj + "), \"" + ty_info.rust_obj + "\");\n*" + ty_info.var_name + "_conv = ", ";")
                     if from_hu_conv is None:
-                        from_hu_conv = (self.consts.get_ptr(ty_info.var_name), "")
+                        from_hu_conv = (self.consts.get_ptr(ty_info.var_name), from_hu_add_ref)
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                         arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                         ret_conv = ret_conv, ret_conv_name = "tag_ptr(" + ty_info.var_name + "_conv, true)",
@@ -498,7 +504,7 @@ class TypeMappingGenerator:
                     else:
                         to_hu_conv_sfx = ""
                     if from_hu_conv is None:
-                        from_hu_conv = (self.consts.get_ptr(ty_info.var_name), "")
+                        from_hu_conv = (self.consts.get_ptr(ty_info.var_name), from_hu_add_ref)
                     return ConvInfo(ty_info = ty_info, arg_name = ty_info.var_name,
                         arg_conv = base_conv, arg_conv_name = ty_info.var_name + "_conv", arg_conv_cleanup = None,
                         ret_conv = ret_conv, ret_conv_name = ret_conv_name,
index 9ae48c4e234dae675047412f5508fd4ccd39e5d5..432ed9d088adca17f975973e526855c1f8d7039f 100755 (executable)
@@ -262,15 +262,16 @@ def java_c_types(fn_arg, ret_arr_len):
         if is_ptr:
             res.pass_by_ref = True
         java_ty = consts.java_arr_ty_str(res.java_ty)
+        is_trait = res.rust_obj in trait_structs
         if res.is_native_primitive or res.passed_as_ptr:
             return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=java_ty, java_hu_ty=res.java_hu_ty + "[]",
                 java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=res.c_ty + "Array", passed_as_ptr=False, is_ptr=is_ptr,
-                nonnull_ptr=nonnull_ptr, is_const=is_const,
+                nonnull_ptr=nonnull_ptr, is_const=is_const, is_trait = is_trait,
                 var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False)
         else:
             return TypeInfo(rust_obj=fn_arg.split(" ")[0], java_ty=java_ty, java_hu_ty=res.java_hu_ty + "[]",
                 java_fn_ty_arg="[" + res.java_fn_ty_arg, c_ty=consts.ptr_arr, passed_as_ptr=False, is_ptr=is_ptr,
-                nonnull_ptr=nonnull_ptr, is_const=is_const,
+                nonnull_ptr=nonnull_ptr, is_const=is_const, is_trait = is_trait,
                 var_name=res.var_name, arr_len="datalen", arr_access="data", subty=res, is_native_primitive=False)
 
     is_primitive = False
@@ -428,6 +429,7 @@ def java_c_types(fn_arg, ret_arr_len):
 
     var_is_arr = var_is_arr_regex.match(fn_arg)
     subty = None
+    is_trait = rust_obj in trait_structs
     if var_is_arr is not None or ret_arr_len is not None:
         assert(not take_by_ptr)
         assert(not is_ptr)
@@ -453,16 +455,18 @@ def java_c_types(fn_arg, ret_arr_len):
             if var_is_arr.group(1) == "":
                 return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
                     passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name="arg", subty=subty,
-                    arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait)
+                    arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False,
+                    is_trait=is_trait, contains_trait=contains_trait)
             return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg="[" + fn_ty_arg, c_ty=c_ty, is_const=is_const,
                 passed_as_ptr=False, is_ptr=False, nonnull_ptr=nonnull_ptr, var_name=var_is_arr.group(1), subty=subty,
-                arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False, contains_trait=contains_trait)
+                arr_len=var_is_arr.group(2), arr_access=arr_access, is_native_primitive=False,
+                is_trait=is_trait, contains_trait=contains_trait)
 
     if java_hu_ty is None:
         java_hu_ty = java_ty
     return TypeInfo(rust_obj=rust_obj, java_ty=java_ty, java_hu_ty=java_hu_ty, java_fn_ty_arg=fn_ty_arg, c_ty=c_ty, passed_as_ptr=is_ptr or take_by_ptr,
         is_const=is_const, is_ptr=is_ptr, nonnull_ptr=nonnull_ptr, var_name=fn_arg, arr_len=arr_len, arr_access=arr_access, is_native_primitive=is_primitive,
-        contains_trait=contains_trait, subty=subty)
+        contains_trait=contains_trait, subty=subty, is_trait=is_trait)
 
 fn_ptr_regex = re.compile("^extern const ([A-Za-z_0-9\* ]*) \(\*(.*)\)\((.*)\);$")
 fn_ret_arr_regex = re.compile("(.*) \(\*(.*)\((.*)\)\)\[([0-9]*)\];$")