Move overly-specific check to an assertion where its relevant
[ldk-c-bindings] / c-bindings-gen / src / types.rs
index 228d6e8f221fda5a13424f6fe66832c67e302c4d..c96a3b9a2fd51220057d89381c95e77e2f5d11a6 100644 (file)
@@ -227,12 +227,13 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> {
                                                                non_lifetimes_processed = true;
                                                                if path != "std::ops::Deref" && path != "core::ops::Deref" {
                                                                        new_typed_generics.insert(&type_param.ident, Some(path));
-                                                               } else if trait_bound.path.segments.len() == 1 {
+                                                               } else {
                                                                        // If we're templated on Deref<Target = ConcreteThing>, store
                                                                        // the reference type in `default_generics` which handles full
                                                                        // types and not just paths.
                                                                        if let syn::PathArguments::AngleBracketed(ref args) =
                                                                                        trait_bound.path.segments[0].arguments {
+                                                                               assert_eq!(trait_bound.path.segments.len(), 1);
                                                                                for subargument in args.args.iter() {
                                                                                        match subargument {
                                                                                                syn::GenericArgument::Lifetime(_) => {},
@@ -268,7 +269,8 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> {
                                                if p.qself.is_some() { return false; }
                                                if p.path.leading_colon.is_some() { return false; }
                                                let mut p_iter = p.path.segments.iter();
-                                               if let Some(gen) = new_typed_generics.get_mut(&p_iter.next().unwrap().ident) {
+                                               let p_ident = &p_iter.next().unwrap().ident;
+                                               if let Some(gen) = new_typed_generics.get_mut(p_ident) {
                                                        if gen.is_some() { return false; }
                                                        if &format!("{}", p_iter.next().unwrap().ident) != "Target" {return false; }
 
@@ -281,7 +283,14 @@ impl<'a, 'p: 'a> GenericTypes<'a, 'p> {
                                                                        if non_lifetimes_processed { return false; }
                                                                        non_lifetimes_processed = true;
                                                                        assert_simple_bound(&trait_bound);
-                                                                       *gen = Some(types.resolve_path(&trait_bound.path, None));
+                                                                       let resolved = types.resolve_path(&trait_bound.path, None);
+                                                                       let ty = syn::Type::Path(syn::TypePath {
+                                                                               qself: None, path: string_path_to_syn_path(&resolved)
+                                                                       });
+                                                                       let ref_ty = parse_quote!(&#ty);
+                                                                       self.default_generics.insert(p_ident, (ty, ref_ty));
+
+                                                                       *gen = Some(resolved);
                                                                }
                                                        }
                                                } else { return false; }
@@ -438,6 +447,10 @@ impl<'mod_lifetime, 'crate_lft: 'mod_lifetime> ImportResolver<'mod_lifetime, 'cr
                                        new_path = format!("{}::{}{}", crate_name, $ident, $path_suffix);
                                        let crate_name_ident = format_ident!("{}", crate_name);
                                        path.push(parse_quote!(#crate_name_ident));
+                               } else if format!("{}", $ident) == "self" {
+                                       let mut path_iter = partial_path.rsplitn(2, "::");
+                                       path_iter.next().unwrap();
+                                       new_path = path_iter.next().unwrap().to_owned();
                                } else {
                                        new_path = format!("{}{}{}", partial_path, $ident, $path_suffix);
                                }
@@ -452,7 +465,8 @@ impl<'mod_lifetime, 'crate_lft: 'mod_lifetime> ImportResolver<'mod_lifetime, 'cr
                        },
                        syn::UseTree::Name(n) => {
                                push_path!(n.ident, "");
-                               imports.insert(n.ident.clone(), (new_path, syn::Path { leading_colon: Some(syn::Token![::](Span::call_site())), segments: path }));
+                               let imported_ident = syn::Ident::new(new_path.rsplitn(2, "::").next().unwrap(), Span::call_site());
+                               imports.insert(imported_ident, (new_path, syn::Path { leading_colon: Some(syn::Token![::](Span::call_site())), segments: path }));
                        },
                        syn::UseTree::Group(g) => {
                                for i in g.items.iter() {
@@ -711,6 +725,10 @@ impl FullLibraryAST {
 fn initial_clonable_types() -> HashSet<String> {
        let mut res = HashSet::new();
        res.insert("crate::c_types::u5".to_owned());
+       res.insert("crate::c_types::FourBytes".to_owned());
+       res.insert("crate::c_types::TwelveBytes".to_owned());
+       res.insert("crate::c_types::SixteenBytes".to_owned());
+       res.insert("crate::c_types::TwentyBytes".to_owned());
        res.insert("crate::c_types::ThirtyTwoBytes".to_owned());
        res.insert("crate::c_types::SecretKey".to_owned());
        res.insert("crate::c_types::PublicKey".to_owned());
@@ -718,8 +736,17 @@ fn initial_clonable_types() -> HashSet<String> {
        res.insert("crate::c_types::TxOut".to_owned());
        res.insert("crate::c_types::Signature".to_owned());
        res.insert("crate::c_types::RecoverableSignature".to_owned());
+       res.insert("crate::c_types::Bech32Error".to_owned());
        res.insert("crate::c_types::Secp256k1Error".to_owned());
        res.insert("crate::c_types::IOError".to_owned());
+       res.insert("crate::c_types::Error".to_owned());
+       res.insert("crate::c_types::Str".to_owned());
+
+       // Because some types are manually-mapped to CVec_u8Z we may end up checking if its clonable
+       // before we ever get to constructing the type fully via
+       // `write_c_mangled_container_path_intern` (which will add it here too), so we have to manually
+       // add it on startup.
+       res.insert("crate::c_types::derived::CVec_u8Z".to_owned());
        res
 }
 
@@ -827,7 +854,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
        /// Returns true we if can just skip passing this to C entirely
        fn no_arg_path_to_rust(&self, full_path: &str) -> &str {
                if full_path == "bitcoin::secp256k1::Secp256k1" {
-                       "secp256k1::SECP256K1"
+                       "secp256k1::global::SECP256K1"
                } else { unimplemented!(); }
        }
 
@@ -874,7 +901,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
 
                        "std::time::Duration"|"core::time::Duration" => Some("u64"),
                        "std::time::SystemTime" => Some("u64"),
-                       "std::io::Error" => Some("crate::c_types::IOError"),
+                       "std::io::Error"|"lightning::io::Error" => Some("crate::c_types::IOError"),
                        "core::fmt::Arguments" if is_ref => Some("crate::c_types::Str"),
 
                        "core::convert::Infallible" => Some("crate::c_types::NotConstructable"),
@@ -890,7 +917,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "bitcoin::bech32::u5"|"bech32::u5" => Some("crate::c_types::u5"),
                        "core::num::NonZeroU8" => Some("u8"),
 
-                       "bitcoin::secp256k1::PublicKey" => Some("crate::c_types::PublicKey"),
+                       "secp256k1::PublicKey"|"bitcoin::secp256k1::PublicKey" => Some("crate::c_types::PublicKey"),
                        "bitcoin::secp256k1::ecdsa::Signature" => Some("crate::c_types::Signature"),
                        "bitcoin::secp256k1::ecdsa::RecoverableSignature" => Some("crate::c_types::RecoverableSignature"),
                        "bitcoin::secp256k1::SecretKey" if is_ref  => Some("*const [u8; 32]"),
@@ -901,6 +928,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "bitcoin::blockdata::transaction::Transaction"|"bitcoin::Transaction" => Some("crate::c_types::Transaction"),
                        "bitcoin::blockdata::transaction::TxOut" if !is_ref => Some("crate::c_types::TxOut"),
                        "bitcoin::network::constants::Network" => Some("crate::bitcoin::network::Network"),
+                       "bitcoin::util::address::WitnessVersion" => Some("crate::c_types::WitnessVersion"),
                        "bitcoin::blockdata::block::BlockHeader" if is_ref  => Some("*const [u8; 80]"),
                        "bitcoin::blockdata::block::Block" if is_ref  => Some("crate::c_types::u8slice"),
 
@@ -954,7 +982,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
 
                        "str" if is_ref => Some(""),
                        "alloc::string::String"|"String" => Some(""),
-                       "std::io::Error" if !is_ref => Some(""),
+                       "std::io::Error"|"lightning::io::Error" => Some(""),
                        // Note that we'll panic for String if is_ref, as we only have non-owned memory, we
                        // cannot create a &String.
 
@@ -986,6 +1014,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "bitcoin::blockdata::transaction::OutPoint" => Some("crate::c_types::C_to_bitcoin_outpoint("),
                        "bitcoin::blockdata::transaction::TxOut" if !is_ref => Some(""),
                        "bitcoin::network::constants::Network" => Some(""),
+                       "bitcoin::util::address::WitnessVersion" => Some(""),
                        "bitcoin::blockdata::block::BlockHeader" => Some("&::bitcoin::consensus::encode::deserialize(unsafe { &*"),
                        "bitcoin::blockdata::block::Block" if is_ref => Some("&::bitcoin::consensus::encode::deserialize("),
 
@@ -1040,7 +1069,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
 
                        "str" if is_ref => Some(".into_str()"),
                        "alloc::string::String"|"String" => Some(".into_string()"),
-                       "std::io::Error" if !is_ref => Some(".to_rust()"),
+                       "std::io::Error"|"lightning::io::Error" => Some(".to_rust()"),
 
                        "core::convert::Infallible" => Some("\")"),
 
@@ -1067,6 +1096,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "bitcoin::blockdata::transaction::OutPoint" => Some(")"),
                        "bitcoin::blockdata::transaction::TxOut" if !is_ref => Some(".into_rust()"),
                        "bitcoin::network::constants::Network" => Some(".into_bitcoin()"),
+                       "bitcoin::util::address::WitnessVersion" => Some(".into()"),
                        "bitcoin::blockdata::block::BlockHeader" => Some(" }).unwrap()"),
                        "bitcoin::blockdata::block::Block" => Some(".to_slice()).unwrap()"),
 
@@ -1132,7 +1162,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
 
                        "std::time::Duration"|"core::time::Duration" => Some(""),
                        "std::time::SystemTime" => Some(""),
-                       "std::io::Error" if !is_ref => Some("crate::c_types::IOError::from_rust("),
+                       "std::io::Error"|"lightning::io::Error" => Some("crate::c_types::IOError::from_rust("),
                        "core::fmt::Arguments" => Some("alloc::format!(\"{}\", "),
 
                        "core::convert::Infallible" => Some("panic!(\"Cannot construct an Infallible: "),
@@ -1159,6 +1189,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "bitcoin::blockdata::transaction::OutPoint" => Some("crate::c_types::bitcoin_to_C_outpoint("),
                        "bitcoin::blockdata::transaction::TxOut" if !is_ref => Some("crate::c_types::TxOut::from_rust("),
                        "bitcoin::network::constants::Network" => Some("crate::bitcoin::network::Network::from_bitcoin("),
+                       "bitcoin::util::address::WitnessVersion" => Some(""),
                        "bitcoin::blockdata::block::BlockHeader" if is_ref => Some("&local_"),
                        "bitcoin::blockdata::block::Block" if is_ref => Some("crate::c_types::u8slice::from_slice(&local_"),
 
@@ -1208,7 +1239,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
 
                        "std::time::Duration"|"core::time::Duration" => Some(".as_secs()"),
                        "std::time::SystemTime" => Some(".duration_since(::std::time::SystemTime::UNIX_EPOCH).expect(\"Times must be post-1970\").as_secs()"),
-                       "std::io::Error" if !is_ref => Some(")"),
+                       "std::io::Error"|"lightning::io::Error" => Some(")"),
                        "core::fmt::Arguments" => Some(").into()"),
 
                        "core::convert::Infallible" => Some("\")"),
@@ -1234,6 +1265,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        "bitcoin::blockdata::transaction::OutPoint" => Some(")"),
                        "bitcoin::blockdata::transaction::TxOut" if !is_ref => Some(")"),
                        "bitcoin::network::constants::Network" => Some(")"),
+                       "bitcoin::util::address::WitnessVersion" => Some(".into()"),
                        "bitcoin::blockdata::block::BlockHeader" if is_ref => Some(""),
                        "bitcoin::blockdata::block::Block" if is_ref => Some(")"),
 
@@ -1485,7 +1517,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
 
                                if let Some(t) = single_contained {
                                        match t {
-                                               syn::Type::Reference(_)|syn::Type::Path(_)|syn::Type::Slice(_) => {
+                                               syn::Type::Reference(_)|syn::Type::Path(_)|syn::Type::Slice(_)|syn::Type::Array(_) => {
                                                        let mut v = Vec::new();
                                                        let ret_ref = self.write_empty_rust_val_check_suffix(generics, &mut v, t);
                                                        let s = String::from_utf8(v).unwrap();
@@ -1781,7 +1813,6 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        syn::Type::Path(p) => {
                                let resolved = self.resolve_path(&p.path, generics);
                                if let Some(arr_ty) = self.is_real_type_array(&resolved) {
-                                       write!(w, ".data").unwrap();
                                        return self.write_empty_rust_val_check_suffix(generics, w, &arr_ty);
                                }
                                if self.crate_types.opaques.get(&resolved).is_some() {
@@ -1801,7 +1832,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
                        syn::Type::Array(a) => {
                                if let syn::Expr::Lit(l) = &a.len {
                                        if let syn::Lit::Int(i) = &l.lit {
-                                               write!(w, " == [0; {}]", i.base10_digits()).unwrap();
+                                               write!(w, ".data == [0; {}]", i.base10_digits()).unwrap();
                                                EmptyValExpectedTy::NonPointer
                                        } else { unimplemented!(); }
                                } else { unimplemented!(); }