Add function-call generic parameters when they're resovable
authorMatt Corallo <git@bluematt.me>
Wed, 2 Mar 2022 03:19:46 +0000 (03:19 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 2 Mar 2022 03:56:34 +0000 (03:56 +0000)
We really shouldn't ever actually need to do this - rustc should be
smart enough to figure out what we're doing given all the
parameters are concrete types, but this appears to be failing
sometimes.

c-bindings-gen/src/blocks.rs
c-bindings-gen/src/main.rs
c-bindings-gen/src/types.rs

index 5034b0902c1140fb0002f2098971578aee87aa5e..e97fb65e88e4c0a58dfcf82e959e48515aecbc9f 100644 (file)
@@ -745,27 +745,47 @@ pub fn write_method_call_params<W: std::io::Write>(w: &mut W, sig: &syn::Signatu
 pub fn maybe_write_generics<W: std::io::Write>(w: &mut W, generics: &syn::Generics, types: &TypeResolver, concrete_lifetimes: bool) {
        let mut gen_types = GenericTypes::new(None);
        assert!(gen_types.learn_generics(generics, types));
-       if !generics.params.is_empty() {
-               write!(w, "<").unwrap();
-               for (idx, generic) in generics.params.iter().enumerate() {
-                       match generic {
-                               syn::GenericParam::Type(type_param) => {
-                                       write!(w, "{}", if idx != 0 { ", " } else { "" }).unwrap();
-                                       let type_ident = &type_param.ident;
-                                       types.write_c_type_in_generic_param(w, &syn::parse_quote!(#type_ident), Some(&gen_types), false);
-                               },
-                               syn::GenericParam::Lifetime(lt) => {
-                                       if concrete_lifetimes {
-                                               write!(w, "'static").unwrap();
-                                       } else {
-                                               write!(w, "{}'{}", if idx != 0 { ", " } else { "" }, lt.lifetime.ident).unwrap();
+       if generics.params.is_empty() { return; }
+       for generic in generics.params.iter() {
+               match generic {
+                       syn::GenericParam::Type(type_param) => {
+                               for bound in type_param.bounds.iter() {
+                                       match bound {
+                                               syn::TypeParamBound::Trait(t) => {
+                                                       if let Some(trait_bound) = types.maybe_resolve_path(&t.path, None) {
+                                                               if types.skip_path(&trait_bound) {
+                                                                       // Just hope rust deduces generic params if some bounds are skipable.
+                                                                       return;
+                                                               }
+                                                       }
+                                               }
+                                               _ => {},
                                        }
-                               },
-                               _ => unimplemented!(),
+                               }
                        }
+                       _ => {},
+               }
+       }
+
+       write!(w, "<").unwrap();
+       for (idx, generic) in generics.params.iter().enumerate() {
+               match generic {
+                       syn::GenericParam::Type(type_param) => {
+                               write!(w, "{}", if idx != 0 { ", " } else { "" }).unwrap();
+                               let type_ident = &type_param.ident;
+                               types.write_c_type_in_generic_param(w, &syn::parse_quote!(#type_ident), Some(&gen_types), false);
+                       },
+                       syn::GenericParam::Lifetime(lt) => {
+                               if concrete_lifetimes {
+                                       write!(w, "'static").unwrap();
+                               } else {
+                                       write!(w, "{}'{}", if idx != 0 { ", " } else { "" }, lt.lifetime.ident).unwrap();
+                               }
+                       },
+                       _ => unimplemented!(),
                }
-               write!(w, ">").unwrap();
        }
+       write!(w, ">").unwrap();
 }
 
 pub fn maybe_write_lifetime_generics<W: std::io::Write>(w: &mut W, generics: &syn::Generics, types: &TypeResolver) {
index c6a74f66bcc5bbce40fc6cc05679e19aefb66d63..97f44a4c4c458e22ce017c0b8c53d8524077cea2 100644 (file)
@@ -1647,10 +1647,20 @@ fn writeln_fn<'a, 'b, W: std::io::Write>(w: &mut W, f: &'a syn::ItemFn, types: &
        writeln_fn_docs(w, &f.attrs, "", types, Some(&gen_types), f.sig.inputs.iter(), &f.sig.output);
 
        write!(w, "#[no_mangle]\npub extern \"C\" fn {}(", f.sig.ident).unwrap();
+
+
        write_method_params(w, &f.sig, "", types, Some(&gen_types), false, true);
        write!(w, " {{\n\t").unwrap();
        write_method_var_decl_body(w, &f.sig, "", types, Some(&gen_types), false);
-       write!(w, "{}::{}(", types.module_path, f.sig.ident).unwrap();
+       write!(w, "{}::{}", types.module_path, f.sig.ident).unwrap();
+
+       let mut function_generic_args = Vec::new();
+       maybe_write_generics(&mut function_generic_args, &f.sig.generics, types, true);
+       if !function_generic_args.is_empty() {
+               write!(w, "::{}", String::from_utf8(function_generic_args).unwrap()).unwrap();
+       }
+       write!(w, "(").unwrap();
+
        write_method_call_params(w, &f.sig, "", types, Some(&gen_types), "", false);
        writeln!(w, "\n}}\n").unwrap();
 }
index 21b0458ceae6d5bf24b7cb76c7bb44bdbdca3366..88d66403931c34b2792e932309e2dfa08f09b226 100644 (file)
@@ -827,7 +827,7 @@ impl<'a, 'c: 'a> TypeResolver<'a, 'c> {
        // *************************************************
 
        /// Returns true we if can just skip passing this to C entirely
-       fn skip_path(&self, full_path: &str) -> bool {
+       pub fn skip_path(&self, full_path: &str) -> bool {
                full_path == "bitcoin::secp256k1::Secp256k1" ||
                full_path == "bitcoin::secp256k1::Signing" ||
                full_path == "bitcoin::secp256k1::Verification"