Skip to content

Commit bce4be9

Browse files
committed
fix: make bool_to_enum assist create enum at top-level
1 parent d3cc3bc commit bce4be9

File tree

2 files changed

+163
-36
lines changed

2 files changed

+163
-36
lines changed

crates/ide-assists/src/handlers/bool_to_enum.rs

Lines changed: 160 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use syntax::{
1616
edit_in_place::{AttrsOwnerEdit, Indent},
1717
make, HasName,
1818
},
19-
ted, AstNode, NodeOrToken, SyntaxNode, T,
19+
match_ast, ted, AstNode, NodeOrToken, SyntaxNode, T,
2020
};
2121
use text_edit::TextRange;
2222

@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
4040
// ```
4141
// ->
4242
// ```
43-
// fn main() {
44-
// #[derive(PartialEq, Eq)]
45-
// enum Bool { True, False }
43+
// #[derive(PartialEq, Eq)]
44+
// enum Bool { True, False }
4645
//
46+
// fn main() {
4747
// let bool = Bool::True;
4848
//
4949
// if bool == Bool::True {
@@ -270,6 +270,10 @@ fn replace_usages(
270270
}
271271
_ => (),
272272
}
273+
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
274+
{
275+
edit.replace(ty_annotation.syntax().text_range(), "Bool");
276+
replace_bool_expr(edit, initializer);
273277
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
274278
// for any other usage in an expression, replace it with a check that it is the true variant
275279
if let Some((record_field, expr)) = new_name
@@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
413417
}
414418
}
415419

420+
fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
421+
let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
422+
if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
423+
return None;
424+
}
425+
426+
Some((const_.ty()?, const_.body()?))
427+
}
428+
416429
/// Adds the definition of the new enum before the target node.
417430
fn add_enum_def(
418431
edit: &mut SourceChangeBuilder,
@@ -430,18 +443,48 @@ fn add_enum_def(
430443
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
431444
let enum_def = make_bool_enum(make_enum_pub);
432445

433-
let indent = IndentLevel::from_node(&target_node);
446+
let insert_before = node_to_insert_before(target_node);
447+
let indent = IndentLevel::from_node(&insert_before);
434448
enum_def.reindent_to(indent);
435449

436450
ted::insert_all(
437-
ted::Position::before(&edit.make_syntax_mut(target_node)),
451+
ted::Position::before(&edit.make_syntax_mut(insert_before)),
438452
vec![
439453
enum_def.syntax().clone().into(),
440454
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
441455
],
442456
);
443457
}
444458

459+
/// Finds where to put the new enum definition, at the nearest module or at top-level.
460+
fn node_to_insert_before(mut target_node: SyntaxNode) -> SyntaxNode {
461+
let mut ancestors = target_node.ancestors();
462+
463+
while let Some(ancestor) = ancestors.next() {
464+
match_ast! {
465+
match ancestor {
466+
ast::Item(item) => {
467+
if item
468+
.syntax()
469+
.parent()
470+
.and_then(|item_list| item_list.parent())
471+
.and_then(ast::Module::cast)
472+
.is_some()
473+
{
474+
return ancestor;
475+
}
476+
},
477+
ast::SourceFile(_) => break,
478+
_ => (),
479+
}
480+
}
481+
482+
target_node = ancestor;
483+
}
484+
485+
target_node
486+
}
487+
445488
fn make_bool_enum(make_pub: bool) -> ast::Enum {
446489
let enum_def = make::enum_(
447490
if make_pub { Some(make::visibility_pub()) } else { None },
@@ -491,10 +534,10 @@ fn main() {
491534
}
492535
"#,
493536
r#"
494-
fn main() {
495-
#[derive(PartialEq, Eq)]
496-
enum Bool { True, False }
537+
#[derive(PartialEq, Eq)]
538+
enum Bool { True, False }
497539
540+
fn main() {
498541
let foo = Bool::True;
499542
500543
if foo == Bool::True {
@@ -520,10 +563,10 @@ fn main() {
520563
}
521564
"#,
522565
r#"
523-
fn main() {
524-
#[derive(PartialEq, Eq)]
525-
enum Bool { True, False }
566+
#[derive(PartialEq, Eq)]
567+
enum Bool { True, False }
526568
569+
fn main() {
527570
let foo = Bool::True;
528571
529572
if foo == Bool::False {
@@ -545,10 +588,10 @@ fn main() {
545588
}
546589
"#,
547590
r#"
548-
fn main() {
549-
#[derive(PartialEq, Eq)]
550-
enum Bool { True, False }
591+
#[derive(PartialEq, Eq)]
592+
enum Bool { True, False }
551593
594+
fn main() {
552595
let foo: Bool = Bool::False;
553596
}
554597
"#,
@@ -565,10 +608,10 @@ fn main() {
565608
}
566609
"#,
567610
r#"
568-
fn main() {
569-
#[derive(PartialEq, Eq)]
570-
enum Bool { True, False }
611+
#[derive(PartialEq, Eq)]
612+
enum Bool { True, False }
571613
614+
fn main() {
572615
let foo = if 1 == 2 { Bool::True } else { Bool::False };
573616
}
574617
"#,
@@ -590,10 +633,10 @@ fn main() {
590633
}
591634
"#,
592635
r#"
593-
fn main() {
594-
#[derive(PartialEq, Eq)]
595-
enum Bool { True, False }
636+
#[derive(PartialEq, Eq)]
637+
enum Bool { True, False }
596638
639+
fn main() {
597640
let foo = Bool::False;
598641
let bar = true;
599642
@@ -619,10 +662,10 @@ fn main() {
619662
}
620663
"#,
621664
r#"
622-
fn main() {
623-
#[derive(PartialEq, Eq)]
624-
enum Bool { True, False }
665+
#[derive(PartialEq, Eq)]
666+
enum Bool { True, False }
625667
668+
fn main() {
626669
let foo = Bool::True;
627670
628671
if *&foo == Bool::True {
@@ -645,10 +688,10 @@ fn main() {
645688
}
646689
"#,
647690
r#"
648-
fn main() {
649-
#[derive(PartialEq, Eq)]
650-
enum Bool { True, False }
691+
#[derive(PartialEq, Eq)]
692+
enum Bool { True, False }
651693
694+
fn main() {
652695
let foo: Bool;
653696
foo = Bool::True;
654697
}
@@ -671,10 +714,10 @@ fn main() {
671714
}
672715
"#,
673716
r#"
674-
fn main() {
675-
#[derive(PartialEq, Eq)]
676-
enum Bool { True, False }
717+
#[derive(PartialEq, Eq)]
718+
enum Bool { True, False }
677719
720+
fn main() {
678721
let foo = Bool::True;
679722
let bar = foo == Bool::False;
680723
@@ -702,11 +745,11 @@ fn main() {
702745
}
703746
"#,
704747
r#"
748+
#[derive(PartialEq, Eq)]
749+
enum Bool { True, False }
750+
705751
fn main() {
706752
if !"foo".chars().any(|c| {
707-
#[derive(PartialEq, Eq)]
708-
enum Bool { True, False }
709-
710753
let foo = Bool::True;
711754
foo == Bool::True
712755
}) {
@@ -1445,6 +1488,90 @@ pub mod bar {
14451488
)
14461489
}
14471490

1491+
#[test]
1492+
fn const_in_impl_cross_file() {
1493+
check_assist(
1494+
bool_to_enum,
1495+
r#"
1496+
//- /main.rs
1497+
mod foo;
1498+
1499+
struct Foo;
1500+
1501+
impl Foo {
1502+
pub const $0BOOL: bool = true;
1503+
}
1504+
1505+
//- /foo.rs
1506+
use crate::Foo;
1507+
1508+
fn foo() -> bool {
1509+
Foo::BOOL
1510+
}
1511+
"#,
1512+
r#"
1513+
//- /main.rs
1514+
mod foo;
1515+
1516+
struct Foo;
1517+
1518+
#[derive(PartialEq, Eq)]
1519+
pub enum Bool { True, False }
1520+
1521+
impl Foo {
1522+
pub const BOOL: Bool = Bool::True;
1523+
}
1524+
1525+
//- /foo.rs
1526+
use crate::{Foo, Bool};
1527+
1528+
fn foo() -> bool {
1529+
Foo::BOOL == Bool::True
1530+
}
1531+
"#,
1532+
)
1533+
}
1534+
1535+
#[test]
1536+
fn const_in_trait() {
1537+
check_assist(
1538+
bool_to_enum,
1539+
r#"
1540+
trait Foo {
1541+
const $0BOOL: bool;
1542+
}
1543+
1544+
impl Foo for usize {
1545+
const BOOL: bool = true;
1546+
}
1547+
1548+
fn main() {
1549+
if <usize as Foo>::BOOL {
1550+
println!("foo");
1551+
}
1552+
}
1553+
"#,
1554+
r#"
1555+
#[derive(PartialEq, Eq)]
1556+
enum Bool { True, False }
1557+
1558+
trait Foo {
1559+
const BOOL: Bool;
1560+
}
1561+
1562+
impl Foo for usize {
1563+
const BOOL: Bool = Bool::True;
1564+
}
1565+
1566+
fn main() {
1567+
if <usize as Foo>::BOOL == Bool::True {
1568+
println!("foo");
1569+
}
1570+
}
1571+
"#,
1572+
)
1573+
}
1574+
14481575
#[test]
14491576
fn const_non_bool() {
14501577
cov_mark::check!(not_applicable_non_bool_const);

crates/ide-assists/src/tests/generated.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,10 @@ fn main() {
294294
}
295295
"#####,
296296
r#####"
297-
fn main() {
298-
#[derive(PartialEq, Eq)]
299-
enum Bool { True, False }
297+
#[derive(PartialEq, Eq)]
298+
enum Bool { True, False }
300299
300+
fn main() {
301301
let bool = Bool::True;
302302
303303
if bool == Bool::True {

0 commit comments

Comments
 (0)