Skip to content

Commit a4c41e3

Browse files
Googlercopybara-github
authored andcommitted
Ensure that printed arguments in verify_pred are evaluated exactly once by assigning intermediate values to variables.
PiperOrigin-RevId: 690026827
1 parent f605d54 commit a4c41e3

File tree

2 files changed

+125
-34
lines changed

2 files changed

+125
-34
lines changed

googletest/tests/assertions_test.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,57 @@ mod verify_pred {
6767
)
6868
}
6969

70+
#[test]
71+
fn evaluates_functions_and_arguments_exactly_once() -> Result<()> {
72+
let mut a = 0;
73+
let mut foo = |_b: u32| {
74+
a += 1;
75+
false
76+
};
77+
let mut b = 0;
78+
let mut bar = || {
79+
b += 10;
80+
b
81+
};
82+
83+
let res = verify_pred!(foo(bar()));
84+
verify_that!(
85+
res,
86+
err(displays_as(contains_substring(indoc! {"
87+
foo(bar()) was false with
88+
bar() = 10,
89+
at"
90+
})))
91+
)?;
92+
93+
verify_that!((a, b), eq((1, 10)))
94+
}
95+
96+
#[test]
97+
fn evaluates_methods_and_arguments_exactly_once() -> Result<()> {
98+
struct Apple(u32);
99+
impl Apple {
100+
fn c(&mut self, _b: bool) -> bool {
101+
self.0 += 1;
102+
false
103+
}
104+
}
105+
let mut a = Apple(0);
106+
let mut b = Apple(10);
107+
108+
let res = verify_pred!(a.c(b.c(false)));
109+
verify_that!(
110+
res,
111+
err(displays_as(contains_substring(indoc! {"
112+
a.c(b.c(false)) was false with
113+
b.c(false) = false,
114+
at"
115+
})))
116+
)?;
117+
118+
verify_that!((a.0, b.0), eq((1, 11)))
119+
}
120+
70121
#[test]
71122
fn supports_chained_method_calls() -> Result<()> {
72123
struct Foo;

googletest_macro/src/verify_pred.rs

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
// limitations under the License.
1414

1515
use quote::quote;
16-
use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Expr, Ident};
16+
use syn::{parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Expr, Ident};
1717

1818
struct AccumulatePartsState {
1919
error_message_ident: Ident,
20+
var_defs: Vec<proc_macro2::TokenStream>,
2021
formats: Vec<proc_macro2::TokenStream>,
2122
}
2223

@@ -31,54 +32,92 @@ impl AccumulatePartsState {
3132
"__googletest__verify_pred__error_message",
3233
::proc_macro2::Span::call_site(),
3334
),
35+
var_defs: vec![],
3436
formats: vec![],
3537
}
3638
}
3739

38-
/// Accumulates error message formating parts for various parts of the
39-
/// expression.
40-
fn accumulate_parts(&mut self, expr: &Expr) {
41-
let expr_string = expr_to_string(expr);
42-
match expr {
43-
Expr::Group(group) => {
40+
/// Takes an expression with chained field accesses and method calls and
41+
/// accumulates intermediate expressions used for computing `verify_pred!`'s
42+
/// expression, including intermediate variable assignments to evaluate
43+
/// parts of the expression exactly once, and the format string used to
44+
/// output intermediate values on condition failure. It returns the new form
45+
/// of the input expression with parts of it potentially replaced by the
46+
/// intermediate variables.
47+
fn accumulate_parts(&mut self, expr: Expr) -> Expr {
48+
let expr_string = expr_to_string(&expr);
49+
let new_expr = match expr {
50+
Expr::Group(mut group) => {
4451
// This is an invisible group added for correct precedence in the AST. Just pass
4552
// through without having a separate printing result.
46-
return self.accumulate_parts(&group.expr);
53+
*group.expr = self.accumulate_parts(*group.expr);
54+
return Expr::Group(group);
4755
}
48-
Expr::Call(call) => {
49-
// Format the args into the error message.
50-
self.format_args(&call.args);
56+
Expr::Call(mut call) => {
57+
// Cache args into intermediate variables.
58+
call.args = self.define_variables_for_args(call.args);
59+
// Cache function value into an intermediate variable.
60+
self.define_variable(&Expr::Call(call))
5161
}
52-
Expr::MethodCall(method_call) => {
53-
// Format the args into the error message.
54-
self.format_args(&method_call.args);
62+
Expr::MethodCall(mut method_call) => {
63+
// Cache args into intermediate variables.
64+
method_call.args = self.define_variables_for_args(method_call.args);
65+
// Cache method value into an intermediate variable.
66+
self.define_variable(&Expr::MethodCall(method_call))
5567
}
56-
_ => {}
57-
}
68+
// By default, assume it's some expression that needs to be cached to avoid
69+
// double-evaluation.
70+
_ => self.define_variable(&expr),
71+
};
5872
let error_message_ident = &self.error_message_ident;
5973
self.formats.push(quote! {
6074
::googletest::fmt::internal::__googletest__write_expr_value!(
6175
&mut #error_message_ident,
6276
#expr_string,
63-
#expr,
77+
#new_expr,
6478
);
6579
});
80+
new_expr
6681
}
6782

68-
// Formats each argument expression into the error message.
69-
fn format_args(&mut self, args: &Punctuated<Expr, Comma>) {
70-
for pair in args.pairs() {
71-
let error_message_ident = &self.error_message_ident;
72-
let expr_string = expr_to_string(pair.value());
73-
let expr = pair.value();
74-
self.formats.push(quote! {
75-
::googletest::fmt::internal::__googletest__write_expr_value!(
76-
&mut #error_message_ident,
77-
#expr_string,
78-
#expr,
79-
);
80-
});
81-
}
83+
// Defines a variable for each argument expression so that it's evaluated
84+
// exactly once.
85+
fn define_variables_for_args(
86+
&mut self,
87+
args: Punctuated<Expr, Comma>,
88+
) -> Punctuated<Expr, Comma> {
89+
args.into_pairs()
90+
.map(|mut pair| {
91+
let var_expr = self.define_variable(pair.value());
92+
let error_message_ident = &self.error_message_ident;
93+
let expr_string = expr_to_string(pair.value());
94+
self.formats.push(quote! {
95+
::googletest::fmt::internal::__googletest__write_expr_value!(
96+
&mut #error_message_ident,
97+
#expr_string,
98+
#var_expr,
99+
);
100+
});
101+
102+
*pair.value_mut() = var_expr;
103+
pair
104+
})
105+
.collect()
106+
}
107+
108+
/// Defines a new variable assigned to the expression and returns the
109+
/// variable as an expression to be used in place of the passed-in
110+
/// expression.
111+
fn define_variable(&mut self, value: &Expr) -> Expr {
112+
let var_name = Ident::new(
113+
&format!("__googletest__verify_pred__var{}", self.var_defs.len()),
114+
value.span(),
115+
);
116+
self.var_defs.push(quote! {
117+
#[allow(non_snake_case)]
118+
let mut #var_name = #value;
119+
});
120+
syn::parse::<Expr>(quote!(#var_name).into()).unwrap()
82121
}
83122
}
84123

@@ -87,13 +126,14 @@ pub fn verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStre
87126
let error_message = quote!(#parsed).to_string() + " was false with";
88127

89128
let mut state = AccumulatePartsState::new();
90-
state.accumulate_parts(&parsed);
91-
let AccumulatePartsState { error_message_ident, mut formats, .. } = state;
129+
let pred_value = state.accumulate_parts(parsed);
130+
let AccumulatePartsState { error_message_ident, var_defs, mut formats, .. } = state;
92131

93132
let _ = formats.pop(); // The last one is the full expression itself.
94133
quote! {
95134
{
96-
if (#parsed) {
135+
#(#var_defs)*
136+
if (#pred_value) {
97137
Ok(())
98138
} else {
99139
let mut #error_message_ident = #error_message.to_string();

0 commit comments

Comments
 (0)