Skip to content

Commit ba01c0a

Browse files
committed
Auto merge of #17467 - winstxnhdw:bool-to-enum, r=Veykril
feat: add bool_to_enum assist for parameters ## Summary This PR adds parameter support for `bool_to_enum` assists. Essentially, the assist can now transform this: ```rs fn function($0foo: bool) { if foo { println!("foo"); } } ``` To this, ```rs #[derive(PartialEq, Eq)] enum Bool { True, False } fn function(foo: Bool) { if foo == Bool::True { println!("foo"); } } ``` Thanks to `@/davidbarsky` for the test skeleton (: Closes #17400
2 parents 079ee28 + a456692 commit ba01c0a

File tree

1 file changed

+112
-17
lines changed

1 file changed

+112
-17
lines changed

src/tools/rust-analyzer/crates/ide-assists/src/handlers/bool_to_enum.rs

Lines changed: 112 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use either::Either;
12
use hir::{ImportPathConfig, ModuleDef};
23
use ide_db::{
34
assists::{AssistId, AssistKind},
@@ -97,27 +98,30 @@ struct BoolNodeData {
9798
fn find_bool_node(ctx: &AssistContext<'_>) -> Option<BoolNodeData> {
9899
let name: ast::Name = ctx.find_node_at_offset()?;
99100

100-
if let Some(let_stmt) = name.syntax().ancestors().find_map(ast::LetStmt::cast) {
101-
let bind_pat = match let_stmt.pat()? {
102-
ast::Pat::IdentPat(pat) => pat,
103-
_ => {
104-
cov_mark::hit!(not_applicable_in_non_ident_pat);
105-
return None;
106-
}
107-
};
108-
let def = ctx.sema.to_def(&bind_pat)?;
101+
if let Some(ident_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) {
102+
let def = ctx.sema.to_def(&ident_pat)?;
109103
if !def.ty(ctx.db()).is_bool() {
110104
cov_mark::hit!(not_applicable_non_bool_local);
111105
return None;
112106
}
113107

114-
Some(BoolNodeData {
115-
target_node: let_stmt.syntax().clone(),
116-
name,
117-
ty_annotation: let_stmt.ty(),
118-
initializer: let_stmt.initializer(),
119-
definition: Definition::Local(def),
120-
})
108+
let local_definition = Definition::Local(def);
109+
match ident_pat.syntax().parent().and_then(Either::<ast::Param, ast::LetStmt>::cast)? {
110+
Either::Left(param) => Some(BoolNodeData {
111+
target_node: param.syntax().clone(),
112+
name,
113+
ty_annotation: param.ty(),
114+
initializer: None,
115+
definition: local_definition,
116+
}),
117+
Either::Right(let_stmt) => Some(BoolNodeData {
118+
target_node: let_stmt.syntax().clone(),
119+
name,
120+
ty_annotation: let_stmt.ty(),
121+
initializer: let_stmt.initializer(),
122+
definition: local_definition,
123+
}),
124+
}
121125
} else if let Some(const_) = name.syntax().parent().and_then(ast::Const::cast) {
122126
let def = ctx.sema.to_def(&const_)?;
123127
if !def.ty(ctx.db()).is_bool() {
@@ -524,6 +528,98 @@ mod tests {
524528

525529
use crate::tests::{check_assist, check_assist_not_applicable};
526530

531+
#[test]
532+
fn parameter_with_first_param_usage() {
533+
check_assist(
534+
bool_to_enum,
535+
r#"
536+
fn function($0foo: bool, bar: bool) {
537+
if foo {
538+
println!("foo");
539+
}
540+
}
541+
"#,
542+
r#"
543+
#[derive(PartialEq, Eq)]
544+
enum Bool { True, False }
545+
546+
fn function(foo: Bool, bar: bool) {
547+
if foo == Bool::True {
548+
println!("foo");
549+
}
550+
}
551+
"#,
552+
)
553+
}
554+
555+
#[test]
556+
fn parameter_with_last_param_usage() {
557+
check_assist(
558+
bool_to_enum,
559+
r#"
560+
fn function(foo: bool, $0bar: bool) {
561+
if bar {
562+
println!("bar");
563+
}
564+
}
565+
"#,
566+
r#"
567+
#[derive(PartialEq, Eq)]
568+
enum Bool { True, False }
569+
570+
fn function(foo: bool, bar: Bool) {
571+
if bar == Bool::True {
572+
println!("bar");
573+
}
574+
}
575+
"#,
576+
)
577+
}
578+
579+
#[test]
580+
fn parameter_with_middle_param_usage() {
581+
check_assist(
582+
bool_to_enum,
583+
r#"
584+
fn function(foo: bool, $0bar: bool, baz: bool) {
585+
if bar {
586+
println!("bar");
587+
}
588+
}
589+
"#,
590+
r#"
591+
#[derive(PartialEq, Eq)]
592+
enum Bool { True, False }
593+
594+
fn function(foo: bool, bar: Bool, baz: bool) {
595+
if bar == Bool::True {
596+
println!("bar");
597+
}
598+
}
599+
"#,
600+
)
601+
}
602+
603+
#[test]
604+
fn parameter_with_closure_usage() {
605+
check_assist(
606+
bool_to_enum,
607+
r#"
608+
fn main() {
609+
let foo = |$0bar: bool| bar;
610+
}
611+
"#,
612+
r#"
613+
#[derive(PartialEq, Eq)]
614+
enum Bool { True, False }
615+
616+
fn main() {
617+
let foo = |bar: Bool| bar == Bool::True;
618+
}
619+
"#,
620+
)
621+
}
622+
527623
#[test]
528624
fn local_variable_with_usage() {
529625
check_assist(
@@ -791,7 +887,6 @@ fn main() {
791887

792888
#[test]
793889
fn local_variable_non_ident_pat() {
794-
cov_mark::check!(not_applicable_in_non_ident_pat);
795890
check_assist_not_applicable(
796891
bool_to_enum,
797892
r#"

0 commit comments

Comments
 (0)