Skip to content

Commit 044d99e

Browse files
bors[bot]Veykril
andauthored
Merge #9816
9816: feat: Implement if_to_bool_then assist r=Veykril a=Veykril One half of #8413 Co-authored-by: Lukas Wirth <lukastw97@gmail.com>
2 parents bc084a6 + 3b7c713 commit 044d99e

File tree

5 files changed

+386
-5
lines changed

5 files changed

+386
-5
lines changed

crates/hir_expand/src/name.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ pub mod known {
200200
Range,
201201
Neg,
202202
Not,
203+
None,
203204
Index,
204205
// Components of known path (function name)
205206
filter_map,
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
use hir::{known, Semantics};
2+
use ide_db::{
3+
helpers::{for_each_tail_expr, FamousDefs},
4+
RootDatabase,
5+
};
6+
use syntax::{
7+
ast::{self, make, ArgListOwner},
8+
ted, AstNode, SyntaxNode,
9+
};
10+
11+
use crate::{
12+
utils::{invert_boolean_expression, unwrap_trivial_block},
13+
AssistContext, AssistId, AssistKind, Assists,
14+
};
15+
16+
// Assist: convert_if_to_bool_then
17+
//
18+
// Converts an if expression into a corresponding `bool::then` call.
19+
//
20+
// ```
21+
// # //- minicore: option
22+
// fn main() {
23+
// if$0 cond {
24+
// Some(val)
25+
// } else {
26+
// None
27+
// }
28+
// }
29+
// ```
30+
// ->
31+
// ```
32+
// fn main() {
33+
// cond.then(|| val)
34+
// }
35+
// ```
36+
pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
37+
// todo, applies to match as well
38+
let expr = ctx.find_node_at_offset::<ast::IfExpr>()?;
39+
if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) {
40+
return None;
41+
}
42+
43+
let cond = expr.condition().filter(|cond| !cond.is_pattern_cond())?;
44+
let cond = cond.expr()?;
45+
let then = expr.then_branch()?;
46+
let else_ = match expr.else_branch()? {
47+
ast::ElseBranch::Block(b) => b,
48+
ast::ElseBranch::IfExpr(_) => {
49+
cov_mark::hit!(convert_if_to_bool_then_chain);
50+
return None;
51+
}
52+
};
53+
54+
let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?;
55+
56+
let (invert_cond, closure_body) = match (
57+
block_is_none_variant(&ctx.sema, &then, none_variant),
58+
block_is_none_variant(&ctx.sema, &else_, none_variant),
59+
) {
60+
(invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)),
61+
(invert @ false, true) => (invert, ast::Expr::BlockExpr(then)),
62+
_ => return None,
63+
};
64+
65+
if is_invalid_body(&ctx.sema, some_variant, &closure_body) {
66+
cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body);
67+
return None;
68+
}
69+
70+
let target = expr.syntax().text_range();
71+
acc.add(
72+
AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite),
73+
"Convert `if` expression to `bool::then` call",
74+
target,
75+
|builder| {
76+
let closure_body = closure_body.clone_for_update();
77+
// Rewrite all `Some(e)` in tail position to `e`
78+
for_each_tail_expr(&closure_body, &mut |e| {
79+
let e = match e {
80+
ast::Expr::BreakExpr(e) => e.expr(),
81+
e @ ast::Expr::CallExpr(_) => Some(e.clone()),
82+
_ => None,
83+
};
84+
if let Some(ast::Expr::CallExpr(call)) = e {
85+
if let Some(arg_list) = call.arg_list() {
86+
if let Some(arg) = arg_list.args().next() {
87+
ted::replace(call.syntax(), arg.syntax());
88+
}
89+
}
90+
}
91+
});
92+
let closure_body = match closure_body {
93+
ast::Expr::BlockExpr(block) => unwrap_trivial_block(block),
94+
e => e,
95+
};
96+
97+
let cond = if invert_cond { invert_boolean_expression(&ctx.sema, cond) } else { cond };
98+
let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body)));
99+
let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list);
100+
builder.replace(target, mcall.to_string());
101+
},
102+
)
103+
}
104+
105+
fn option_variants(
106+
sema: &Semantics<RootDatabase>,
107+
expr: &SyntaxNode,
108+
) -> Option<(hir::Variant, hir::Variant)> {
109+
let fam = FamousDefs(&sema, sema.scope(expr).krate());
110+
let option_variants = fam.core_option_Option()?.variants(sema.db);
111+
match &*option_variants {
112+
&[variant0, variant1] => Some(if variant0.name(sema.db) == known::None {
113+
(variant0, variant1)
114+
} else {
115+
(variant1, variant0)
116+
}),
117+
_ => None,
118+
}
119+
}
120+
121+
/// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression.
122+
/// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call.
123+
fn is_invalid_body(
124+
sema: &Semantics<RootDatabase>,
125+
some_variant: hir::Variant,
126+
expr: &ast::Expr,
127+
) -> bool {
128+
let mut invalid = false;
129+
expr.preorder(&mut |e| {
130+
invalid |=
131+
matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_)));
132+
invalid
133+
});
134+
if !invalid {
135+
for_each_tail_expr(&expr, &mut |e| {
136+
if invalid {
137+
return;
138+
}
139+
let e = match e {
140+
ast::Expr::BreakExpr(e) => e.expr(),
141+
e @ ast::Expr::CallExpr(_) => Some(e.clone()),
142+
_ => None,
143+
};
144+
if let Some(ast::Expr::CallExpr(call)) = e {
145+
if let Some(ast::Expr::PathExpr(p)) = call.expr() {
146+
let res = p.path().and_then(|p| sema.resolve_path(&p));
147+
if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res {
148+
return invalid |= v != some_variant;
149+
}
150+
}
151+
}
152+
invalid = true
153+
});
154+
}
155+
invalid
156+
}
157+
158+
fn block_is_none_variant(
159+
sema: &Semantics<RootDatabase>,
160+
block: &ast::BlockExpr,
161+
none_variant: hir::Variant,
162+
) -> bool {
163+
block.as_lone_tail().and_then(|e| match e {
164+
ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? {
165+
hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v),
166+
_ => None,
167+
},
168+
_ => None,
169+
}) == Some(none_variant)
170+
}
171+
172+
#[cfg(test)]
173+
mod tests {
174+
use crate::tests::{check_assist, check_assist_not_applicable};
175+
176+
use super::*;
177+
178+
#[test]
179+
fn convert_if_to_bool_then_simple() {
180+
check_assist(
181+
convert_if_to_bool_then,
182+
r"
183+
//- minicore:option
184+
fn main() {
185+
if$0 true {
186+
Some(15)
187+
} else {
188+
None
189+
}
190+
}
191+
",
192+
r"
193+
fn main() {
194+
true.then(|| 15)
195+
}
196+
",
197+
);
198+
}
199+
200+
#[test]
201+
fn convert_if_to_bool_then_invert() {
202+
check_assist(
203+
convert_if_to_bool_then,
204+
r"
205+
//- minicore:option
206+
fn main() {
207+
if$0 true {
208+
None
209+
} else {
210+
Some(15)
211+
}
212+
}
213+
",
214+
r"
215+
fn main() {
216+
false.then(|| 15)
217+
}
218+
",
219+
);
220+
}
221+
222+
#[test]
223+
fn convert_if_to_bool_then_none_none() {
224+
check_assist_not_applicable(
225+
convert_if_to_bool_then,
226+
r"
227+
//- minicore:option
228+
fn main() {
229+
if$0 true {
230+
None
231+
} else {
232+
None
233+
}
234+
}
235+
",
236+
);
237+
}
238+
239+
#[test]
240+
fn convert_if_to_bool_then_some_some() {
241+
check_assist_not_applicable(
242+
convert_if_to_bool_then,
243+
r"
244+
//- minicore:option
245+
fn main() {
246+
if$0 true {
247+
Some(15)
248+
} else {
249+
Some(15)
250+
}
251+
}
252+
",
253+
);
254+
}
255+
256+
#[test]
257+
fn convert_if_to_bool_then_mixed() {
258+
check_assist_not_applicable(
259+
convert_if_to_bool_then,
260+
r"
261+
//- minicore:option
262+
fn main() {
263+
if$0 true {
264+
if true {
265+
Some(15)
266+
} else {
267+
None
268+
}
269+
} else {
270+
None
271+
}
272+
}
273+
",
274+
);
275+
}
276+
277+
#[test]
278+
fn convert_if_to_bool_then_chain() {
279+
cov_mark::check!(convert_if_to_bool_then_chain);
280+
check_assist_not_applicable(
281+
convert_if_to_bool_then,
282+
r"
283+
//- minicore:option
284+
fn main() {
285+
if$0 true {
286+
Some(15)
287+
} else if true {
288+
None
289+
} else {
290+
None
291+
}
292+
}
293+
",
294+
);
295+
}
296+
297+
#[test]
298+
fn convert_if_to_bool_then_pattern_cond() {
299+
check_assist_not_applicable(
300+
convert_if_to_bool_then,
301+
r"
302+
//- minicore:option
303+
fn main() {
304+
if$0 let true = true {
305+
Some(15)
306+
} else {
307+
None
308+
}
309+
}
310+
",
311+
);
312+
}
313+
314+
#[test]
315+
fn convert_if_to_bool_then_pattern_invalid_body() {
316+
cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2);
317+
check_assist_not_applicable(
318+
convert_if_to_bool_then,
319+
r"
320+
//- minicore:option
321+
fn make_me_an_option() -> Option<i32> { None }
322+
fn main() {
323+
if$0 true {
324+
if true {
325+
make_me_an_option()
326+
} else {
327+
Some(15)
328+
}
329+
} else {
330+
None
331+
}
332+
}
333+
",
334+
);
335+
check_assist_not_applicable(
336+
convert_if_to_bool_then,
337+
r"
338+
//- minicore:option
339+
fn main() {
340+
if$0 true {
341+
if true {
342+
return;
343+
}
344+
Some(15)
345+
} else {
346+
None
347+
}
348+
}
349+
",
350+
);
351+
}
352+
}

0 commit comments

Comments
 (0)