|
| 1 | +use std::mem::replace; |
| 2 | + |
1 | 3 | use crate::syn_utils::{Arg, Args};
|
2 |
| -use proc_macro2::TokenStream; |
| 4 | +use proc_macro2::{Span, TokenStream}; |
3 | 5 | use quote::{quote, ToTokens};
|
4 | 6 | use syn::{
|
5 | 7 | parse2, parse_quote, parse_str, spanned::Spanned, token, Block, Expr, Field, FieldMutability,
|
6 |
| - FnArg, Ident, ItemFn, LitStr, Pat, Result, Visibility, |
| 8 | + FnArg, Ident, ItemFn, LitStr, Pat, Result, ReturnType, Visibility, |
7 | 9 | };
|
8 | 10 |
|
9 | 11 | pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStream> {
|
@@ -35,11 +37,23 @@ pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStr
|
35 | 37 | if item_fn.sig.asyncness.is_none() {
|
36 | 38 | attr_args.r#async = None;
|
37 | 39 | }
|
| 40 | + let output = replace(&mut item_fn.sig.output, ReturnType::Default); |
38 | 41 | let block = if let Some(a) = attr_args.r#async {
|
39 | 42 | item_fn.sig.asyncness = None;
|
40 |
| - a.apply(block) |
| 43 | + a.apply(block, output) |
41 | 44 | } else {
|
42 |
| - quote!(#block) |
| 45 | + match output { |
| 46 | + ReturnType::Default => quote!(#block), |
| 47 | + ReturnType::Type(_, ty) => { |
| 48 | + let f = Ident::new("__test_body", Span::mixed_site()); |
| 49 | + quote!({ |
| 50 | + let #f = move || -> #ty { |
| 51 | + #block |
| 52 | + }; |
| 53 | + ::std::result::Result::map_err(#f(), ::std::convert::Into::<TestCaseError>::into)?; |
| 54 | + }) |
| 55 | + } |
| 56 | + } |
43 | 57 | };
|
44 | 58 | let block = quote! {
|
45 | 59 | {
|
@@ -136,28 +150,43 @@ enum Async {
|
136 | 150 | Expr(Expr),
|
137 | 151 | }
|
138 | 152 | impl Async {
|
139 |
| - fn apply(&self, block: &Block) -> TokenStream { |
| 153 | + fn apply(&self, block: &Block, output: ReturnType) -> TokenStream { |
| 154 | + let body; |
| 155 | + let output_type; |
| 156 | + let ret_expr; |
| 157 | + match output { |
| 158 | + ReturnType::Default => { |
| 159 | + body = quote! { |
| 160 | + #block |
| 161 | + Ok(()) |
| 162 | + }; |
| 163 | + output_type = |
| 164 | + quote!(::core::result::Result<_, ::proptest::test_runner::TestCaseError> ); |
| 165 | + ret_expr = quote! { ret? }; |
| 166 | + } |
| 167 | + ReturnType::Type(_, ty) => { |
| 168 | + body = quote! { #block }; |
| 169 | + output_type = quote!(#ty); |
| 170 | + ret_expr = quote! { |
| 171 | + ::std::result::Result::map_err(ret, ::std::convert::Into::<TestCaseError>::into)? |
| 172 | + }; |
| 173 | + } |
| 174 | + } |
140 | 175 | match self {
|
141 | 176 | Async::Tokio => {
|
142 | 177 | quote! {
|
143 |
| - let ret: ::core::result::Result<_, ::proptest::test_runner::TestCaseError> = |
| 178 | + let ret: #output_type = |
144 | 179 | tokio::runtime::Runtime::new()
|
145 | 180 | .unwrap()
|
146 |
| - .block_on(async move { |
147 |
| - #block |
148 |
| - Ok(()) |
149 |
| - }); |
150 |
| - ret?; |
| 181 | + .block_on(async move { #body }); |
| 182 | + #ret_expr; |
151 | 183 | }
|
152 | 184 | }
|
153 | 185 | Async::Expr(expr) => {
|
154 | 186 | quote! {
|
155 |
| - let ret: ::core::result::Result<(), ::proptest::test_runner::TestCaseError> = |
156 |
| - (#expr)(async move { |
157 |
| - #block |
158 |
| - Ok(()) |
159 |
| - }); |
160 |
| - ret?; |
| 187 | + let ret: #output_type = |
| 188 | + (#expr)(async move { #body }); |
| 189 | + #ret_expr; |
161 | 190 | }
|
162 | 191 | }
|
163 | 192 | }
|
|
0 commit comments