13
13
// limitations under the License.
14
14
15
15
use quote:: quote;
16
- use syn:: { parse_macro_input, punctuated:: Punctuated , token:: Comma , Expr , Ident } ;
16
+ use syn:: { parse_macro_input, punctuated:: Punctuated , spanned :: Spanned , token:: Comma , Expr , Ident } ;
17
17
18
18
struct AccumulatePartsState {
19
19
error_message_ident : Ident ,
20
+ var_defs : Vec < proc_macro2:: TokenStream > ,
20
21
formats : Vec < proc_macro2:: TokenStream > ,
21
22
}
22
23
@@ -31,54 +32,92 @@ impl AccumulatePartsState {
31
32
"__googletest__verify_pred__error_message" ,
32
33
:: proc_macro2:: Span :: call_site ( ) ,
33
34
) ,
35
+ var_defs : vec ! [ ] ,
34
36
formats : vec ! [ ] ,
35
37
}
36
38
}
37
39
38
- /// Accumulates error message formating parts for various parts of the
39
- /// expression.
40
- fn accumulate_parts ( & mut self , expr : & Expr ) {
41
- let expr_string = expr_to_string ( expr) ;
42
- match expr {
43
- Expr :: Group ( group) => {
40
+ /// Takes an expression with chained field accesses and method calls and
41
+ /// accumulates intermediate expressions used for computing `verify_pred!`'s
42
+ /// expression, including intermediate variable assignments to evaluate
43
+ /// parts of the expression exactly once, and the format string used to
44
+ /// output intermediate values on condition failure. It returns the new form
45
+ /// of the input expression with parts of it potentially replaced by the
46
+ /// intermediate variables.
47
+ fn accumulate_parts ( & mut self , expr : Expr ) -> Expr {
48
+ let expr_string = expr_to_string ( & expr) ;
49
+ let new_expr = match expr {
50
+ Expr :: Group ( mut group) => {
44
51
// This is an invisible group added for correct precedence in the AST. Just pass
45
52
// through without having a separate printing result.
46
- return self . accumulate_parts ( & group. expr ) ;
53
+ * group. expr = self . accumulate_parts ( * group. expr ) ;
54
+ return Expr :: Group ( group) ;
47
55
}
48
- Expr :: Call ( call) => {
49
- // Format the args into the error message.
50
- self . format_args ( & call. args ) ;
56
+ Expr :: Call ( mut call) => {
57
+ // Cache args into intermediate variables.
58
+ call. args = self . define_variables_for_args ( call. args ) ;
59
+ // Cache function value into an intermediate variable.
60
+ self . define_variable ( & Expr :: Call ( call) )
51
61
}
52
- Expr :: MethodCall ( method_call) => {
53
- // Format the args into the error message.
54
- self . format_args ( & method_call. args ) ;
62
+ Expr :: MethodCall ( mut method_call) => {
63
+ // Cache args into intermediate variables.
64
+ method_call. args = self . define_variables_for_args ( method_call. args ) ;
65
+ // Cache method value into an intermediate variable.
66
+ self . define_variable ( & Expr :: MethodCall ( method_call) )
55
67
}
56
- _ => { }
57
- }
68
+ // By default, assume it's some expression that needs to be cached to avoid
69
+ // double-evaluation.
70
+ _ => self . define_variable ( & expr) ,
71
+ } ;
58
72
let error_message_ident = & self . error_message_ident ;
59
73
self . formats . push ( quote ! {
60
74
:: googletest:: fmt:: internal:: __googletest__write_expr_value!(
61
75
& mut #error_message_ident,
62
76
#expr_string,
63
- #expr ,
77
+ #new_expr ,
64
78
) ;
65
79
} ) ;
80
+ new_expr
66
81
}
67
82
68
- // Formats each argument expression into the error message.
69
- fn format_args ( & mut self , args : & Punctuated < Expr , Comma > ) {
70
- for pair in args. pairs ( ) {
71
- let error_message_ident = & self . error_message_ident ;
72
- let expr_string = expr_to_string ( pair. value ( ) ) ;
73
- let expr = pair. value ( ) ;
74
- self . formats . push ( quote ! {
75
- :: googletest:: fmt:: internal:: __googletest__write_expr_value!(
76
- & mut #error_message_ident,
77
- #expr_string,
78
- #expr,
79
- ) ;
80
- } ) ;
81
- }
83
+ // Defines a variable for each argument expression so that it's evaluated
84
+ // exactly once.
85
+ fn define_variables_for_args (
86
+ & mut self ,
87
+ args : Punctuated < Expr , Comma > ,
88
+ ) -> Punctuated < Expr , Comma > {
89
+ args. into_pairs ( )
90
+ . map ( |mut pair| {
91
+ let var_expr = self . define_variable ( pair. value ( ) ) ;
92
+ let error_message_ident = & self . error_message_ident ;
93
+ let expr_string = expr_to_string ( pair. value ( ) ) ;
94
+ self . formats . push ( quote ! {
95
+ :: googletest:: fmt:: internal:: __googletest__write_expr_value!(
96
+ & mut #error_message_ident,
97
+ #expr_string,
98
+ #var_expr,
99
+ ) ;
100
+ } ) ;
101
+
102
+ * pair. value_mut ( ) = var_expr;
103
+ pair
104
+ } )
105
+ . collect ( )
106
+ }
107
+
108
+ /// Defines a new variable assigned to the expression and returns the
109
+ /// variable as an expression to be used in place of the passed-in
110
+ /// expression.
111
+ fn define_variable ( & mut self , value : & Expr ) -> Expr {
112
+ let var_name = Ident :: new (
113
+ & format ! ( "__googletest__verify_pred__var{}" , self . var_defs. len( ) ) ,
114
+ value. span ( ) ,
115
+ ) ;
116
+ self . var_defs . push ( quote ! {
117
+ #[ allow( non_snake_case) ]
118
+ let mut #var_name = #value;
119
+ } ) ;
120
+ syn:: parse :: < Expr > ( quote ! ( #var_name) . into ( ) ) . unwrap ( )
82
121
}
83
122
}
84
123
@@ -87,13 +126,14 @@ pub fn verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStre
87
126
let error_message = quote ! ( #parsed) . to_string ( ) + " was false with" ;
88
127
89
128
let mut state = AccumulatePartsState :: new ( ) ;
90
- state. accumulate_parts ( & parsed) ;
91
- let AccumulatePartsState { error_message_ident, mut formats, .. } = state;
129
+ let pred_value = state. accumulate_parts ( parsed) ;
130
+ let AccumulatePartsState { error_message_ident, var_defs , mut formats, .. } = state;
92
131
93
132
let _ = formats. pop ( ) ; // The last one is the full expression itself.
94
133
quote ! {
95
134
{
96
- if ( #parsed) {
135
+ #( #var_defs) *
136
+ if ( #pred_value) {
97
137
Ok ( ( ) )
98
138
} else {
99
139
let mut #error_message_ident = #error_message. to_string( ) ;
0 commit comments