Process all type aliases as C types, leaning on annotations to skip
[ldk-c-bindings] / c-bindings-gen / src / types.rs
index 007441ed5625e4f4bba85aa5ea7057c822a253c7..d53801959deb67c8fa2f568626719587aabb9d3e 100644 (file)
@@ -110,8 +110,7 @@ pub fn export_status(attrs: &[syn::Attribute]) -> ExportStatus {
                                                                        }
                                                                        if all_test { return ExportStatus::TestOnly; }
                                                                }
-                                                       } else if i == "test" || i == "feature" {
-                                                               // If its cfg(feature(...)) we assume its test-only
+                                                       } else if i == "test" {
                                                                return ExportStatus::TestOnly;
                                                        }
                                                }
@@ -353,7 +352,7 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> {
        }
 }
 
-trait ResolveType<'a> { fn resolve_type(&'a self, ty: &'a syn::Type) -> &'a syn::Type; }
+pub trait ResolveType<'a> { fn resolve_type(&'a self, ty: &'a syn::Type) -> &'a syn::Type; }
 impl<'a, 'b, 'c: 'a + 'b> ResolveType<'c> for Option<&GenericTypes<'a, 'b>> {
        fn resolve_type(&'c self, ty: &'c syn::Type) -> &'c syn::Type {
                if let Some(us) = self {
@@ -514,14 +513,7 @@ impl<'mod_lifetime, 'crate_lft: 'mod_lifetime> ImportResolver<'mod_lifetime, 'cr
                                },
                                syn::Item::Type(t) if export_status(&t.attrs) == ExportStatus::Export => {
                                        if let syn::Visibility::Public(_) = t.vis {
-                                               let mut process_alias = true;
-                                               for tok in t.generics.params.iter() {
-                                                       if let syn::GenericParam::Lifetime(_) = tok {}
-                                                       else { process_alias = false; }
-                                               }
-                                               if process_alias {
-                                                       declared.insert(t.ident.clone(), DeclType::StructImported { generics: &t.generics });
-                                               }
+                                               declared.insert(t.ident.clone(), DeclType::StructImported { generics: &t.generics });
                                        }
                                },
                                syn::Item::Enum(e) => {
@@ -724,6 +716,7 @@ fn initial_clonable_types() -> HashSet<String> {
        let mut res = HashSet::new();
        res.insert("crate::c_types::u5".to_owned());
        res.insert("crate::c_types::ThirtyTwoBytes".to_owned());
+       res.insert("crate::c_types::SecretKey".to_owned());
        res.insert("crate::c_types::PublicKey".to_owned());
        res.insert("crate::c_types::Transaction".to_owned());
        res.insert("crate::c_types::TxOut".to_owned());
@@ -739,6 +732,8 @@ pub struct CrateTypes<'a> {
        /// This may contain structs or enums, but only when either is mapped as
        /// struct X { inner: *mut originalX, .. }
        pub opaques: HashMap<String, (&'a syn::Ident, &'a syn::Generics)>,
+       /// structs that weren't exposed
+       pub priv_structs: HashMap<String, &'a syn::Generics>,
        /// Enums which are mapped as C enums with conversion functions
        pub mirrored_enums: HashMap<String, &'a syn::ItemEnum>,
        /// Traits which are mapped as a pointer + jump table
@@ -768,7 +763,7 @@ impl<'a> CrateTypes<'a> {
                CrateTypes {
                        opaques: HashMap::new(), mirrored_enums: HashMap::new(), traits: HashMap::new(),
                        type_aliases: HashMap::new(), reverse_alias_map: HashMap::new(),
-                       templates_defined: RefCell::new(HashMap::default()),
+                       templates_defined: RefCell::new(HashMap::default()), priv_structs: HashMap::new(),
                        clonable_types: RefCell::new(initial_clonable_types()), trait_impls: HashMap::new(),
                        template_file: RefCell::new(template_file), lib_ast: &libast,
                }
@@ -828,7 +823,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
        // *************************************************
 
        /// Returns true we if can just skip passing this to C entirely
-       fn skip_path(&self, full_path: &str) -> bool {
+       pub fn skip_path(&self, full_path: &str) -> bool {
                full_path == "bitcoin::secp256k1::Secp256k1" ||
                full_path == "bitcoin::secp256k1::Signing" ||
                full_path == "bitcoin::secp256k1::Verification"
@@ -1134,7 +1129,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "std::time::Duration"|"core::time::Duration" => Some(""),
                        "std::time::SystemTime" => Some(""),
                        "std::io::Error" if !is_ref => Some("crate::c_types::IOError::from_rust("),
-                       "core::fmt::Arguments" => Some("format!(\"{}\", "),
+                       "core::fmt::Arguments" => Some("alloc::format!(\"{}\", "),
 
                        "core::convert::Infallible" => Some("panic!(\"Cannot construct an Infallible: "),
 
@@ -1297,6 +1292,22 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        assert!(args.next().is_none());
                        match inner {
                                syn::Type::Reference(_) => true,
+                               syn::Type::Array(a) => {
+                                       if let syn::Expr::Lit(l) = &a.len {
+                                               if let syn::Lit::Int(i) = &l.lit {
+                                                       if i.base10_digits().parse::<usize>().unwrap() >= 32 {
+                                                               let mut buf = Vec::new();
+                                                               self.write_rust_type(&mut buf, generics, &a.elem);
+                                                               let ty = String::from_utf8(buf).unwrap();
+                                                               ty == "u8"
+                                                       } else {
+                                                               // Blindly assume that if we're trying to create an empty value for an
+                                                               // array < 32 entries that all-0s may be a valid state.
+                                                               unimplemented!();
+                                                       }
+                                               } else { unimplemented!(); }
+                                       } else { unimplemented!(); }
+                               },
                                syn::Type::Path(p) => {
                                        if let Some(resolved) = self.maybe_resolve_path(&p.path, generics) {
                                                if self.c_type_has_inner_from_path(&resolved) { return true; }
@@ -2611,8 +2622,15 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                                                if !self.is_primitive(&resolved) { return false; }
                                                if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(len), .. }) = &a.len {
                                                        if self.c_type_from_path(&format!("[{}; {}]", resolved, len.base10_digits()), is_ref, ptr_for_ref).is_none() { return false; }
-                                                       write!(w, "_{}{}", resolved, len.base10_digits()).unwrap();
-                                                       write!(mangled_type, "_{}{}", resolved, len.base10_digits()).unwrap();
+                                                       if in_type || args.len() != 1 {
+                                                               write!(w, "_{}{}", resolved, len.base10_digits()).unwrap();
+                                                               write!(mangled_type, "_{}{}", resolved, len.base10_digits()).unwrap();
+                                                       } else {
+                                                               let arrty = format!("[{}; {}]", resolved, len.base10_digits());
+                                                               let realty = self.c_type_from_path(&arrty, is_ref, ptr_for_ref).unwrap_or(&arrty);
+                                                               write!(w, "{}", realty).unwrap();
+                                                               write!(mangled_type, "{}", realty).unwrap();
+                                                       }
                                                } else { return false; }
                                        } else { return false; }
                                },