Skip to content

Commit 623c556

Browse files
authored
Let the ? operator work natively in try_stream!. (#53)
Insteads of desugaring `?` in the macro, we can have the async block itself return `Result<(), E>`, and adjust the supporting code so that `?` just works. The benefit is that this allows `?` operators that are hidden behind macros.
1 parent 22a36ee commit 623c556

File tree

4 files changed

+98
-29
lines changed

4 files changed

+98
-29
lines changed

async-stream-impl/src/lib.rs

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ use syn::parse::Parser;
55
use syn::visit_mut::VisitMut;
66

77
struct Scrub<'a> {
8-
/// Whether the stream is a try stream.
9-
is_try: bool,
108
/// The unit expression, `()`.
119
unit: Box<syn::Expr>,
1210
has_yielded: bool,
@@ -24,9 +22,8 @@ fn parse_input(input: TokenStream) -> syn::Result<(TokenStream2, Vec<syn::Stmt>)
2422
}
2523

2624
impl<'a> Scrub<'a> {
27-
fn new(is_try: bool, crate_path: &'a TokenStream2) -> Self {
25+
fn new(crate_path: &'a TokenStream2) -> Self {
2826
Self {
29-
is_try,
3027
unit: syn::parse_quote!(()),
3128
has_yielded: false,
3229
crate_path,
@@ -44,26 +41,7 @@ impl VisitMut for Scrub<'_> {
4441

4542
// let ident = &self.yielder;
4643

47-
*i = if self.is_try {
48-
syn::parse_quote! { __yield_tx.send(::core::result::Result::Ok(#value_expr)).await }
49-
} else {
50-
syn::parse_quote! { __yield_tx.send(#value_expr).await }
51-
};
52-
}
53-
syn::Expr::Try(try_expr) => {
54-
syn::visit_mut::visit_expr_try_mut(self, try_expr);
55-
// let ident = &self.yielder;
56-
let e = &try_expr.expr;
57-
58-
*i = syn::parse_quote! {
59-
match #e {
60-
::core::result::Result::Ok(v) => v,
61-
::core::result::Result::Err(e) => {
62-
__yield_tx.send(::core::result::Result::Err(e.into())).await;
63-
return;
64-
}
65-
}
66-
};
44+
*i = syn::parse_quote! { __yield_tx.send(#value_expr).await };
6745
}
6846
syn::Expr::Closure(_) | syn::Expr::Async(_) => {
6947
// Don't transform inner closures or async blocks.
@@ -124,7 +102,7 @@ pub fn stream_inner(input: TokenStream) -> TokenStream {
124102
Err(e) => return e.to_compile_error().into(),
125103
};
126104

127-
let mut scrub = Scrub::new(false, &crate_path);
105+
let mut scrub = Scrub::new(&crate_path);
128106

129107
for mut stmt in &mut stmts {
130108
scrub.visit_stmt_mut(&mut stmt);
@@ -158,7 +136,7 @@ pub fn try_stream_inner(input: TokenStream) -> TokenStream {
158136
Err(e) => return e.to_compile_error().into(),
159137
};
160138

161-
let mut scrub = Scrub::new(true, &crate_path);
139+
let mut scrub = Scrub::new(&crate_path);
162140

163141
for mut stmt in &mut stmts {
164142
scrub.visit_stmt_mut(&mut stmt);
@@ -174,9 +152,13 @@ pub fn try_stream_inner(input: TokenStream) -> TokenStream {
174152

175153
quote!({
176154
let (mut __yield_tx, __yield_rx) = #crate_path::yielder::pair();
177-
#crate_path::AsyncStream::new(__yield_rx, async move {
155+
#crate_path::AsyncTryStream::new(__yield_rx, async move {
178156
#dummy_yield
179-
#(#stmts)*
157+
let () = {
158+
#(#stmts)*
159+
};
160+
#[allow(unreachable_code)]
161+
Ok(())
180162
})
181163
})
182164
.into()

async-stream/src/async_stream.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,70 @@ where
7575
}
7676
}
7777
}
78+
79+
#[doc(hidden)]
80+
#[derive(Debug)]
81+
pub struct AsyncTryStream<T, U> {
82+
rx: Receiver<T>,
83+
done: bool,
84+
generator: U,
85+
}
86+
87+
impl<T, U> AsyncTryStream<T, U> {
88+
#[doc(hidden)]
89+
pub fn new(rx: Receiver<T>, generator: U) -> AsyncTryStream<T, U> {
90+
AsyncTryStream {
91+
rx,
92+
done: false,
93+
generator,
94+
}
95+
}
96+
}
97+
98+
impl<T, U, E> FusedStream for AsyncTryStream<T, U>
99+
where
100+
U: Future<Output = Result<(), E>>,
101+
{
102+
fn is_terminated(&self) -> bool {
103+
self.done
104+
}
105+
}
106+
107+
impl<T, U, E> Stream for AsyncTryStream<T, U>
108+
where
109+
U: Future<Output = Result<(), E>>,
110+
{
111+
type Item = Result<T, E>;
112+
113+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
114+
unsafe {
115+
let me = Pin::get_unchecked_mut(self);
116+
117+
if me.done {
118+
return Poll::Ready(None);
119+
}
120+
121+
let mut dst = None;
122+
let res = {
123+
let _enter = me.rx.enter(&mut dst);
124+
Pin::new_unchecked(&mut me.generator).poll(cx)
125+
};
126+
127+
me.done = res.is_ready();
128+
129+
if let Poll::Ready(Err(e)) = res {
130+
return Poll::Ready(Some(Err(e)));
131+
}
132+
133+
if let Some(val) = dst.take() {
134+
return Poll::Ready(Some(Ok(val)));
135+
}
136+
137+
if me.done {
138+
Poll::Ready(None)
139+
} else {
140+
Poll::Pending
141+
}
142+
}
143+
}
144+
}

async-stream/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ pub mod yielder;
164164

165165
// Used by the macro, but not intended to be accessed publicly.
166166
#[doc(hidden)]
167-
pub use crate::async_stream::AsyncStream;
167+
pub use crate::async_stream::{AsyncStream, AsyncTryStream};
168168

169169
#[doc(hidden)]
170170
pub use async_stream_impl;

async-stream/tests/try_stream.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_stream::try_stream;
22

33
use futures_core::stream::Stream;
4+
use futures_util::pin_mut;
45
use futures_util::stream::StreamExt;
56

67
#[tokio::test]
@@ -78,3 +79,22 @@ async fn multi_try() {
7879
values
7980
);
8081
}
82+
83+
macro_rules! try_macro {
84+
($e:expr) => {
85+
$e?
86+
};
87+
}
88+
89+
#[tokio::test]
90+
async fn try_in_macro() {
91+
let s = try_stream! {
92+
yield "hi";
93+
try_macro!(Err("bye"));
94+
};
95+
pin_mut!(s);
96+
97+
assert_eq!(s.next().await, Some(Ok("hi")));
98+
assert_eq!(s.next().await, Some(Err("bye")));
99+
assert_eq!(s.next().await, None);
100+
}

0 commit comments

Comments
 (0)