From 26d603f8001cbe519244b34d0480d86d820c47f9 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 26 Aug 2024 18:52:07 +0000 Subject: [PATCH] Add a `*Ref` struct for traits and `Deref` to it rather than `Self` Because LDK often takes `Deref` as type parameters, we'd implemented `Deref { type Target=Self; .. }` for the traits defined in the bindings crate. This worked well, but because Rust auto-`Deref`s, it can lead to spurious compilation failures due to infinite recursion trying to deref. In the past we've worked around this by coming up with alternative compilation strategies when faced with `Deref` recursion, but we don't strictly need to. Instead, here, we introduce duplicates of all our `Trait` structs which become the `Deref` `Target`. This way, we can `Deref` into the `Trait` and maintain LDK compatibility, without having any infinite `Deref` recursion issues. One complication is traits which contain associated types to define `Deref`s to another associated type, e.g. trait A { type B; type C: Deref; } In this case, `B` needs to be the `TraitRef` and `C` needs to be the `Trait`. We add code specifically to detect this case. --- c-bindings-gen/src/main.rs | 77 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 3 deletions(-) diff --git a/c-bindings-gen/src/main.rs b/c-bindings-gen/src/main.rs index 85036a2..08a4328 100644 --- a/c-bindings-gen/src/main.rs +++ b/c-bindings-gen/src/main.rs @@ -225,6 +225,11 @@ fn do_write_impl_trait(w: &mut W, trait_path: &str, _trait_na writeln!(w, "\t\tlet vec = (self.write)(self.this_arg);").unwrap(); writeln!(w, "\t\tw.write_all(vec.as_slice())").unwrap(); writeln!(w, "\t}}\n}}").unwrap(); + writeln!(w, "impl {} for {}Ref {{", trait_path, for_obj).unwrap(); + writeln!(w, "\tfn write(&self, w: &mut W) -> Result<(), crate::c_types::io::Error> {{").unwrap(); + writeln!(w, "\t\tlet vec = (self.0.write)(self.0.this_arg);").unwrap(); + writeln!(w, "\t\tw.write_all(vec.as_slice())").unwrap(); + writeln!(w, "\t}}\n}}").unwrap(); }, _ => panic!(), } @@ -450,6 +455,47 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty ($t: expr, $impl_accessor: expr, $type_resolver: expr, $generic_impls: expr) => { let mut trait_gen_types = gen_types.push_ctx(); assert!(trait_gen_types.learn_generics_with_impls(&$t.generics, $generic_impls, $type_resolver)); + + let mut ref_types = HashSet::new(); + for item in $t.items.iter() { + if let syn::TraitItem::Type(ref t) = &item { + if t.default.is_some() || t.generics.lt_token.is_some() { panic!("10"); } + let mut bounds_iter = t.bounds.iter(); + loop { + match bounds_iter.next().unwrap() { + syn::TypeParamBound::Trait(tr) => { + match $type_resolver.resolve_path(&tr.path, None).as_str() { + "core::ops::Deref"|"core::ops::DerefMut"|"std::ops::Deref"|"std::ops::DerefMut" => { + // Handle cases like + // trait A { + // type B; + // type C: Deref; + // } + // by tracking if we have any B's here and making them + // the *Ref types below. + if let syn::PathArguments::AngleBracketed(args) = &tr.path.segments.iter().last().unwrap().arguments { + if let syn::GenericArgument::Binding(bind) = args.args.iter().last().unwrap() { + assert_eq!(format!("{}", bind.ident), "Target"); + if let syn::Type::Path(p) = &bind.ty { + assert!(p.qself.is_none()); + let mut segs = p.path.segments.iter(); + assert_eq!(format!("{}", segs.next().unwrap().ident), "Self"); + ref_types.insert(format!("{}", segs.next().unwrap().ident)); + assert!(segs.next().is_none()); + } else { panic!(); } + } + } + }, + _ => {}, + } + break; + } + syn::TypeParamBound::Lifetime(_) => {}, + } + } + } + } + for item in $t.items.iter() { match item { syn::TraitItem::Method(m) => { @@ -538,7 +584,11 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty loop { match bounds_iter.next().unwrap() { syn::TypeParamBound::Trait(tr) => { - writeln!(w, "\ttype {} = crate::{};", t.ident, $type_resolver.resolve_path(&tr.path, Some(&gen_types))).unwrap(); + write!(w, "\ttype {} = crate::{}", t.ident, $type_resolver.resolve_path(&tr.path, Some(&gen_types))).unwrap(); + if ref_types.contains(&format!("{}", t.ident)) { + write!(w, "Ref").unwrap(); + } + writeln!(w, ";").unwrap(); for bound in bounds_iter { if let syn::TypeParamBound::Trait(t) = bound { // We only allow for `Sized` here. @@ -581,10 +631,15 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty writeln!(w, "impl core::cmp::Eq for {} {{}}", trait_name).unwrap(); writeln!(w, "impl core::cmp::PartialEq for {} {{", trait_name).unwrap(); writeln!(w, "\tfn eq(&self, o: &Self) -> bool {{ (self.eq)(self.this_arg, o) }}\n}}").unwrap(); + writeln!(w, "impl core::cmp::Eq for {}Ref {{}}", trait_name).unwrap(); + writeln!(w, "impl core::cmp::PartialEq for {}Ref {{", trait_name).unwrap(); + writeln!(w, "\tfn eq(&self, o: &Self) -> bool {{ (self.0.eq)(self.0.this_arg, &o.0) }}\n}}").unwrap(); }, ("std::hash::Hash", _, _)|("core::hash::Hash", _, _) => { writeln!(w, "impl core::hash::Hash for {} {{", trait_name).unwrap(); writeln!(w, "\tfn hash(&self, hasher: &mut H) {{ hasher.write_u64((self.hash)(self.this_arg)) }}\n}}").unwrap(); + writeln!(w, "impl core::hash::Hash for {}Ref {{", trait_name).unwrap(); + writeln!(w, "\tfn hash(&self, hasher: &mut H) {{ hasher.write_u64((self.0.hash)(self.0.this_arg)) }}\n}}").unwrap(); }, ("Send", _, _) => {}, ("Sync", _, _) => {}, ("Clone", _, _) => { @@ -598,6 +653,10 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty writeln!(w, "\tfn clone(&self) -> Self {{").unwrap(); writeln!(w, "\t\t{}_clone(self)", trait_name).unwrap(); writeln!(w, "\t}}\n}}").unwrap(); + writeln!(w, "impl Clone for {}Ref {{", trait_name).unwrap(); + writeln!(w, "\tfn clone(&self) -> Self {{").unwrap(); + writeln!(w, "\t\tSelf({}_clone(&self.0))", trait_name).unwrap(); + writeln!(w, "\t}}\n}}").unwrap(); }, ("std::fmt::Debug", _, _)|("core::fmt::Debug", _, _) => { writeln!(w, "impl core::fmt::Debug for {} {{", trait_name).unwrap(); @@ -605,6 +664,11 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty writeln!(w, "\t\tf.write_str((self.debug_str)(self.this_arg).into_str())").unwrap(); writeln!(w, "\t}}").unwrap(); writeln!(w, "}}").unwrap(); + writeln!(w, "impl core::fmt::Debug for {}Ref {{", trait_name).unwrap(); + writeln!(w, "\tfn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {{").unwrap(); + writeln!(w, "\t\tf.write_str((self.0.debug_str)(self.0.this_arg).into_str())").unwrap(); + writeln!(w, "\t}}").unwrap(); + writeln!(w, "}}").unwrap(); }, (s, i, generic_args) => { if let Some(supertrait) = types.crate_types.traits.get(s) { @@ -621,9 +685,16 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty write!(w, " {}", $s).unwrap(); maybe_write_generics(w, &$supertrait.generics, $generic_args, types, false); writeln!(w, " for {} {{", trait_name).unwrap(); - impl_trait_for_c!($supertrait, format!(".{}", $i), &resolver, $generic_args); writeln!(w, "}}").unwrap(); + + write!(w, "impl").unwrap(); + maybe_write_lifetime_generics(w, &$supertrait.generics, types); + write!(w, " {}", $s).unwrap(); + maybe_write_generics(w, &$supertrait.generics, $generic_args, types, false); + writeln!(w, " for {}Ref {{", trait_name).unwrap(); + impl_trait_for_c!($supertrait, format!(".0.{}", $i), &resolver, $generic_args); + writeln!(w, "}}").unwrap(); } } impl_supertrait!(s, supertrait, i, generic_args); @@ -651,7 +722,7 @@ fn writeln_trait<'a, 'b, W: std::io::Write>(w: &mut W, t: &'a syn::ItemTrait, ty impl_trait_for_c!(t, "", types, &syn::PathArguments::None); writeln!(w, "}}\n").unwrap(); - writeln!(w, "struct {}Ref({});", trait_name, trait_name).unwrap(); + writeln!(w, "pub struct {}Ref({});", trait_name, trait_name).unwrap(); write!(w, "impl").unwrap(); maybe_write_lifetime_generics(w, &t.generics, types); write!(w, " rust{}", t.ident).unwrap(); -- 2.39.5