Skip to content

Commit 2459c8d

Browse files
committed
Support #[proptest] attribute for functions with explicit return types.
1 parent ac9e736 commit 2459c8d

File tree

2 files changed

+107
-17
lines changed

2 files changed

+107
-17
lines changed

src/proptest_fn.rs

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
use std::mem::replace;
2+
13
use crate::syn_utils::{Arg, Args};
2-
use proc_macro2::TokenStream;
4+
use proc_macro2::{Span, TokenStream};
35
use quote::{quote, ToTokens};
46
use syn::{
57
parse2, parse_quote, parse_str, spanned::Spanned, token, Block, Expr, Field, FieldMutability,
6-
FnArg, Ident, ItemFn, LitStr, Pat, Result, Visibility,
8+
FnArg, Ident, ItemFn, LitStr, Pat, Result, ReturnType, Visibility,
79
};
810

911
pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStream> {
@@ -35,11 +37,23 @@ pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result<TokenStr
3537
if item_fn.sig.asyncness.is_none() {
3638
attr_args.r#async = None;
3739
}
40+
let output = replace(&mut item_fn.sig.output, ReturnType::Default);
3841
let block = if let Some(a) = attr_args.r#async {
3942
item_fn.sig.asyncness = None;
40-
a.apply(block)
43+
a.apply(block, output)
4144
} else {
42-
quote!(#block)
45+
match output {
46+
ReturnType::Default => quote!(#block),
47+
ReturnType::Type(_, ty) => {
48+
let f = Ident::new("__test_body", Span::mixed_site());
49+
quote!({
50+
let #f = move || -> #ty {
51+
#block
52+
};
53+
::std::result::Result::map_err(#f(), ::std::convert::Into::<TestCaseError>::into)?;
54+
})
55+
}
56+
}
4357
};
4458
let block = quote! {
4559
{
@@ -136,28 +150,43 @@ enum Async {
136150
Expr(Expr),
137151
}
138152
impl Async {
139-
fn apply(&self, block: &Block) -> TokenStream {
153+
fn apply(&self, block: &Block, output: ReturnType) -> TokenStream {
154+
let body;
155+
let output_type;
156+
let ret_expr;
157+
match output {
158+
ReturnType::Default => {
159+
body = quote! {
160+
#block
161+
Ok(())
162+
};
163+
output_type =
164+
quote!(::core::result::Result<_, ::proptest::test_runner::TestCaseError> );
165+
ret_expr = quote! { ret? };
166+
}
167+
ReturnType::Type(_, ty) => {
168+
body = quote! { #block };
169+
output_type = quote!(#ty);
170+
ret_expr = quote! {
171+
::std::result::Result::map_err(ret, ::std::convert::Into::<TestCaseError>::into)?
172+
};
173+
}
174+
}
140175
match self {
141176
Async::Tokio => {
142177
quote! {
143-
let ret: ::core::result::Result<_, ::proptest::test_runner::TestCaseError> =
178+
let ret: #output_type =
144179
tokio::runtime::Runtime::new()
145180
.unwrap()
146-
.block_on(async move {
147-
#block
148-
Ok(())
149-
});
150-
ret?;
181+
.block_on(async move { #body });
182+
#ret_expr;
151183
}
152184
}
153185
Async::Expr(expr) => {
154186
quote! {
155-
let ret: ::core::result::Result<(), ::proptest::test_runner::TestCaseError> =
156-
(#expr)(async move {
157-
#block
158-
Ok(())
159-
});
160-
ret?;
187+
let ret: #output_type =
188+
(#expr)(async move { #body });
189+
#ret_expr;
161190
}
162191
}
163192
}

tests/proptest_fn.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,64 @@ async fn async_expr_test_prop_assert() {
116116
async fn async_expr_test_prop_assert_false() {
117117
prop_assert!(false);
118118
}
119+
120+
struct NotImplError;
121+
122+
#[derive(Debug)]
123+
struct CustomError(#[allow(dead_code)] TestCaseError);
124+
125+
impl std::error::Error for CustomError {}
126+
impl std::fmt::Display for CustomError {
127+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128+
write!(f, "CustomError")
129+
}
130+
}
131+
impl From<TestCaseError> for CustomError {
132+
fn from(e: TestCaseError) -> Self {
133+
CustomError(e)
134+
}
135+
}
136+
impl From<NotImplError> for CustomError {
137+
fn from(_: NotImplError) -> Self {
138+
CustomError(TestCaseError::fail("Custom"))
139+
}
140+
}
141+
142+
#[proptest]
143+
fn custom_error_ok(#[strategy(1..10u8)] x: u8) -> Result<(), CustomError> {
144+
assert!(x < 10);
145+
not_impl_error_ok()?;
146+
Ok(())
147+
}
148+
149+
#[proptest]
150+
#[should_panic]
151+
fn custom_error_err(#[strategy(1..10u8)] x: u8) -> Result<(), CustomError> {
152+
assert!(x < 10);
153+
not_impl_error_err()?;
154+
Ok(())
155+
}
156+
157+
#[proptest(async = "tokio")]
158+
async fn custom_error_async_ok(#[strategy(1..10u8)] x: u8) -> Result<(), CustomError> {
159+
assert!(x < 10);
160+
not_impl_error_ok()?;
161+
yield_now().await;
162+
Ok(())
163+
}
164+
165+
#[proptest(async = "tokio")]
166+
#[should_panic]
167+
async fn custom_error_async_err(#[strategy(1..10u8)] x: u8) -> Result<(), CustomError> {
168+
assert!(x < 10);
169+
not_impl_error_err()?;
170+
yield_now().await;
171+
Ok(())
172+
}
173+
174+
fn not_impl_error_ok() -> Result<(), NotImplError> {
175+
Ok(())
176+
}
177+
fn not_impl_error_err() -> Result<(), NotImplError> {
178+
Err(NotImplError)
179+
}

0 commit comments

Comments
 (0)