Skip to content

Commit 5b9c408

Browse files
committed
Support deriving on more types
1 parent 48f1cfc commit 5b9c408

File tree

1 file changed

+72
-64
lines changed

1 file changed

+72
-64
lines changed

chalk-derive/src/lib.rs

Lines changed: 72 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@ extern crate proc_macro;
33
use proc_macro2::{Span, TokenStream};
44
use quote::quote;
55
use quote::ToTokens;
6-
use syn::{parse_quote, DeriveInput, GenericParam, Ident, TypeParamBound};
6+
use syn::{parse_quote, DeriveInput, Ident, TypeParam, TypeParamBound};
77

88
use synstructure::decl_derive;
99

1010
/// Checks whether a generic parameter has a `: HasInterner` bound
11-
fn has_interner(param: &GenericParam) -> Option<&Ident> {
11+
fn has_interner(param: &TypeParam) -> Option<&Ident> {
1212
bounded_by_trait(param, "HasInterner")
1313
}
1414

1515
/// Checks whether a generic parameter has a `: Interner` bound
16-
fn is_interner(param: &GenericParam) -> Option<&Ident> {
16+
fn is_interner(param: &TypeParam) -> Option<&Ident> {
1717
bounded_by_trait(param, "Interner")
1818
}
1919

@@ -28,48 +28,44 @@ fn has_interner_attr(input: &DeriveInput) -> Option<TokenStream> {
2828
)
2929
}
3030

31-
fn bounded_by_trait<'p>(param: &'p GenericParam, name: &str) -> Option<&'p Ident> {
31+
fn bounded_by_trait<'p>(param: &'p TypeParam, name: &str) -> Option<&'p Ident> {
3232
let name = Some(String::from(name));
33-
match param {
34-
GenericParam::Type(ref t) => t.bounds.iter().find_map(|b| {
35-
if let TypeParamBound::Trait(trait_bound) = b {
36-
if trait_bound
37-
.path
38-
.segments
39-
.last()
40-
.map(|s| s.ident.to_string())
41-
== name
42-
{
43-
return Some(&t.ident);
44-
}
33+
param.bounds.iter().find_map(|b| {
34+
if let TypeParamBound::Trait(trait_bound) = b {
35+
if trait_bound
36+
.path
37+
.segments
38+
.last()
39+
.map(|s| s.ident.to_string())
40+
== name
41+
{
42+
return Some(&param.ident);
4543
}
46-
None
47-
}),
48-
_ => None,
49-
}
44+
}
45+
None
46+
})
5047
}
5148

52-
fn get_generic_param(input: &DeriveInput) -> &GenericParam {
53-
match input.generics.params.len() {
54-
1 => {}
49+
fn get_intern_param(input: &DeriveInput) -> Option<(DeriveKind, &Ident)> {
50+
let mut params = input.generics.type_params().filter_map(|param| {
51+
has_interner(param)
52+
.map(|ident| (DeriveKind::FromHasInterner, ident))
53+
.or_else(|| is_interner(param).map(|ident| (DeriveKind::FromInterner, ident)))
54+
});
5555

56-
0 => panic!(
57-
"deriving this trait requires a single type parameter or a `#[has_interner]` attr"
58-
),
56+
let param = params.next();
57+
assert!(params.next().is_none(), "deriving this trait only works with at most one type parameter that implements HasInterner or Interner");
5958

60-
_ => panic!("deriving this trait only works with a single type parameter"),
61-
};
62-
&input.generics.params[0]
59+
param
6360
}
6461

65-
fn get_generic_param_name(input: &DeriveInput) -> Option<&Ident> {
66-
match get_generic_param(input) {
67-
GenericParam::Type(t) => Some(&t.ident),
68-
_ => None,
69-
}
62+
fn get_intern_param_name(input: &DeriveInput) -> &Ident {
63+
get_intern_param(input)
64+
.expect("deriving this trait requires a parameter that implements HasInterner or Interner")
65+
.1
7066
}
7167

72-
fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
68+
fn try_find_interner(s: &mut synstructure::Structure) -> Option<(TokenStream, DeriveKind)> {
7369
let input = s.ast();
7470

7571
if let Some(arg) = has_interner_attr(input) {
@@ -79,35 +75,40 @@ fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
7975
// struct S {
8076
//
8177
// }
82-
return (arg, DeriveKind::FromHasInternerAttr);
78+
return Some((arg, DeriveKind::FromHasInternerAttr));
8379
}
8480

85-
let generic_param0 = get_generic_param(input);
86-
87-
if let Some(param) = has_interner(generic_param0) {
88-
// HasInterner bound:
89-
//
90-
// Example:
91-
//
92-
// struct Binders<T: HasInterner> { }
93-
s.add_impl_generic(parse_quote! { _I });
94-
95-
s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
96-
s.add_where_predicate(
97-
parse_quote! { #param: ::chalk_ir::interner::HasInterner<Interner = _I> },
98-
);
81+
get_intern_param(input).map(|generic_param0| match generic_param0 {
82+
(DeriveKind::FromHasInterner, param) => {
83+
// HasInterner bound:
84+
//
85+
// Example:
86+
//
87+
// struct Binders<T: HasInterner> { }
88+
s.add_impl_generic(parse_quote! { _I });
89+
90+
s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
91+
s.add_where_predicate(
92+
parse_quote! { #param: ::chalk_ir::interner::HasInterner<Interner = _I> },
93+
);
94+
95+
(quote! { _I }, DeriveKind::FromHasInterner)
96+
}
97+
(DeriveKind::FromInterner, i) => {
98+
// Interner bound:
99+
//
100+
// Example:
101+
//
102+
// struct Foo<I: Interner> { }
103+
(quote! { #i }, DeriveKind::FromInterner)
104+
}
105+
_ => unreachable!(),
106+
})
107+
}
99108

100-
(quote! { _I }, DeriveKind::FromHasInterner)
101-
} else if let Some(i) = is_interner(generic_param0) {
102-
// Interner bound:
103-
//
104-
// Example:
105-
//
106-
// struct Foo<I: Interner> { }
107-
(quote! { #i }, DeriveKind::FromInterner)
108-
} else {
109-
panic!("deriving this trait requires a parameter that implements HasInterner or Interner",);
110-
}
109+
fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
110+
try_find_interner(s)
111+
.expect("deriving this trait requires a `#[has_interner]` attr or a parameter that implements HasInterner or Interner")
111112
}
112113

113114
#[derive(Copy, Clone, PartialEq)]
@@ -174,7 +175,7 @@ fn derive_any_type_visitable(
174175
});
175176

176177
if kind == DeriveKind::FromHasInterner {
177-
let param = get_generic_param_name(input).unwrap();
178+
let param = get_intern_param_name(input);
178179
s.add_where_predicate(parse_quote! { #param: ::chalk_ir::visit::TypeVisitable<#interner> });
179180
}
180181

@@ -278,7 +279,7 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
278279
let input = s.ast();
279280

280281
if kind == DeriveKind::FromHasInterner {
281-
let param = get_generic_param_name(input).unwrap();
282+
let param = get_intern_param_name(input);
282283
s.add_where_predicate(parse_quote! { #param: ::chalk_ir::fold::TypeFoldable<#interner> });
283284
};
284285

@@ -298,7 +299,14 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
298299
}
299300

300301
fn derive_fallible_type_folder(mut s: synstructure::Structure) -> TokenStream {
301-
let (interner, _) = find_interner(&mut s);
302+
let interner = try_find_interner(&mut s).map_or_else(
303+
|| {
304+
s.add_impl_generic(parse_quote! { _I });
305+
s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
306+
quote! { _I }
307+
},
308+
|(interner, _)| interner,
309+
);
302310
s.underscore_const(true);
303311
s.unbound_impl(
304312
quote!(::chalk_ir::fold::FallibleTypeFolder<#interner>),

0 commit comments

Comments
 (0)