Skip to content

Commit 9b7e9a7

Browse files
Expand 'yield' in internal macro calls. (#57)
1 parent 8c22349 commit 9b7e9a7

File tree

4 files changed

+162
-3
lines changed

4 files changed

+162
-3
lines changed

async-stream-impl/src/lib.rs

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use proc_macro::TokenStream;
22
use proc_macro2::{Group, TokenStream as TokenStream2, TokenTree};
33
use quote::quote;
4-
use syn::parse::Parser;
4+
use syn::parse::{Parse, ParseStream, Parser, Result};
55
use syn::visit_mut::VisitMut;
66

77
struct Scrub<'a> {
@@ -34,6 +34,80 @@ impl<'a> Scrub<'a> {
3434
}
3535
}
3636

37+
struct Partial<T>(T, TokenStream2);
38+
39+
impl<T: Parse> Parse for Partial<T> {
40+
fn parse(input: ParseStream) -> Result<Self> {
41+
Ok(Partial(input.parse()?, input.parse()?))
42+
}
43+
}
44+
45+
fn visit_token_stream_impl(
46+
visitor: &mut Scrub<'_>,
47+
tokens: TokenStream2,
48+
modified: &mut bool,
49+
out: &mut TokenStream2,
50+
) {
51+
use quote::ToTokens;
52+
use quote::TokenStreamExt;
53+
54+
let mut tokens = tokens.into_iter().peekable();
55+
while let Some(tt) = tokens.next() {
56+
match tt {
57+
TokenTree::Ident(i) if i == "yield" => {
58+
let stream = std::iter::once(TokenTree::Ident(i)).chain(tokens).collect();
59+
match syn::parse2(stream) {
60+
Ok(Partial(yield_expr, rest)) => {
61+
let mut expr = syn::Expr::Yield(yield_expr);
62+
visitor.visit_expr_mut(&mut expr);
63+
expr.to_tokens(out);
64+
*modified = true;
65+
tokens = rest.into_iter().peekable();
66+
}
67+
Err(e) => {
68+
out.append_all(&mut e.to_compile_error().into_iter());
69+
*modified = true;
70+
return;
71+
}
72+
}
73+
}
74+
TokenTree::Ident(i) if i == "stream" || i == "try_stream" => {
75+
out.append(TokenTree::Ident(i));
76+
match tokens.peek() {
77+
Some(TokenTree::Punct(p)) if p.as_char() == '!' => {
78+
out.extend(tokens.next()); // !
79+
if let Some(TokenTree::Group(_)) = tokens.peek() {
80+
out.extend(tokens.next()); // { .. } or [ .. ] or ( .. )
81+
}
82+
}
83+
_ => {}
84+
}
85+
}
86+
TokenTree::Group(group) => {
87+
let mut content = group.stream();
88+
*modified |= visitor.visit_token_stream(&mut content);
89+
let mut new = Group::new(group.delimiter(), content);
90+
new.set_span(group.span());
91+
out.append(new);
92+
}
93+
other => out.append(other),
94+
}
95+
}
96+
}
97+
98+
impl Scrub<'_> {
99+
fn visit_token_stream(&mut self, tokens: &mut TokenStream2) -> bool {
100+
let (mut out, mut modified) = (TokenStream2::new(), false);
101+
visit_token_stream_impl(self, tokens.clone(), &mut modified, &mut out);
102+
103+
if modified {
104+
*tokens = out;
105+
}
106+
107+
modified
108+
}
109+
}
110+
37111
impl VisitMut for Scrub<'_> {
38112
fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
39113
match i {
@@ -109,8 +183,20 @@ impl VisitMut for Scrub<'_> {
109183
}
110184
}
111185

112-
fn visit_item_mut(&mut self, _: &mut syn::Item) {
113-
// Don't transform inner items.
186+
fn visit_macro_mut(&mut self, mac: &mut syn::Macro) {
187+
let mac_ident = mac.path.segments.last().map(|p| &p.ident);
188+
if mac_ident.map_or(false, |i| i == "stream" || i == "try_stream") {
189+
return;
190+
}
191+
192+
self.visit_token_stream(&mut mac.tokens);
193+
}
194+
195+
fn visit_item_mut(&mut self, i: &mut syn::Item) {
196+
// Recurse into macros but otherwise don't transform inner items.
197+
if let syn::Item::Macro(i) = i {
198+
self.visit_macro_mut(&mut i.mac);
199+
}
114200
}
115201
}
116202

async-stream/tests/stream.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,42 @@ async fn yield_multi_value() {
8080
assert_eq!("dizzy", values[2]);
8181
}
8282

83+
#[tokio::test]
84+
async fn unit_yield_in_select() {
85+
use tokio::select;
86+
87+
async fn do_stuff_async() {}
88+
89+
let s = stream! {
90+
select! {
91+
_ = do_stuff_async() => yield,
92+
else => yield,
93+
}
94+
};
95+
96+
let values: Vec<_> = s.collect().await;
97+
assert_eq!(values.len(), 1);
98+
}
99+
100+
#[tokio::test]
101+
async fn yield_with_select() {
102+
use tokio::select;
103+
104+
async fn do_stuff_async() {}
105+
async fn more_async_work() {}
106+
107+
let s = stream! {
108+
select! {
109+
_ = do_stuff_async() => yield "hey",
110+
_ = more_async_work() => yield "hey",
111+
else => yield "hey",
112+
}
113+
};
114+
115+
let values: Vec<_> = s.collect().await;
116+
assert_eq!(values, vec!["hey"]);
117+
}
118+
83119
#[tokio::test]
84120
async fn return_stream() {
85121
fn build_stream() -> impl Stream<Item = u32> {
@@ -172,6 +208,27 @@ async fn yield_non_unpin_value() {
172208
assert_eq!(s, vec![0, 1, 2]);
173209
}
174210

211+
#[test]
212+
fn inner_try_stream() {
213+
use async_stream::try_stream;
214+
use tokio::select;
215+
216+
async fn do_stuff_async() {}
217+
218+
let _ = stream! {
219+
select! {
220+
_ = do_stuff_async() => {
221+
let another_s = try_stream! {
222+
yield;
223+
};
224+
let _: Result<(), ()> = Box::pin(another_s).next().await.unwrap();
225+
},
226+
else => {},
227+
}
228+
yield
229+
};
230+
}
231+
175232
#[test]
176233
fn test() {
177234
let t = trybuild::TestCases::new();
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
use async_stream::stream;
2+
3+
fn main() {
4+
async fn work() {}
5+
6+
stream! {
7+
tokio::select! {
8+
_ = work() => yield fn f() {},
9+
}
10+
};
11+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
error: expected expression
2+
--> $DIR/yield_bad_expr_in_macro.rs:8:33
3+
|
4+
8 | _ = work() => yield fn f() {},
5+
| ^^

0 commit comments

Comments
 (0)