1
- use proc_macro2:: TokenStream ;
1
+ use proc_macro2:: { Ident , Span , TokenStream } ;
2
2
use quote:: quote;
3
3
use syn:: punctuated:: Punctuated ;
4
- use syn:: token:: { Brace , Colon } ;
5
- use syn:: { FieldsNamed , FnArg , ItemFn } ;
4
+ use syn:: token:: { Colon , PathSep } ;
5
+ use syn:: {
6
+ ConstParam , FnArg , GenericParam , ItemFn , LifetimeParam , Pat , PatIdent , PatType , Path ,
7
+ PathSegment , Type , TypeParam , TypePath ,
8
+ } ;
6
9
7
10
#[ proc_macro_attribute]
8
11
pub fn with_simd (
@@ -40,20 +43,13 @@ pub fn with_simd(
40
43
block,
41
44
} = item. clone ( ) ;
42
45
43
- let mut struct_generics = Punctuated :: new ( ) ;
44
- let mut struct_generics_lifetimes = Vec :: new ( ) ;
45
- let mut struct_generics_names = Vec :: new ( ) ;
46
+ let mut struct_generics = Vec :: new ( ) ;
46
47
let mut struct_field_names = Vec :: new ( ) ;
47
- let mut struct_fields = FieldsNamed {
48
- brace_token : Brace {
49
- span : sig. paren_token . span ,
50
- } ,
51
- named : Punctuated :: new ( ) ,
52
- } ;
48
+ let mut struct_field_types = Vec :: new ( ) ;
53
49
54
50
let mut first_non_lifetime = usize:: MAX ;
55
51
for ( idx, param) in sig. generics . params . clone ( ) . into_pairs ( ) . enumerate ( ) {
56
- let ( param, comma ) = param. into_tuple ( ) ;
52
+ let ( param, _ ) = param. into_tuple ( ) ;
57
53
match & param {
58
54
syn:: GenericParam :: Lifetime ( _) => { }
59
55
_ => {
@@ -63,17 +59,7 @@ pub fn with_simd(
63
59
}
64
60
}
65
61
}
66
- match & param {
67
- syn:: GenericParam :: Type ( ty) => struct_generics_names. push ( ty. ident . clone ( ) ) ,
68
- syn:: GenericParam :: Lifetime ( lt) => struct_generics_lifetimes. push ( lt. lifetime . clone ( ) ) ,
69
- syn:: GenericParam :: Const ( const_) => struct_generics_names. push ( const_. ident . clone ( ) ) ,
70
- } ;
71
- struct_generics. push_value ( param) ;
72
- if let Some ( comma) = comma {
73
- struct_generics. push_punct ( comma) ;
74
- }
75
62
}
76
-
77
63
let mut new_fn_sig = sig. clone ( ) ;
78
64
new_fn_sig. generics . params = new_fn_sig
79
65
. generics
@@ -83,49 +69,61 @@ pub fn with_simd(
83
69
. filter ( |( idx, _) | * idx != first_non_lifetime)
84
70
. map ( |( _, arg) | arg)
85
71
. collect ( ) ;
86
- new_fn_sig. inputs = new_fn_sig. inputs . into_iter ( ) . skip ( 1 ) . collect ( ) ;
72
+ new_fn_sig. inputs = new_fn_sig
73
+ . inputs
74
+ . into_iter ( )
75
+ . skip ( 1 )
76
+ . enumerate ( )
77
+ . map ( |( idx, arg) | {
78
+ FnArg :: Typed ( PatType {
79
+ attrs : Vec :: new ( ) ,
80
+ pat : Box :: new ( Pat :: Ident ( PatIdent {
81
+ attrs : Vec :: new ( ) ,
82
+ by_ref : None ,
83
+ mutability : None ,
84
+ ident : Ident :: new ( & format ! ( "__{idx}" ) , Span :: call_site ( ) ) ,
85
+ subpat : None ,
86
+ } ) ) ,
87
+ colon_token : Colon {
88
+ spans : [ Span :: call_site ( ) ] ,
89
+ } ,
90
+ ty : match arg {
91
+ FnArg :: Typed ( ty) => ty. ty ,
92
+ FnArg :: Receiver ( _) => panic ! ( ) ,
93
+ } ,
94
+ } )
95
+ } )
96
+ . collect ( ) ;
87
97
new_fn_sig. ident = name. clone ( ) ;
98
+ let mut param_ty = Vec :: new ( ) ;
88
99
89
- for param in sig . inputs . clone ( ) . into_pairs ( ) . skip ( 1 ) {
90
- let ( param, comma ) = param. into_tuple ( ) ;
100
+ for ( idx , param) in new_fn_sig . inputs . clone ( ) . into_pairs ( ) . enumerate ( ) {
101
+ let ( param, _ ) = param. into_tuple ( ) ;
91
102
let FnArg :: Typed ( param) = param. clone ( ) else {
92
- return quote ! {
93
- :: core:: compile_error!( :: core:: concat!(
94
- "pulp::with_simd only accepts free functions"
95
- ) ) ;
96
- #item
97
- }
98
- . into ( ) ;
103
+ panic ! ( ) ;
99
104
} ;
100
-
101
105
let name = * param. pat ;
102
106
let syn:: Pat :: Ident ( name) = name else {
103
- return quote ! {
104
- :: core:: compile_error!( :: core:: concat!(
105
- "pulp::with_simd requires function parameters to be idents"
106
- ) ) ;
107
- #item
108
- }
109
- . into ( ) ;
107
+ panic ! ( ) ;
110
108
} ;
111
109
110
+ let anon_ty = Ident :: new ( & format ! ( "__T{idx}" ) , Span :: call_site ( ) ) ;
111
+
112
112
struct_field_names. push ( name. ident . clone ( ) ) ;
113
- let field = syn:: Field {
114
- attrs : param. attrs ,
115
- vis : syn:: Visibility :: Public ( syn:: token:: Pub {
116
- span : proc_macro2:: Span :: call_site ( ) ,
117
- } ) ,
118
- mutability : syn:: FieldMutability :: None ,
119
- ident : Some ( name. ident ) ,
120
- colon_token : Some ( Colon {
121
- spans : [ proc_macro2:: Span :: call_site ( ) ] ,
122
- } ) ,
123
- ty : * param. ty ,
124
- } ;
125
- struct_fields. named . push_value ( field) ;
126
- if let Some ( comma) = comma {
127
- struct_fields. named . push_punct ( comma) ;
128
- }
113
+ let mut ty = Punctuated :: < _ , PathSep > :: new ( ) ;
114
+ ty. push_value ( PathSegment {
115
+ ident : anon_ty. clone ( ) ,
116
+ arguments : syn:: PathArguments :: None ,
117
+ } ) ;
118
+ struct_field_types. push ( Type :: Path ( TypePath {
119
+ qself : None ,
120
+ path : Path {
121
+ leading_colon : None ,
122
+ segments : ty,
123
+ } ,
124
+ } ) ) ;
125
+ struct_generics. push ( anon_ty) ;
126
+ param_ty. push ( * param. ty ) ;
129
127
}
130
128
131
129
let output_ty = match sig. output . clone ( ) {
@@ -136,33 +134,79 @@ pub fn with_simd(
136
134
let fn_name = sig. ident . clone ( ) ;
137
135
138
136
let arch = attr. value ;
137
+ let new_fn_generics = new_fn_sig. generics . clone ( ) ;
138
+ let params = new_fn_generics. params . clone ( ) ;
139
+ let generics = params. into_iter ( ) . collect :: < Vec < _ > > ( ) ;
140
+ let non_lt_generics_names = generics
141
+ . iter ( )
142
+ . map ( |p| match p {
143
+ GenericParam :: Type ( TypeParam { ident, .. } )
144
+ | GenericParam :: Const ( ConstParam { ident, .. } ) => {
145
+ quote ! { #ident, }
146
+ }
147
+ _ => quote ! { } ,
148
+ } )
149
+ . collect :: < Vec < _ > > ( ) ;
150
+ let generics_decl = generics
151
+ . iter ( )
152
+ . map ( |p| match p {
153
+ GenericParam :: Lifetime ( LifetimeParam {
154
+ lifetime,
155
+ colon_token,
156
+ bounds,
157
+ ..
158
+ } ) => {
159
+ quote ! { #lifetime #colon_token #bounds }
160
+ }
161
+ GenericParam :: Type ( TypeParam {
162
+ ident,
163
+ colon_token,
164
+ bounds,
165
+ ..
166
+ } ) => {
167
+ quote ! { #ident #colon_token #bounds }
168
+ }
169
+ GenericParam :: Const ( ConstParam {
170
+ const_token,
171
+ ident,
172
+ colon_token,
173
+ ty,
174
+ ..
175
+ } ) => {
176
+ quote ! { #const_token #ident #colon_token #ty }
177
+ }
178
+ } )
179
+ . collect :: < Vec < _ > > ( ) ;
180
+ let generics_where_clause = new_fn_generics. where_clause ;
139
181
140
- quote ! {
182
+ let code = quote ! {
141
183
#( #attrs) *
142
184
#vis #new_fn_sig {
143
185
#[ allow( non_camel_case_types) ]
144
- struct #name<#struct_generics> #struct_fields
186
+ struct #name<#( # struct_generics, ) * > ( # ( #struct_field_types , ) * ) ;
145
187
146
- impl <#struct_generics> :: pulp:: WithSimd for #name<#( #struct_generics_lifetimes, ) *
147
- #( #struct_generics_names, ) * > { type Output = #output_ty;
188
+ impl <#( #generics_decl, ) * > :: pulp:: WithSimd for #name<
189
+ #( #param_ty, ) *
190
+ > #generics_where_clause {
191
+ type Output = #output_ty;
148
192
149
193
#[ inline( always) ]
150
- fn with_simd<__S: :: pulp:: Simd >( self , __simd: __S) -> <Self as
151
- :: pulp :: WithSimd > :: Output { let Self { #( #struct_field_names, ) * } = self ;
194
+ fn with_simd<__S: :: pulp:: Simd >( self , __simd: __S) -> <Self as :: pulp :: WithSimd > :: Output {
195
+ let Self ( #( #struct_field_names, ) * ) = self ;
152
196
#[ allow( unused_unsafe) ]
153
197
unsafe {
154
198
#fn_name:: <__S,
155
- #( #struct_generics_names , ) *
199
+ #( #non_lt_generics_names ) *
156
200
>( __simd, #( #struct_field_names, ) * )
157
201
}
158
202
}
159
203
}
160
204
161
- ( #arch) . dispatch( #name:: <# ( #struct_generics_names , ) * > { #( #struct_field_names, ) * } )
205
+ ( #arch) . dispatch( #name ( #( #struct_field_names, ) * ) )
162
206
}
163
207
164
208
#( #attrs) *
165
209
#vis #sig #block
166
- }
167
- . into ( )
210
+ } ;
211
+ code . into ( )
168
212
}
0 commit comments