@@ -16,7 +16,7 @@ use syntax::{
16
16
edit_in_place:: { AttrsOwnerEdit , Indent } ,
17
17
make, HasName ,
18
18
} ,
19
- ted, AstNode , NodeOrToken , SyntaxNode , T ,
19
+ match_ast , ted, AstNode , NodeOrToken , SyntaxNode , T ,
20
20
} ;
21
21
use text_edit:: TextRange ;
22
22
@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
40
40
// ```
41
41
// ->
42
42
// ```
43
- // fn main() {
44
- // #[derive(PartialEq, Eq)]
45
- // enum Bool { True, False }
43
+ // #[derive(PartialEq, Eq)]
44
+ // enum Bool { True, False }
46
45
//
46
+ // fn main() {
47
47
// let bool = Bool::True;
48
48
//
49
49
// if bool == Bool::True {
@@ -270,6 +270,10 @@ fn replace_usages(
270
270
}
271
271
_ => ( ) ,
272
272
}
273
+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274
+ {
275
+ edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276
+ replace_bool_expr ( edit, initializer) ;
273
277
} else if new_name. syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
274
278
// for any other usage in an expression, replace it with a check that it is the true variant
275
279
if let Some ( ( record_field, expr) ) = new_name
@@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
413
417
}
414
418
}
415
419
420
+ fn find_assoc_const_usage ( name : & ast:: NameLike ) -> Option < ( ast:: Type , ast:: Expr ) > {
421
+ let const_ = name. syntax ( ) . parent ( ) . and_then ( ast:: Const :: cast) ?;
422
+ if const_. syntax ( ) . parent ( ) . and_then ( ast:: AssocItemList :: cast) . is_none ( ) {
423
+ return None ;
424
+ }
425
+
426
+ Some ( ( const_. ty ( ) ?, const_. body ( ) ?) )
427
+ }
428
+
416
429
/// Adds the definition of the new enum before the target node.
417
430
fn add_enum_def (
418
431
edit : & mut SourceChangeBuilder ,
@@ -430,18 +443,48 @@ fn add_enum_def(
430
443
. any ( |module| module. nearest_non_block_module ( ctx. db ( ) ) != * target_module) ;
431
444
let enum_def = make_bool_enum ( make_enum_pub) ;
432
445
433
- let indent = IndentLevel :: from_node ( & target_node) ;
446
+ let insert_before = node_to_insert_before ( target_node) ;
447
+ let indent = IndentLevel :: from_node ( & insert_before) ;
434
448
enum_def. reindent_to ( indent) ;
435
449
436
450
ted:: insert_all (
437
- ted:: Position :: before ( & edit. make_syntax_mut ( target_node ) ) ,
451
+ ted:: Position :: before ( & edit. make_syntax_mut ( insert_before ) ) ,
438
452
vec ! [
439
453
enum_def. syntax( ) . clone( ) . into( ) ,
440
454
make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
441
455
] ,
442
456
) ;
443
457
}
444
458
459
+ /// Finds where to put the new enum definition, at the nearest module or at top-level.
460
+ fn node_to_insert_before ( mut target_node : SyntaxNode ) -> SyntaxNode {
461
+ let mut ancestors = target_node. ancestors ( ) ;
462
+
463
+ while let Some ( ancestor) = ancestors. next ( ) {
464
+ match_ast ! {
465
+ match ancestor {
466
+ ast:: Item ( item) => {
467
+ if item
468
+ . syntax( )
469
+ . parent( )
470
+ . and_then( |item_list| item_list. parent( ) )
471
+ . and_then( ast:: Module :: cast)
472
+ . is_some( )
473
+ {
474
+ return ancestor;
475
+ }
476
+ } ,
477
+ ast:: SourceFile ( _) => break ,
478
+ _ => ( ) ,
479
+ }
480
+ }
481
+
482
+ target_node = ancestor;
483
+ }
484
+
485
+ target_node
486
+ }
487
+
445
488
fn make_bool_enum ( make_pub : bool ) -> ast:: Enum {
446
489
let enum_def = make:: enum_ (
447
490
if make_pub { Some ( make:: visibility_pub ( ) ) } else { None } ,
@@ -491,10 +534,10 @@ fn main() {
491
534
}
492
535
"# ,
493
536
r#"
494
- fn main() {
495
- #[derive(PartialEq, Eq)]
496
- enum Bool { True, False }
537
+ #[derive(PartialEq, Eq)]
538
+ enum Bool { True, False }
497
539
540
+ fn main() {
498
541
let foo = Bool::True;
499
542
500
543
if foo == Bool::True {
@@ -520,10 +563,10 @@ fn main() {
520
563
}
521
564
"# ,
522
565
r#"
523
- fn main() {
524
- #[derive(PartialEq, Eq)]
525
- enum Bool { True, False }
566
+ #[derive(PartialEq, Eq)]
567
+ enum Bool { True, False }
526
568
569
+ fn main() {
527
570
let foo = Bool::True;
528
571
529
572
if foo == Bool::False {
@@ -545,10 +588,10 @@ fn main() {
545
588
}
546
589
"# ,
547
590
r#"
548
- fn main() {
549
- #[derive(PartialEq, Eq)]
550
- enum Bool { True, False }
591
+ #[derive(PartialEq, Eq)]
592
+ enum Bool { True, False }
551
593
594
+ fn main() {
552
595
let foo: Bool = Bool::False;
553
596
}
554
597
"# ,
@@ -565,10 +608,10 @@ fn main() {
565
608
}
566
609
"# ,
567
610
r#"
568
- fn main() {
569
- #[derive(PartialEq, Eq)]
570
- enum Bool { True, False }
611
+ #[derive(PartialEq, Eq)]
612
+ enum Bool { True, False }
571
613
614
+ fn main() {
572
615
let foo = if 1 == 2 { Bool::True } else { Bool::False };
573
616
}
574
617
"# ,
@@ -590,10 +633,10 @@ fn main() {
590
633
}
591
634
"# ,
592
635
r#"
593
- fn main() {
594
- #[derive(PartialEq, Eq)]
595
- enum Bool { True, False }
636
+ #[derive(PartialEq, Eq)]
637
+ enum Bool { True, False }
596
638
639
+ fn main() {
597
640
let foo = Bool::False;
598
641
let bar = true;
599
642
@@ -619,10 +662,10 @@ fn main() {
619
662
}
620
663
"# ,
621
664
r#"
622
- fn main() {
623
- #[derive(PartialEq, Eq)]
624
- enum Bool { True, False }
665
+ #[derive(PartialEq, Eq)]
666
+ enum Bool { True, False }
625
667
668
+ fn main() {
626
669
let foo = Bool::True;
627
670
628
671
if *&foo == Bool::True {
@@ -645,10 +688,10 @@ fn main() {
645
688
}
646
689
"# ,
647
690
r#"
648
- fn main() {
649
- #[derive(PartialEq, Eq)]
650
- enum Bool { True, False }
691
+ #[derive(PartialEq, Eq)]
692
+ enum Bool { True, False }
651
693
694
+ fn main() {
652
695
let foo: Bool;
653
696
foo = Bool::True;
654
697
}
@@ -671,10 +714,10 @@ fn main() {
671
714
}
672
715
"# ,
673
716
r#"
674
- fn main() {
675
- #[derive(PartialEq, Eq)]
676
- enum Bool { True, False }
717
+ #[derive(PartialEq, Eq)]
718
+ enum Bool { True, False }
677
719
720
+ fn main() {
678
721
let foo = Bool::True;
679
722
let bar = foo == Bool::False;
680
723
@@ -702,11 +745,11 @@ fn main() {
702
745
}
703
746
"# ,
704
747
r#"
748
+ #[derive(PartialEq, Eq)]
749
+ enum Bool { True, False }
750
+
705
751
fn main() {
706
752
if !"foo".chars().any(|c| {
707
- #[derive(PartialEq, Eq)]
708
- enum Bool { True, False }
709
-
710
753
let foo = Bool::True;
711
754
foo == Bool::True
712
755
}) {
@@ -1445,6 +1488,90 @@ pub mod bar {
1445
1488
)
1446
1489
}
1447
1490
1491
+ #[ test]
1492
+ fn const_in_impl_cross_file ( ) {
1493
+ check_assist (
1494
+ bool_to_enum,
1495
+ r#"
1496
+ //- /main.rs
1497
+ mod foo;
1498
+
1499
+ struct Foo;
1500
+
1501
+ impl Foo {
1502
+ pub const $0BOOL: bool = true;
1503
+ }
1504
+
1505
+ //- /foo.rs
1506
+ use crate::Foo;
1507
+
1508
+ fn foo() -> bool {
1509
+ Foo::BOOL
1510
+ }
1511
+ "# ,
1512
+ r#"
1513
+ //- /main.rs
1514
+ mod foo;
1515
+
1516
+ struct Foo;
1517
+
1518
+ #[derive(PartialEq, Eq)]
1519
+ pub enum Bool { True, False }
1520
+
1521
+ impl Foo {
1522
+ pub const BOOL: Bool = Bool::True;
1523
+ }
1524
+
1525
+ //- /foo.rs
1526
+ use crate::{Foo, Bool};
1527
+
1528
+ fn foo() -> bool {
1529
+ Foo::BOOL == Bool::True
1530
+ }
1531
+ "# ,
1532
+ )
1533
+ }
1534
+
1535
+ #[ test]
1536
+ fn const_in_trait ( ) {
1537
+ check_assist (
1538
+ bool_to_enum,
1539
+ r#"
1540
+ trait Foo {
1541
+ const $0BOOL: bool;
1542
+ }
1543
+
1544
+ impl Foo for usize {
1545
+ const BOOL: bool = true;
1546
+ }
1547
+
1548
+ fn main() {
1549
+ if <usize as Foo>::BOOL {
1550
+ println!("foo");
1551
+ }
1552
+ }
1553
+ "# ,
1554
+ r#"
1555
+ #[derive(PartialEq, Eq)]
1556
+ enum Bool { True, False }
1557
+
1558
+ trait Foo {
1559
+ const BOOL: Bool;
1560
+ }
1561
+
1562
+ impl Foo for usize {
1563
+ const BOOL: Bool = Bool::True;
1564
+ }
1565
+
1566
+ fn main() {
1567
+ if <usize as Foo>::BOOL == Bool::True {
1568
+ println!("foo");
1569
+ }
1570
+ }
1571
+ "# ,
1572
+ )
1573
+ }
1574
+
1448
1575
#[ test]
1449
1576
fn const_non_bool ( ) {
1450
1577
cov_mark:: check!( not_applicable_non_bool_const) ;
0 commit comments