From c98ebb54cc7f57f8c4c31d01ec69c81b172e49b1 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 23 Jun 2022 16:13:06 +0000 Subject: [PATCH] Generate mutable references in `default_generics` when relevant As we move towards resolving generics via the new `Type`-based interface instead of the string-based one, we need to ensure we retain the mutability flag when resolving references, which we do here. --- c-bindings-gen/src/blocks.rs | 2 +- c-bindings-gen/src/types.rs | 29 +++++++++++++++++++---------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/c-bindings-gen/src/blocks.rs b/c-bindings-gen/src/blocks.rs index 5072f8a..84a971b 100644 --- a/c-bindings-gen/src/blocks.rs +++ b/c-bindings-gen/src/blocks.rs @@ -576,7 +576,7 @@ pub fn write_method_params(w: &mut W, sig: &syn::Signature, t }, _ => unimplemented!(), } - w.write(&c_type).unwrap(); + w.write_all(&c_type).unwrap(); } } } diff --git a/c-bindings-gen/src/types.rs b/c-bindings-gen/src/types.rs index 16cc067..8f5e667 100644 --- a/c-bindings-gen/src/types.rs +++ b/c-bindings-gen/src/types.rs @@ -190,7 +190,7 @@ pub struct GenericTypes<'a, 'b> { self_ty: Option, parent: Option<&'b GenericTypes<'b, 'b>>, typed_generics: HashMap<&'a syn::Ident, String>, - default_generics: HashMap<&'a syn::Ident, (syn::Type, syn::Type)>, + default_generics: HashMap<&'a syn::Ident, (syn::Type, syn::Type, syn::Type)>, } impl<'a, 'p: 'a> GenericTypes<'a, 'p> { pub fn new(self_ty: Option) -> Self { @@ -226,6 +226,10 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> { if non_lifetimes_processed { return false; } non_lifetimes_processed = true; if path != "std::ops::Deref" && path != "core::ops::Deref" { + let p = string_path_to_syn_path(&path); + let ref_ty = parse_quote!(&#p); + let mut_ref_ty = parse_quote!(&mut #p); + self.default_generics.insert(&type_param.ident, (syn::Type::Path(syn::TypePath { qself: None, path: p }), ref_ty, mut_ref_ty)); new_typed_generics.insert(&type_param.ident, Some(path)); } else { // If we're templated on Deref, store @@ -240,7 +244,7 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> { syn::GenericArgument::Binding(ref b) => { if &format!("{}", b.ident) != "Target" { return false; } let default = &b.ty; - self.default_generics.insert(&type_param.ident, (parse_quote!(&#default), parse_quote!(&#default))); + self.default_generics.insert(&type_param.ident, (parse_quote!(&#default), parse_quote!(&#default), parse_quote!(&mut #default))); break 'bound_loop; }, _ => unimplemented!(), @@ -255,7 +259,7 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> { } if let Some(default) = type_param.default.as_ref() { assert!(type_param.bounds.is_empty()); - self.default_generics.insert(&type_param.ident, (default.clone(), parse_quote!(&#default))); + self.default_generics.insert(&type_param.ident, (default.clone(), parse_quote!(&#default), parse_quote!(&mut #default))); } }, _ => {}, @@ -288,10 +292,11 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> { qself: None, path: string_path_to_syn_path(&resolved) }); let ref_ty = parse_quote!(&#ty); + let mut_ref_ty = parse_quote!(&mut #ty); if types.crate_types.traits.get(&resolved).is_some() { - self.default_generics.insert(p_ident, (ty, ref_ty)); + self.default_generics.insert(p_ident, (ty, ref_ty, mut_ref_ty)); } else { - self.default_generics.insert(p_ident, (ref_ty.clone(), ref_ty)); + self.default_generics.insert(p_ident, (ref_ty.clone(), ref_ty, mut_ref_ty)); } *gen = Some(resolved); @@ -383,16 +388,20 @@ impl<'a, 'b, 'c: 'a + 'b> ResolveType<'c> for Option<&GenericTypes<'a, 'b>> { match ty { syn::Type::Path(p) => { if let Some(ident) = p.path.get_ident() { - if let Some((ty, _)) = us.default_generics.get(ident) { - return ty; + if let Some((ty, _, _)) = us.default_generics.get(ident) { + return self.resolve_type(ty); } } }, - syn::Type::Reference(syn::TypeReference { elem, .. }) => { + syn::Type::Reference(syn::TypeReference { elem, mutability, .. }) => { if let syn::Type::Path(p) = &**elem { if let Some(ident) = p.path.get_ident() { - if let Some((_, refty)) = us.default_generics.get(ident) { - return refty; + if let Some((_, refty, mut_ref_ty)) = us.default_generics.get(ident) { + if mutability.is_some() { + return self.resolve_type(mut_ref_ty); + } else { + return self.resolve_type(refty); + } } } } -- 2.39.5