@@ -3,17 +3,17 @@ extern crate proc_macro;
3
3
use proc_macro2:: { Span , TokenStream } ;
4
4
use quote:: quote;
5
5
use quote:: ToTokens ;
6
- use syn:: { parse_quote, DeriveInput , GenericParam , Ident , TypeParamBound } ;
6
+ use syn:: { parse_quote, DeriveInput , Ident , TypeParam , TypeParamBound } ;
7
7
8
8
use synstructure:: decl_derive;
9
9
10
10
/// Checks whether a generic parameter has a `: HasInterner` bound
11
- fn has_interner ( param : & GenericParam ) -> Option < & Ident > {
11
+ fn has_interner ( param : & TypeParam ) -> Option < & Ident > {
12
12
bounded_by_trait ( param, "HasInterner" )
13
13
}
14
14
15
15
/// Checks whether a generic parameter has a `: Interner` bound
16
- fn is_interner ( param : & GenericParam ) -> Option < & Ident > {
16
+ fn is_interner ( param : & TypeParam ) -> Option < & Ident > {
17
17
bounded_by_trait ( param, "Interner" )
18
18
}
19
19
@@ -28,48 +28,44 @@ fn has_interner_attr(input: &DeriveInput) -> Option<TokenStream> {
28
28
)
29
29
}
30
30
31
- fn bounded_by_trait < ' p > ( param : & ' p GenericParam , name : & str ) -> Option < & ' p Ident > {
31
+ fn bounded_by_trait < ' p > ( param : & ' p TypeParam , name : & str ) -> Option < & ' p Ident > {
32
32
let name = Some ( String :: from ( name) ) ;
33
- match param {
34
- GenericParam :: Type ( ref t) => t. bounds . iter ( ) . find_map ( |b| {
35
- if let TypeParamBound :: Trait ( trait_bound) = b {
36
- if trait_bound
37
- . path
38
- . segments
39
- . last ( )
40
- . map ( |s| s. ident . to_string ( ) )
41
- == name
42
- {
43
- return Some ( & t. ident ) ;
44
- }
33
+ param. bounds . iter ( ) . find_map ( |b| {
34
+ if let TypeParamBound :: Trait ( trait_bound) = b {
35
+ if trait_bound
36
+ . path
37
+ . segments
38
+ . last ( )
39
+ . map ( |s| s. ident . to_string ( ) )
40
+ == name
41
+ {
42
+ return Some ( & param. ident ) ;
45
43
}
46
- None
47
- } ) ,
48
- _ => None ,
49
- }
44
+ }
45
+ None
46
+ } )
50
47
}
51
48
52
- fn get_generic_param ( input : & DeriveInput ) -> & GenericParam {
53
- match input. generics . params . len ( ) {
54
- 1 => { }
49
+ fn get_intern_param ( input : & DeriveInput ) -> Option < ( DeriveKind , & Ident ) > {
50
+ let mut params = input. generics . type_params ( ) . filter_map ( |param| {
51
+ has_interner ( param)
52
+ . map ( |ident| ( DeriveKind :: FromHasInterner , ident) )
53
+ . or_else ( || is_interner ( param) . map ( |ident| ( DeriveKind :: FromInterner , ident) ) )
54
+ } ) ;
55
55
56
- 0 => panic ! (
57
- "deriving this trait requires a single type parameter or a `#[has_interner]` attr"
58
- ) ,
56
+ let param = params. next ( ) ;
57
+ assert ! ( params. next( ) . is_none( ) , "deriving this trait only works with at most one type parameter that implements HasInterner or Interner" ) ;
59
58
60
- _ => panic ! ( "deriving this trait only works with a single type parameter" ) ,
61
- } ;
62
- & input. generics . params [ 0 ]
59
+ param
63
60
}
64
61
65
- fn get_generic_param_name ( input : & DeriveInput ) -> Option < & Ident > {
66
- match get_generic_param ( input) {
67
- GenericParam :: Type ( t) => Some ( & t. ident ) ,
68
- _ => None ,
69
- }
62
+ fn get_intern_param_name ( input : & DeriveInput ) -> & Ident {
63
+ get_intern_param ( input)
64
+ . expect ( "deriving this trait requires a parameter that implements HasInterner or Interner" )
65
+ . 1
70
66
}
71
67
72
- fn find_interner ( s : & mut synstructure:: Structure ) -> ( TokenStream , DeriveKind ) {
68
+ fn try_find_interner ( s : & mut synstructure:: Structure ) -> Option < ( TokenStream , DeriveKind ) > {
73
69
let input = s. ast ( ) ;
74
70
75
71
if let Some ( arg) = has_interner_attr ( input) {
@@ -79,35 +75,40 @@ fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
79
75
// struct S {
80
76
//
81
77
// }
82
- return ( arg, DeriveKind :: FromHasInternerAttr ) ;
78
+ return Some ( ( arg, DeriveKind :: FromHasInternerAttr ) ) ;
83
79
}
84
80
85
- let generic_param0 = get_generic_param ( input) ;
86
-
87
- if let Some ( param) = has_interner ( generic_param0) {
88
- // HasInterner bound:
89
- //
90
- // Example:
91
- //
92
- // struct Binders<T: HasInterner> { }
93
- s. add_impl_generic ( parse_quote ! { _I } ) ;
94
-
95
- s. add_where_predicate ( parse_quote ! { _I: :: chalk_ir:: interner:: Interner } ) ;
96
- s. add_where_predicate (
97
- parse_quote ! { #param: :: chalk_ir:: interner:: HasInterner <Interner = _I> } ,
98
- ) ;
81
+ get_intern_param ( input) . map ( |generic_param0| match generic_param0 {
82
+ ( DeriveKind :: FromHasInterner , param) => {
83
+ // HasInterner bound:
84
+ //
85
+ // Example:
86
+ //
87
+ // struct Binders<T: HasInterner> { }
88
+ s. add_impl_generic ( parse_quote ! { _I } ) ;
89
+
90
+ s. add_where_predicate ( parse_quote ! { _I: :: chalk_ir:: interner:: Interner } ) ;
91
+ s. add_where_predicate (
92
+ parse_quote ! { #param: :: chalk_ir:: interner:: HasInterner <Interner = _I> } ,
93
+ ) ;
94
+
95
+ ( quote ! { _I } , DeriveKind :: FromHasInterner )
96
+ }
97
+ ( DeriveKind :: FromInterner , i) => {
98
+ // Interner bound:
99
+ //
100
+ // Example:
101
+ //
102
+ // struct Foo<I: Interner> { }
103
+ ( quote ! { #i } , DeriveKind :: FromInterner )
104
+ }
105
+ _ => unreachable ! ( ) ,
106
+ } )
107
+ }
99
108
100
- ( quote ! { _I } , DeriveKind :: FromHasInterner )
101
- } else if let Some ( i) = is_interner ( generic_param0) {
102
- // Interner bound:
103
- //
104
- // Example:
105
- //
106
- // struct Foo<I: Interner> { }
107
- ( quote ! { #i } , DeriveKind :: FromInterner )
108
- } else {
109
- panic ! ( "deriving this trait requires a parameter that implements HasInterner or Interner" , ) ;
110
- }
109
+ fn find_interner ( s : & mut synstructure:: Structure ) -> ( TokenStream , DeriveKind ) {
110
+ try_find_interner ( s)
111
+ . expect ( "deriving this trait requires a `#[has_interner]` attr or a parameter that implements HasInterner or Interner" )
111
112
}
112
113
113
114
#[ derive( Copy , Clone , PartialEq ) ]
@@ -174,7 +175,7 @@ fn derive_any_type_visitable(
174
175
} ) ;
175
176
176
177
if kind == DeriveKind :: FromHasInterner {
177
- let param = get_generic_param_name ( input) . unwrap ( ) ;
178
+ let param = get_intern_param_name ( input) ;
178
179
s. add_where_predicate ( parse_quote ! { #param: :: chalk_ir:: visit:: TypeVisitable <#interner> } ) ;
179
180
}
180
181
@@ -278,7 +279,7 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
278
279
let input = s. ast ( ) ;
279
280
280
281
if kind == DeriveKind :: FromHasInterner {
281
- let param = get_generic_param_name ( input) . unwrap ( ) ;
282
+ let param = get_intern_param_name ( input) ;
282
283
s. add_where_predicate ( parse_quote ! { #param: :: chalk_ir:: fold:: TypeFoldable <#interner> } ) ;
283
284
} ;
284
285
@@ -298,7 +299,14 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
298
299
}
299
300
300
301
fn derive_fallible_type_folder ( mut s : synstructure:: Structure ) -> TokenStream {
301
- let ( interner, _) = find_interner ( & mut s) ;
302
+ let interner = try_find_interner ( & mut s) . map_or_else (
303
+ || {
304
+ s. add_impl_generic ( parse_quote ! { _I } ) ;
305
+ s. add_where_predicate ( parse_quote ! { _I: :: chalk_ir:: interner:: Interner } ) ;
306
+ quote ! { _I }
307
+ } ,
308
+ |( interner, _) | interner,
309
+ ) ;
302
310
s. underscore_const ( true ) ;
303
311
s. unbound_impl (
304
312
quote ! ( :: chalk_ir:: fold:: FallibleTypeFolder <#interner>) ,
0 commit comments