Support concretizing generics in template parameters
[ldk-c-bindings] / c-bindings-gen / src / types.rs
index 0b8e10d6ee02b7330573e6034edc8e3512ed3077..acfcc1d8a8c26c0b85a3493ae1c82646733fe38b 100644 (file)
@@ -6,6 +6,7 @@
 // You may not use this file except in accordance with one or both of these
 // licenses.
 
+use std::cell::RefCell;
 use std::collections::{HashMap, HashSet};
 use std::fs::File;
 use std::io::Write;
@@ -82,12 +83,21 @@ pub fn export_status(attrs: &[syn::Attribute]) -> ExportStatus {
                                                        if i == "any" {
                                                                // #[cfg(any(test, feature = ""))]
                                                                if let TokenTree::Group(g) = iter.next().unwrap() {
-                                                                       if let TokenTree::Ident(i) = g.stream().into_iter().next().unwrap() {
-                                                                               if i == "test" || i == "feature" {
-                                                                                       // If its cfg(feature(...)) we assume its test-only
-                                                                                       return ExportStatus::TestOnly;
+                                                                       let mut all_test = true;
+                                                                       for token in g.stream().into_iter() {
+                                                                               if let TokenTree::Ident(i) = token {
+                                                                                       match format!("{}", i).as_str() {
+                                                                                               "test" => {},
+                                                                                               "feature" => {},
+                                                                                               _ => all_test = false,
+                                                                                       }
+                                                                               } else if let TokenTree::Literal(lit) = token {
+                                                                                       if format!("{}", lit) != "fuzztarget" {
+                                                                                               all_test = false;
+                                                                                       }
                                                                                }
                                                                        }
+                                                                       if all_test { return ExportStatus::TestOnly; }
                                                                }
                                                        } else if i == "test" || i == "feature" {
                                                                // If its cfg(feature(...)) we assume its test-only
@@ -150,23 +160,21 @@ pub fn is_enum_opaque(e: &syn::ItemEnum) -> bool {
 /// It maps both direct types as well as Deref<Target = X>, mapping them via the provided
 /// TypeResolver's resolve_path function (ie traits map to the concrete jump table, structs to the
 /// concrete C container struct, etc).
-pub struct GenericTypes<'a> {
-       typed_generics: Vec<HashMap<&'a syn::Ident, (String, Option<&'a syn::Path>)>>,
+#[must_use]
+pub struct GenericTypes<'a, 'b> {
+       parent: Option<&'b GenericTypes<'b, 'b>>,
+       typed_generics: HashMap<&'a syn::Ident, (String, Option<&'a syn::Path>)>,
 }
-impl<'a> GenericTypes<'a> {
+impl<'a, 'p: 'a> GenericTypes<'a, 'p> {
        pub fn new() -> Self {
-               Self { typed_generics: vec![HashMap::new()], }
+               Self { parent: None, typed_generics: HashMap::new(), }
        }
 
        /// push a new context onto the stack, allowing for a new set of generics to be learned which
        /// will override any lower contexts, but which will still fall back to resoltion via lower
        /// contexts.
-       pub fn push_ctx(&mut self) {
-               self.typed_generics.push(HashMap::new());
-       }
-       /// pop the latest context off the stack.
-       pub fn pop_ctx(&mut self) {
-               self.typed_generics.pop();
+       pub fn push_ctx<'c>(&'c self) -> GenericTypes<'a, 'c> {
+               GenericTypes { parent: Some(self), typed_generics: HashMap::new(), }
        }
 
        /// Learn the generics in generics in the current context, given a TypeResolver.
@@ -192,7 +200,7 @@ impl<'a> GenericTypes<'a> {
                                                                        path = "crate::".to_string() + &path;
                                                                        Some(&trait_bound.path)
                                                                } else { None };
-                                                               self.typed_generics.last_mut().unwrap().insert(&type_param.ident, (path, new_ident));
+                                                               self.typed_generics.insert(&type_param.ident, (path, new_ident));
                                                        } else { return false; }
                                                }
                                        }
@@ -208,7 +216,7 @@ impl<'a> GenericTypes<'a> {
                                                if p.qself.is_some() { return false; }
                                                if p.path.leading_colon.is_some() { return false; }
                                                let mut p_iter = p.path.segments.iter();
-                                               if let Some(gen) = self.typed_generics.last_mut().unwrap().get_mut(&p_iter.next().unwrap().ident) {
+                                               if let Some(gen) = self.typed_generics.get_mut(&p_iter.next().unwrap().ident) {
                                                        if gen.0 != "std::ops::Deref" { return false; }
                                                        if &format!("{}", p_iter.next().unwrap().ident) != "Target" { return false; }
 
@@ -227,7 +235,7 @@ impl<'a> GenericTypes<'a> {
                                }
                        }
                }
-               for (_, (_, ident)) in self.typed_generics.last().unwrap().iter() {
+               for (_, (_, ident)) in self.typed_generics.iter() {
                        if ident.is_none() { return false; }
                }
                true
@@ -253,7 +261,7 @@ impl<'a> GenericTypes<'a> {
                                                                        path = "crate::".to_string() + &path;
                                                                        Some(&tr.path)
                                                                } else { None };
-                                                               self.typed_generics.last_mut().unwrap().insert(&t.ident, (path, new_ident));
+                                                               self.typed_generics.insert(&t.ident, (path, new_ident));
                                                        } else { unimplemented!(); }
                                                },
                                                _ => unimplemented!(),
@@ -267,21 +275,21 @@ impl<'a> GenericTypes<'a> {
 
        /// Attempt to resolve an Ident as a generic parameter and return the full path.
        pub fn maybe_resolve_ident<'b>(&'b self, ident: &syn::Ident) -> Option<&'b String> {
-               for gen in self.typed_generics.iter().rev() {
-                       if let Some(res) = gen.get(ident).map(|(a, _)| a) {
-                               return Some(res);
-                       }
+               if let Some(res) = self.typed_generics.get(ident).map(|(a, _)| a) {
+                       return Some(res);
+               }
+               if let Some(parent) = self.parent {
+                       parent.maybe_resolve_ident(ident)
+               } else {
+                       None
                }
-               None
        }
        /// Attempt to resolve a Path as a generic parameter and return the full path. as both a string
        /// and syn::Path.
        pub fn maybe_resolve_path<'b>(&'b self, path: &syn::Path) -> Option<(&'b String, &'a syn::Path)> {
                if let Some(ident) = path.get_ident() {
-                       for gen in self.typed_generics.iter().rev() {
-                               if let Some(res) = gen.get(ident).map(|(a, b)| (a, b.unwrap())) {
-                                       return Some(res);
-                               }
+                       if let Some(res) = self.typed_generics.get(ident).map(|(a, b)| (a, b.unwrap())) {
+                               return Some(res);
                        }
                } else {
                        // Associated types are usually specified as "Self::Generic", so we check for that
@@ -289,14 +297,16 @@ impl<'a> GenericTypes<'a> {
                        let mut it = path.segments.iter();
                        if path.segments.len() == 2 && format!("{}", it.next().unwrap().ident) == "Self" {
                                let ident = &it.next().unwrap().ident;
-                               for gen in self.typed_generics.iter().rev() {
-                                       if let Some(res) = gen.get(ident).map(|(a, b)| (a, b.unwrap())) {
-                                               return Some(res);
-                                       }
+                               if let Some(res) = self.typed_generics.get(ident).map(|(a, b)| (a, b.unwrap())) {
+                                       return Some(res);
                                }
                        }
                }
-               None
+               if let Some(parent) = self.parent {
+                       parent.maybe_resolve_path(path)
+               } else {
+                       None
+               }
        }
 }
 
@@ -403,10 +413,7 @@ impl<'mod_lifetime, 'crate_lft: 'mod_lifetime> ImportResolver<'mod_lifetime, 'cr
                                                        else { process_alias = false; }
                                                }
                                                if process_alias {
-                                                       match &*t.ty {
-                                                               syn::Type::Path(_) => { declared.insert(t.ident.clone(), DeclType::StructImported); },
-                                                               _ => {},
-                                                       }
+                                                       declared.insert(t.ident.clone(), DeclType::StructImported);
                                                }
                                        }
                                },
@@ -500,11 +507,20 @@ impl<'mod_lifetime, 'crate_lft: 'mod_lifetime> ImportResolver<'mod_lifetime, 'cr
        pub fn resolve_imported_refs(&self, mut ty: syn::Type) -> syn::Type {
                match &mut ty {
                        syn::Type::Path(p) => {
-                               if let Some(ident) = p.path.get_ident() {
-                                       if let Some((_, newpath)) = self.imports.get(ident) {
-                                               p.path = newpath.clone();
+eprintln!("rir {:?}", p);
+                               if p.path.segments.len() != 1 { unimplemented!(); }
+                               let mut args = p.path.segments[0].arguments.clone();
+                               if let syn::PathArguments::AngleBracketed(ref mut generics) = &mut args {
+                                       for arg in generics.args.iter_mut() {
+                                               if let syn::GenericArgument::Type(ref mut t) = arg {
+                                                       *t = self.resolve_imported_refs(t.clone());
+                                               }
                                        }
-                               } else { unimplemented!(); }
+                               }
+                               if let Some((_, newpath)) = self.imports.get(single_ident_generic_path_to_ident(&p.path).unwrap()) {
+                                       p.path = newpath.clone();
+                               }
+                               p.path.segments[0].arguments = args;
                        },
                        syn::Type::Reference(r) => {
                                r.elem = Box::new(self.resolve_imported_refs((*r.elem).clone()));
@@ -530,6 +546,53 @@ impl<'mod_lifetime, 'crate_lft: 'mod_lifetime> ImportResolver<'mod_lifetime, 'cr
 #[allow(deprecated)]
 pub type NonRandomHash = hash::BuildHasherDefault<hash::SipHasher>;
 
+/// A public module
+pub struct ASTModule {
+       pub attrs: Vec<syn::Attribute>,
+       pub items: Vec<syn::Item>,
+       pub submods: Vec<String>,
+}
+/// A struct containing the syn::File AST for each file in the crate.
+pub struct FullLibraryAST {
+       pub modules: HashMap<String, ASTModule, NonRandomHash>,
+}
+impl FullLibraryAST {
+       fn load_module(&mut self, module: String, attrs: Vec<syn::Attribute>, mut items: Vec<syn::Item>) {
+               let mut non_mod_items = Vec::with_capacity(items.len());
+               let mut submods = Vec::with_capacity(items.len());
+               for item in items.drain(..) {
+                       match item {
+                               syn::Item::Mod(m) if m.content.is_some() => {
+                                       if export_status(&m.attrs) == ExportStatus::Export {
+                                               if let syn::Visibility::Public(_) = m.vis {
+                                                       let modident = format!("{}", m.ident);
+                                                       let modname = if module != "" {
+                                                               module.clone() + "::" + &modident
+                                                       } else {
+                                                               modident.clone()
+                                                       };
+                                                       self.load_module(modname, m.attrs, m.content.unwrap().1);
+                                                       submods.push(modident);
+                                               } else {
+                                                       non_mod_items.push(syn::Item::Mod(m));
+                                               }
+                                       }
+                               },
+                               syn::Item::Mod(_) => panic!("--pretty=expanded output should never have non-body modules"),
+                               _ => { non_mod_items.push(item); }
+                       }
+               }
+               self.modules.insert(module, ASTModule { attrs, items: non_mod_items, submods });
+       }
+
+       pub fn load_lib(lib: syn::File) -> Self {
+               assert_eq!(export_status(&lib.attrs), ExportStatus::Export);
+               let mut res = Self { modules: HashMap::default() };
+               res.load_module("".to_owned(), lib.attrs, lib.items);
+               res
+       }
+}
+
 /// Top-level struct tracking everything which has been defined while walking the crate.
 pub struct CrateTypes<'a> {
        /// This may contain structs or enums, but only when either is mapped as
@@ -547,14 +610,38 @@ pub struct CrateTypes<'a> {
        /// exists.
        ///
        /// This is used at the end of processing to make C++ wrapper classes
-       pub templates_defined: HashMap<String, bool, NonRandomHash>,
+       pub templates_defined: RefCell<HashMap<String, bool, NonRandomHash>>,
        /// The output file for any created template container types, written to as we find new
        /// template containers which need to be defined.
-       pub template_file: &'a mut File,
+       template_file: RefCell<&'a mut File>,
        /// Set of containers which are clonable
-       pub clonable_types: HashSet<String>,
+       clonable_types: RefCell<HashSet<String>>,
        /// Key impls Value
        pub trait_impls: HashMap<String, Vec<String>>,
+       /// The full set of modules in the crate(s)
+       pub lib_ast: &'a FullLibraryAST,
+}
+
+impl<'a> CrateTypes<'a> {
+       pub fn new(template_file: &'a mut File, libast: &'a FullLibraryAST) -> Self {
+               CrateTypes {
+                       opaques: HashMap::new(), mirrored_enums: HashMap::new(), traits: HashMap::new(),
+                       type_aliases: HashMap::new(), reverse_alias_map: HashMap::new(),
+                       templates_defined: RefCell::new(HashMap::default()),
+                       clonable_types: RefCell::new(HashSet::new()), trait_impls: HashMap::new(),
+                       template_file: RefCell::new(template_file), lib_ast: &libast,
+               }
+       }
+       pub fn set_clonable(&self, object: String) {
+               self.clonable_types.borrow_mut().insert(object);
+       }
+       pub fn is_clonable(&self, object: &str) -> bool {
+               self.clonable_types.borrow().contains(object)
+       }
+       pub fn write_new_template(&self, mangled_container: String, has_destructor: bool, created_container: &[u8]) {
+               self.template_file.borrow_mut().write(created_container).unwrap();
+               self.templates_defined.borrow_mut().insert(mangled_container, has_destructor);
+       }
 }
 
 /// A struct which tracks resolving rust types into C-mapped equivalents, exists for one specific
@@ -562,7 +649,7 @@ pub struct CrateTypes<'a> {
 pub struct TypeResolver<'mod_lifetime, 'crate_lft: 'mod_lifetime> {
        pub orig_crate: &'mod_lifetime str,
        pub module_path: &'mod_lifetime str,
-       pub crate_types: &'mod_lifetime mut CrateTypes<'crate_lft>,
+       pub crate_types: &'mod_lifetime CrateTypes<'crate_lft>,
        types: ImportResolver<'mod_lifetime, 'crate_lft>,
 }
 
@@ -592,7 +679,7 @@ enum ContainerPrefixLocation {
 }
 
 impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
-       pub fn new(orig_crate: &'a str, module_path: &'a str, types: ImportResolver<'a, 'c>, crate_types: &'a mut CrateTypes<'c>) -> Self {
+       pub fn new(orig_crate: &'a str, module_path: &'a str, types: ImportResolver<'a, 'c>, crate_types: &'a CrateTypes<'c>) -> Self {
                Self { orig_crate, module_path, types, crate_types }
        }
 
@@ -627,7 +714,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                }
        }
        pub fn is_clonable(&self, ty: &str) -> bool {
-               if self.crate_types.clonable_types.contains(ty) { return true; }
+               if self.crate_types.is_clonable(ty) { return true; }
                if self.is_primitive(ty) { return true; }
                match ty {
                        "()" => true,
@@ -1923,7 +2010,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
        // *** C Container Type Equivalent and alias Printing ***
        // ******************************************************
 
-       fn write_template_generics<'b, W: std::io::Write>(&mut self, w: &mut W, args: &mut dyn Iterator<Item=&'b syn::Type>, generics: Option<&GenericTypes>, is_ref: bool) -> bool {
+       fn write_template_generics<'b, W: std::io::Write>(&self, w: &mut W, args: &mut dyn Iterator<Item=&'b syn::Type>, generics: Option<&GenericTypes>, is_ref: bool) -> bool {
                for (idx, t) in args.enumerate() {
                        if idx != 0 {
                                write!(w, ", ").unwrap();
@@ -1957,8 +2044,8 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                }
                true
        }
-       fn check_create_container(&mut self, mangled_container: String, container_type: &str, args: Vec<&syn::Type>, generics: Option<&GenericTypes>, is_ref: bool) -> bool {
-               if !self.crate_types.templates_defined.get(&mangled_container).is_some() {
+       fn check_create_container(&self, mangled_container: String, container_type: &str, args: Vec<&syn::Type>, generics: Option<&GenericTypes>, is_ref: bool) -> bool {
+               if !self.crate_types.templates_defined.borrow().get(&mangled_container).is_some() {
                        let mut created_container: Vec<u8> = Vec::new();
 
                        if container_type == "Result" {
@@ -1989,7 +2076,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                                let is_clonable = self.is_clonable(&ok_str) && self.is_clonable(&err_str);
                                write_result_block(&mut created_container, &mangled_container, &ok_str, &err_str, is_clonable);
                                if is_clonable {
-                                       self.crate_types.clonable_types.insert(Self::generated_container_path().to_owned() + "::" + &mangled_container);
+                                       self.crate_types.set_clonable(Self::generated_container_path().to_owned() + "::" + &mangled_container);
                                }
                        } else if container_type == "Vec" {
                                let mut a_ty: Vec<u8> = Vec::new();
@@ -1998,7 +2085,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                                let is_clonable = self.is_clonable(&ty);
                                write_vec_block(&mut created_container, &mangled_container, &ty, is_clonable);
                                if is_clonable {
-                                       self.crate_types.clonable_types.insert(Self::generated_container_path().to_owned() + "::" + &mangled_container);
+                                       self.crate_types.set_clonable(Self::generated_container_path().to_owned() + "::" + &mangled_container);
                                }
                        } else if container_type.ends_with("Tuple") {
                                let mut tuple_args = Vec::new();
@@ -2014,7 +2101,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                                }
                                write_tuple_block(&mut created_container, &mangled_container, &tuple_args, is_clonable);
                                if is_clonable {
-                                       self.crate_types.clonable_types.insert(Self::generated_container_path().to_owned() + "::" + &mangled_container);
+                                       self.crate_types.set_clonable(Self::generated_container_path().to_owned() + "::" + &mangled_container);
                                }
                        } else if container_type == "Option" {
                                let mut a_ty: Vec<u8> = Vec::new();
@@ -2023,14 +2110,12 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                                let is_clonable = self.is_clonable(&ty);
                                write_option_block(&mut created_container, &mangled_container, &ty, is_clonable);
                                if is_clonable {
-                                       self.crate_types.clonable_types.insert(Self::generated_container_path().to_owned() + "::" + &mangled_container);
+                                       self.crate_types.set_clonable(Self::generated_container_path().to_owned() + "::" + &mangled_container);
                                }
                        } else {
                                unreachable!();
                        }
-                       self.crate_types.templates_defined.insert(mangled_container.clone(), true);
-
-                       self.crate_types.template_file.write(&created_container).unwrap();
+                       self.crate_types.write_new_template(mangled_container.clone(), true, &created_container);
                }
                true
        }
@@ -2040,7 +2125,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                } else { unimplemented!(); }
        }
        fn write_c_mangled_container_path_intern<W: std::io::Write>
-                       (&mut self, w: &mut W, args: Vec<&syn::Type>, generics: Option<&GenericTypes>, ident: &str, is_ref: bool, is_mut: bool, ptr_for_ref: bool, in_type: bool) -> bool {
+                       (&self, w: &mut W, args: Vec<&syn::Type>, generics: Option<&GenericTypes>, ident: &str, is_ref: bool, is_mut: bool, ptr_for_ref: bool, in_type: bool) -> bool {
                let mut mangled_type: Vec<u8> = Vec::new();
                if !self.is_transparent_container(ident, is_ref, args.iter().map(|a| *a)) {
                        write!(w, "C{}_", ident).unwrap();
@@ -2152,7 +2237,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                // Make sure the type is actually defined:
                self.check_create_container(String::from_utf8(mangled_type).unwrap(), ident, args, generics, is_ref)
        }
-       fn write_c_mangled_container_path<W: std::io::Write>(&mut self, w: &mut W, args: Vec<&syn::Type>, generics: Option<&GenericTypes>, ident: &str, is_ref: bool, is_mut: bool, ptr_for_ref: bool) -> bool {
+       fn write_c_mangled_container_path<W: std::io::Write>(&self, w: &mut W, args: Vec<&syn::Type>, generics: Option<&GenericTypes>, ident: &str, is_ref: bool, is_mut: bool, ptr_for_ref: bool) -> bool {
                if !self.is_transparent_container(ident, is_ref, args.iter().map(|a| *a)) {
                        write!(w, "{}::", Self::generated_container_path()).unwrap();
                }
@@ -2195,7 +2280,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        false
                }
        }
-       fn write_c_type_intern<W: std::io::Write>(&mut self, w: &mut W, t: &syn::Type, generics: Option<&GenericTypes>, is_ref: bool, is_mut: bool, ptr_for_ref: bool) -> bool {
+       fn write_c_type_intern<W: std::io::Write>(&self, w: &mut W, t: &syn::Type, generics: Option<&GenericTypes>, is_ref: bool, is_mut: bool, ptr_for_ref: bool) -> bool {
                match t {
                        syn::Type::Path(p) => {
                                if p.qself.is_some() {
@@ -2286,14 +2371,14 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        _ => false,
                }
        }
-       pub fn write_c_type<W: std::io::Write>(&mut self, w: &mut W, t: &syn::Type, generics: Option<&GenericTypes>, ptr_for_ref: bool) {
+       pub fn write_c_type<W: std::io::Write>(&self, w: &mut W, t: &syn::Type, generics: Option<&GenericTypes>, ptr_for_ref: bool) {
                assert!(self.write_c_type_intern(w, t, generics, false, false, ptr_for_ref));
        }
-       pub fn understood_c_path(&mut self, p: &syn::Path) -> bool {
+       pub fn understood_c_path(&self, p: &syn::Path) -> bool {
                if p.leading_colon.is_some() { return false; }
                self.write_c_path_intern(&mut std::io::sink(), p, None, false, false, false)
        }
-       pub fn understood_c_type(&mut self, t: &syn::Type, generics: Option<&GenericTypes>) -> bool {
+       pub fn understood_c_type(&self, t: &syn::Type, generics: Option<&GenericTypes>) -> bool {
                self.write_c_type_intern(&mut std::io::sink(), t, generics, false, false, false)
        }
 }