1
1
use crate :: lifetime:: CollectLifetimes ;
2
2
use crate :: parse:: Item ;
3
- use crate :: receiver:: ReplaceReceiver ;
3
+ use crate :: receiver:: { has_self_in_block , has_self_in_sig , ReplaceReceiver } ;
4
4
use proc_macro2:: { Span , TokenStream } ;
5
5
use quote:: { quote, ToTokens } ;
6
6
use std:: mem;
7
7
use syn:: punctuated:: Punctuated ;
8
8
use syn:: visit_mut:: VisitMut ;
9
9
use syn:: {
10
10
parse_quote, ArgCaptured , ArgSelfRef , Block , FnArg , GenericParam , Generics , Ident , ImplItem ,
11
- Lifetime , MethodSig , Pat , PatIdent , Path , ReturnType , Token , TraitItem , Type , TypeParamBound ,
12
- WhereClause ,
11
+ Lifetime , MethodSig , Pat , PatIdent , Path , ReturnType , Token , TraitItem , Type , TypeParam ,
12
+ TypeParamBound , WhereClause ,
13
13
} ;
14
14
15
15
impl ToTokens for Item {
@@ -47,12 +47,16 @@ pub fn expand(input: &mut Item) {
47
47
} ;
48
48
for inner in & mut input. items {
49
49
if let TraitItem :: Method ( method) = inner {
50
- if method. sig . asyncness . is_some ( ) {
51
- if let Some ( block) = & mut method. default {
52
- transform_block ( context, & method. sig , block) ;
50
+ let sig = & mut method. sig ;
51
+ if sig. asyncness . is_some ( ) {
52
+ let block = & mut method. default ;
53
+ let mut has_self = has_self_in_sig ( sig) ;
54
+ if let Some ( block) = block {
55
+ has_self |= has_self_in_block ( block) ;
56
+ transform_block ( context, sig, block, has_self) ;
53
57
}
54
58
let has_default = method. default . is_some ( ) ;
55
- transform_sig ( context, & mut method . sig , has_default) ;
59
+ transform_sig ( context, sig, has_self , has_default) ;
56
60
}
57
61
}
58
62
}
@@ -65,9 +69,12 @@ pub fn expand(input: &mut Item) {
65
69
} ;
66
70
for inner in & mut input. items {
67
71
if let ImplItem :: Method ( method) = inner {
68
- if method. sig . asyncness . is_some ( ) {
69
- transform_block ( context, & method. sig , & mut method. block ) ;
70
- transform_sig ( context, & mut method. sig , false ) ;
72
+ let sig = & mut method. sig ;
73
+ if sig. asyncness . is_some ( ) {
74
+ let block = & mut method. block ;
75
+ let has_self = has_self_in_sig ( sig) || has_self_in_block ( block) ;
76
+ transform_block ( context, sig, block, has_self) ;
77
+ transform_sig ( context, sig, has_self, false ) ;
71
78
}
72
79
}
73
80
}
@@ -88,19 +95,14 @@ pub fn expand(input: &mut Item) {
88
95
// 'life1: 'async_trait,
89
96
// T: 'async_trait,
90
97
// Self: Sync + 'async_trait;
91
- fn transform_sig ( context : Context , sig : & mut MethodSig , has_default : bool ) {
98
+ fn transform_sig ( context : Context , sig : & mut MethodSig , has_self : bool , has_default : bool ) {
92
99
sig. decl . fn_token . span = sig. asyncness . take ( ) . unwrap ( ) . span ;
93
100
94
101
let ret = match & sig. decl . output {
95
102
ReturnType :: Default => quote ! ( ( ) ) ,
96
103
ReturnType :: Type ( _, ret) => quote ! ( #ret) ,
97
104
} ;
98
105
99
- let has_self = match sig. decl . inputs . iter_mut ( ) . next ( ) {
100
- Some ( FnArg :: SelfRef ( _) ) | Some ( FnArg :: SelfValue ( _) ) => true ,
101
- _ => false ,
102
- } ;
103
-
104
106
let mut elided = CollectLifetimes :: new ( ) ;
105
107
for arg in sig. decl . inputs . iter_mut ( ) {
106
108
match arg {
@@ -146,10 +148,10 @@ fn transform_sig(context: Context, sig: &mut MethodSig, has_default: bool) {
146
148
}
147
149
sig. decl . generics . params . push ( parse_quote ! ( #lifetime) ) ;
148
150
if has_self {
149
- let bound: Ident = match & sig. decl . inputs [ 0 ] {
150
- FnArg :: SelfRef ( ArgSelfRef {
151
+ let bound: Ident = match sig. decl . inputs . iter ( ) . next ( ) {
152
+ Some ( FnArg :: SelfRef ( ArgSelfRef {
151
153
mutability : None , ..
152
- } ) => parse_quote ! ( Sync ) ,
154
+ } ) ) => parse_quote ! ( Sync ) ,
153
155
_ => parse_quote ! ( Send ) ,
154
156
} ;
155
157
let assume_bound = match context {
@@ -204,7 +206,7 @@ fn transform_sig(context: Context, sig: &mut MethodSig, has_default: bool) {
204
206
// _self + x
205
207
// }
206
208
// Pin::from(Box::new(async_trait_method::<T, Self>(self, x)))
207
- fn transform_block ( context : Context , sig : & MethodSig , block : & mut Block ) {
209
+ fn transform_block ( context : Context , sig : & mut MethodSig , block : & mut Block , has_self : bool ) {
208
210
let inner = Ident :: new ( & format ! ( "__{}" , sig. ident) , sig. ident . span ( ) ) ;
209
211
let args = sig
210
212
. decl
@@ -251,6 +253,7 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
251
253
. map ( |param| param. ident . clone ( ) )
252
254
. collect :: < Vec < _ > > ( ) ;
253
255
256
+ let mut self_bound = None :: < TypeParamBound > ;
254
257
match standalone. decl . inputs . iter_mut ( ) . next ( ) {
255
258
Some ( arg @ FnArg :: SelfRef ( _) ) => {
256
259
let ( lifetime, mutability) = match arg {
@@ -262,19 +265,14 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
262
265
_ => unreachable ! ( ) ,
263
266
} ;
264
267
match context {
265
- Context :: Trait { name , generics , .. } => {
266
- let bound = match mutability {
267
- Some ( _) => quote ! ( Send ) ,
268
- None => quote ! ( Sync ) ,
269
- } ;
268
+ Context :: Trait { .. } => {
269
+ self_bound = Some ( match mutability {
270
+ Some ( _) => parse_quote ! ( core :: marker :: Send ) ,
271
+ None => parse_quote ! ( core :: marker :: Sync ) ,
272
+ } ) ;
270
273
* arg = parse_quote ! {
271
274
_self: & #lifetime #mutability AsyncTrait
272
275
} ;
273
- let ( _, generics, _) = generics. split_for_impl ( ) ;
274
- standalone. decl . generics . params . push ( parse_quote ! {
275
- AsyncTrait : ?Sized + #name #generics + core:: marker:: #bound
276
- } ) ;
277
- types. push ( Ident :: new ( "Self" , Span :: call_site ( ) ) ) ;
278
276
}
279
277
Context :: Impl { receiver, .. } => {
280
278
* arg = parse_quote ! {
@@ -284,15 +282,11 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
284
282
}
285
283
}
286
284
Some ( arg @ FnArg :: SelfValue ( _) ) => match context {
287
- Context :: Trait { name, generics, .. } => {
285
+ Context :: Trait { .. } => {
286
+ self_bound = Some ( parse_quote ! ( core:: marker:: Send ) ) ;
288
287
* arg = parse_quote ! {
289
288
_self: AsyncTrait
290
289
} ;
291
- let ( _, generics, _) = generics. split_for_impl ( ) ;
292
- standalone. decl . generics . params . push ( parse_quote ! {
293
- AsyncTrait : ?Sized + #name #generics + core:: marker:: Send
294
- } ) ;
295
- types. push ( Ident :: new ( "Self" , Span :: call_site ( ) ) ) ;
296
290
}
297
291
Context :: Impl { receiver, .. } => {
298
292
* arg = parse_quote ! {
@@ -303,6 +297,20 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
303
297
_ => { }
304
298
}
305
299
300
+ if let Context :: Trait { name, generics, .. } = context {
301
+ if has_self {
302
+ let ( _, generics, _) = generics. split_for_impl ( ) ;
303
+ let mut self_param: TypeParam = parse_quote ! ( AsyncTrait : ?Sized + #name #generics) ;
304
+ self_param. bounds . extend ( self_bound) ;
305
+ standalone
306
+ . decl
307
+ . generics
308
+ . params
309
+ . push ( GenericParam :: Type ( self_param) ) ;
310
+ types. push ( Ident :: new ( "Self" , Span :: call_site ( ) ) ) ;
311
+ }
312
+ }
313
+
306
314
if let Some ( where_clause) = & mut standalone. decl . generics . where_clause {
307
315
// Work around an input bound like `where Self::Output: Send` expanding
308
316
// to `where <AsyncTrait>::Output: Send` which is illegal syntax because
0 commit comments