Skip to content

Commit e29e385

Browse files
authored
Merge pull request #12 from dtolnay/self
Fix Self argument in trait method without self
2 parents a93ce20 + 81304f5 commit e29e385

File tree

3 files changed

+89
-37
lines changed

3 files changed

+89
-37
lines changed

src/expand.rs

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
use crate::lifetime::CollectLifetimes;
22
use crate::parse::Item;
3-
use crate::receiver::ReplaceReceiver;
3+
use crate::receiver::{has_self_in_block, has_self_in_sig, ReplaceReceiver};
44
use proc_macro2::{Span, TokenStream};
55
use quote::{quote, ToTokens};
66
use std::mem;
77
use syn::punctuated::Punctuated;
88
use syn::visit_mut::VisitMut;
99
use syn::{
1010
parse_quote, ArgCaptured, ArgSelfRef, Block, FnArg, GenericParam, Generics, Ident, ImplItem,
11-
Lifetime, MethodSig, Pat, PatIdent, Path, ReturnType, Token, TraitItem, Type, TypeParamBound,
12-
WhereClause,
11+
Lifetime, MethodSig, Pat, PatIdent, Path, ReturnType, Token, TraitItem, Type, TypeParam,
12+
TypeParamBound, WhereClause,
1313
};
1414

1515
impl ToTokens for Item {
@@ -47,12 +47,16 @@ pub fn expand(input: &mut Item) {
4747
};
4848
for inner in &mut input.items {
4949
if let TraitItem::Method(method) = inner {
50-
if method.sig.asyncness.is_some() {
51-
if let Some(block) = &mut method.default {
52-
transform_block(context, &method.sig, block);
50+
let sig = &mut method.sig;
51+
if sig.asyncness.is_some() {
52+
let block = &mut method.default;
53+
let mut has_self = has_self_in_sig(sig);
54+
if let Some(block) = block {
55+
has_self |= has_self_in_block(block);
56+
transform_block(context, sig, block, has_self);
5357
}
5458
let has_default = method.default.is_some();
55-
transform_sig(context, &mut method.sig, has_default);
59+
transform_sig(context, sig, has_self, has_default);
5660
}
5761
}
5862
}
@@ -65,9 +69,12 @@ pub fn expand(input: &mut Item) {
6569
};
6670
for inner in &mut input.items {
6771
if let ImplItem::Method(method) = inner {
68-
if method.sig.asyncness.is_some() {
69-
transform_block(context, &method.sig, &mut method.block);
70-
transform_sig(context, &mut method.sig, false);
72+
let sig = &mut method.sig;
73+
if sig.asyncness.is_some() {
74+
let block = &mut method.block;
75+
let has_self = has_self_in_sig(sig) || has_self_in_block(block);
76+
transform_block(context, sig, block, has_self);
77+
transform_sig(context, sig, has_self, false);
7178
}
7279
}
7380
}
@@ -88,19 +95,14 @@ pub fn expand(input: &mut Item) {
8895
// 'life1: 'async_trait,
8996
// T: 'async_trait,
9097
// Self: Sync + 'async_trait;
91-
fn transform_sig(context: Context, sig: &mut MethodSig, has_default: bool) {
98+
fn transform_sig(context: Context, sig: &mut MethodSig, has_self: bool, has_default: bool) {
9299
sig.decl.fn_token.span = sig.asyncness.take().unwrap().span;
93100

94101
let ret = match &sig.decl.output {
95102
ReturnType::Default => quote!(()),
96103
ReturnType::Type(_, ret) => quote!(#ret),
97104
};
98105

99-
let has_self = match sig.decl.inputs.iter_mut().next() {
100-
Some(FnArg::SelfRef(_)) | Some(FnArg::SelfValue(_)) => true,
101-
_ => false,
102-
};
103-
104106
let mut elided = CollectLifetimes::new();
105107
for arg in sig.decl.inputs.iter_mut() {
106108
match arg {
@@ -146,10 +148,10 @@ fn transform_sig(context: Context, sig: &mut MethodSig, has_default: bool) {
146148
}
147149
sig.decl.generics.params.push(parse_quote!(#lifetime));
148150
if has_self {
149-
let bound: Ident = match &sig.decl.inputs[0] {
150-
FnArg::SelfRef(ArgSelfRef {
151+
let bound: Ident = match sig.decl.inputs.iter().next() {
152+
Some(FnArg::SelfRef(ArgSelfRef {
151153
mutability: None, ..
152-
}) => parse_quote!(Sync),
154+
})) => parse_quote!(Sync),
153155
_ => parse_quote!(Send),
154156
};
155157
let assume_bound = match context {
@@ -204,7 +206,7 @@ fn transform_sig(context: Context, sig: &mut MethodSig, has_default: bool) {
204206
// _self + x
205207
// }
206208
// Pin::from(Box::new(async_trait_method::<T, Self>(self, x)))
207-
fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
209+
fn transform_block(context: Context, sig: &mut MethodSig, block: &mut Block, has_self: bool) {
208210
let inner = Ident::new(&format!("__{}", sig.ident), sig.ident.span());
209211
let args = sig
210212
.decl
@@ -251,6 +253,7 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
251253
.map(|param| param.ident.clone())
252254
.collect::<Vec<_>>();
253255

256+
let mut self_bound = None::<TypeParamBound>;
254257
match standalone.decl.inputs.iter_mut().next() {
255258
Some(arg @ FnArg::SelfRef(_)) => {
256259
let (lifetime, mutability) = match arg {
@@ -262,19 +265,14 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
262265
_ => unreachable!(),
263266
};
264267
match context {
265-
Context::Trait { name, generics, .. } => {
266-
let bound = match mutability {
267-
Some(_) => quote!(Send),
268-
None => quote!(Sync),
269-
};
268+
Context::Trait { .. } => {
269+
self_bound = Some(match mutability {
270+
Some(_) => parse_quote!(core::marker::Send),
271+
None => parse_quote!(core::marker::Sync),
272+
});
270273
*arg = parse_quote! {
271274
_self: &#lifetime #mutability AsyncTrait
272275
};
273-
let (_, generics, _) = generics.split_for_impl();
274-
standalone.decl.generics.params.push(parse_quote! {
275-
AsyncTrait: ?Sized + #name #generics + core::marker::#bound
276-
});
277-
types.push(Ident::new("Self", Span::call_site()));
278276
}
279277
Context::Impl { receiver, .. } => {
280278
*arg = parse_quote! {
@@ -284,15 +282,11 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
284282
}
285283
}
286284
Some(arg @ FnArg::SelfValue(_)) => match context {
287-
Context::Trait { name, generics, .. } => {
285+
Context::Trait { .. } => {
286+
self_bound = Some(parse_quote!(core::marker::Send));
288287
*arg = parse_quote! {
289288
_self: AsyncTrait
290289
};
291-
let (_, generics, _) = generics.split_for_impl();
292-
standalone.decl.generics.params.push(parse_quote! {
293-
AsyncTrait: ?Sized + #name #generics + core::marker::Send
294-
});
295-
types.push(Ident::new("Self", Span::call_site()));
296290
}
297291
Context::Impl { receiver, .. } => {
298292
*arg = parse_quote! {
@@ -303,6 +297,20 @@ fn transform_block(context: Context, sig: &MethodSig, block: &mut Block) {
303297
_ => {}
304298
}
305299

300+
if let Context::Trait { name, generics, .. } = context {
301+
if has_self {
302+
let (_, generics, _) = generics.split_for_impl();
303+
let mut self_param: TypeParam = parse_quote!(AsyncTrait: ?Sized + #name #generics);
304+
self_param.bounds.extend(self_bound);
305+
standalone
306+
.decl
307+
.generics
308+
.params
309+
.push(GenericParam::Type(self_param));
310+
types.push(Ident::new("Self", Span::call_site()));
311+
}
312+
}
313+
306314
if let Some(where_clause) = &mut standalone.decl.generics.where_clause {
307315
// Work around an input bound like `where Self::Output: Send` expanding
308316
// to `where <AsyncTrait>::Output: Send` which is illegal syntax because

src/receiver.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,41 @@
11
use std::mem;
22
use syn::punctuated::Punctuated;
33
use syn::visit_mut::{self, VisitMut};
4-
use syn::{parse_quote, ExprPath, Item, Path, QSelf, Type, TypePath};
4+
use syn::{
5+
parse_quote, ArgSelf, ArgSelfRef, Block, ExprPath, Item, MethodSig, Path, QSelf, Type, TypePath,
6+
};
7+
8+
pub fn has_self_in_sig(sig: &mut MethodSig) -> bool {
9+
let mut visitor = HasSelf(false);
10+
visitor.visit_method_sig_mut(sig);
11+
visitor.0
12+
}
13+
14+
pub fn has_self_in_block(block: &mut Block) -> bool {
15+
let mut visitor = HasSelf(false);
16+
visitor.visit_block_mut(block);
17+
visitor.0
18+
}
19+
20+
struct HasSelf(bool);
21+
22+
impl VisitMut for HasSelf {
23+
fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
24+
self.0 |= ty.path.segments[0].ident == "Self";
25+
}
26+
27+
fn visit_arg_self_mut(&mut self, _arg: &mut ArgSelf) {
28+
self.0 = true;
29+
}
30+
31+
fn visit_arg_self_ref_mut(&mut self, _arg: &mut ArgSelfRef) {
32+
self.0 = true;
33+
}
34+
35+
fn visit_item_mut(&mut self, _: &mut Item) {
36+
// Do not recurse into nested items.
37+
}
38+
}
539

640
pub struct ReplaceReceiver {
741
pub with: Type,

tests/test.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,13 @@ mod issue2 {
148148
}
149149
}
150150
}
151+
152+
// https://github.com/dtolnay/async-trait/issues/9
153+
mod issue9 {
154+
use async_trait::async_trait;
155+
156+
#[async_trait]
157+
pub trait Issue3: Sized + Send {
158+
async fn f(_x: Self) {}
159+
}
160+
}

0 commit comments

Comments
 (0)