@@ -16,11 +16,14 @@ use syntax::{
16
16
edit_in_place:: { AttrsOwnerEdit , Indent } ,
17
17
make, HasName ,
18
18
} ,
19
- ted , AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
19
+ AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
20
20
} ;
21
21
use text_edit:: TextRange ;
22
22
23
- use crate :: assist_context:: { AssistContext , Assists } ;
23
+ use crate :: {
24
+ assist_context:: { AssistContext , Assists } ,
25
+ utils,
26
+ } ;
24
27
25
28
// Assist: bool_to_enum
26
29
//
@@ -73,7 +76,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
73
76
74
77
let usages = definition. usages ( & ctx. sema ) . all ( ) ;
75
78
add_enum_def ( edit, ctx, & usages, target_node, & target_module) ;
76
- replace_usages ( edit, ctx, & usages, definition, & target_module) ;
79
+ replace_usages ( edit, ctx, usages, definition, & target_module) ;
77
80
} ,
78
81
)
79
82
}
@@ -169,8 +172,8 @@ fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {
169
172
170
173
/// Converts an expression of type `bool` to one of the new enum type.
171
174
fn bool_expr_to_enum_expr ( expr : ast:: Expr ) -> ast:: Expr {
172
- let true_expr = make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) . clone_for_update ( ) ;
173
- let false_expr = make:: expr_path ( make:: path_from_text ( "Bool::False" ) ) . clone_for_update ( ) ;
175
+ let true_expr = make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) ;
176
+ let false_expr = make:: expr_path ( make:: path_from_text ( "Bool::False" ) ) ;
174
177
175
178
if let ast:: Expr :: Literal ( literal) = & expr {
176
179
match literal. kind ( ) {
@@ -184,66 +187,62 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
184
187
make:: tail_only_block_expr ( true_expr) ,
185
188
Some ( ast:: ElseBranch :: Block ( make:: tail_only_block_expr ( false_expr) ) ) ,
186
189
)
187
- . clone_for_update ( )
188
190
}
189
191
}
190
192
191
193
/// Replaces all usages of the target identifier, both when read and written to.
192
194
fn replace_usages (
193
195
edit : & mut SourceChangeBuilder ,
194
196
ctx : & AssistContext < ' _ > ,
195
- usages : & UsageSearchResult ,
197
+ usages : UsageSearchResult ,
196
198
target_definition : Definition ,
197
199
target_module : & hir:: Module ,
198
200
) {
199
- for ( file_id, references) in usages. iter ( ) {
200
- edit. edit_file ( * file_id) ;
201
+ for ( file_id, references) in usages {
202
+ edit. edit_file ( file_id) ;
201
203
202
- let refs_with_imports =
203
- augment_references_with_imports ( edit, ctx, references, target_module) ;
204
+ let refs_with_imports = augment_references_with_imports ( ctx, references, target_module) ;
204
205
205
206
refs_with_imports. into_iter ( ) . rev ( ) . for_each (
206
- |FileReferenceWithImport { range, old_name , new_name , import_data } | {
207
+ |FileReferenceWithImport { range, name , import_data } | {
207
208
// replace the usages in patterns and expressions
208
- if let Some ( ident_pat) = old_name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast)
209
- {
209
+ if let Some ( ident_pat) = name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast) {
210
210
cov_mark:: hit!( replaces_record_pat_shorthand) ;
211
211
212
212
let definition = ctx. sema . to_def ( & ident_pat) . map ( Definition :: Local ) ;
213
213
if let Some ( def) = definition {
214
214
replace_usages (
215
215
edit,
216
216
ctx,
217
- & def. usages ( & ctx. sema ) . all ( ) ,
217
+ def. usages ( & ctx. sema ) . all ( ) ,
218
218
target_definition,
219
219
target_module,
220
220
)
221
221
}
222
- } else if let Some ( initializer) = find_assignment_usage ( & new_name ) {
222
+ } else if let Some ( initializer) = find_assignment_usage ( & name ) {
223
223
cov_mark:: hit!( replaces_assignment) ;
224
224
225
225
replace_bool_expr ( edit, initializer) ;
226
- } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & new_name ) {
226
+ } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & name ) {
227
227
cov_mark:: hit!( replaces_negation) ;
228
228
229
229
edit. replace (
230
230
prefix_expr. syntax ( ) . text_range ( ) ,
231
231
format ! ( "{} == Bool::False" , inner_expr) ,
232
232
) ;
233
- } else if let Some ( ( record_field, initializer) ) = old_name
233
+ } else if let Some ( ( record_field, initializer) ) = name
234
234
. as_name_ref ( )
235
235
. and_then ( ast:: RecordExprField :: for_field_name)
236
236
. and_then ( |record_field| ctx. sema . resolve_record_field ( & record_field) )
237
237
. and_then ( |( got_field, _, _) | {
238
- find_record_expr_usage ( & new_name , got_field, target_definition)
238
+ find_record_expr_usage ( & name , got_field, target_definition)
239
239
} )
240
240
{
241
241
cov_mark:: hit!( replaces_record_expr) ;
242
242
243
- let record_field = edit. make_mut ( record_field) ;
244
243
let enum_expr = bool_expr_to_enum_expr ( initializer) ;
245
- record_field . replace_expr ( enum_expr) ;
246
- } else if let Some ( pat) = find_record_pat_field_usage ( & old_name ) {
244
+ utils :: replace_record_field_expr ( ctx , edit , record_field , enum_expr) ;
245
+ } else if let Some ( pat) = find_record_pat_field_usage ( & name ) {
247
246
match pat {
248
247
ast:: Pat :: IdentPat ( ident_pat) => {
249
248
cov_mark:: hit!( replaces_record_pat) ;
@@ -253,7 +252,7 @@ fn replace_usages(
253
252
replace_usages (
254
253
edit,
255
254
ctx,
256
- & def. usages ( & ctx. sema ) . all ( ) ,
255
+ def. usages ( & ctx. sema ) . all ( ) ,
257
256
target_definition,
258
257
target_module,
259
258
)
@@ -270,40 +269,44 @@ fn replace_usages(
270
269
}
271
270
_ => ( ) ,
272
271
}
273
- } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274
- {
272
+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & name) {
275
273
edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276
274
replace_bool_expr ( edit, initializer) ;
277
- } else if let Some ( receiver) = find_method_call_expr_usage ( & new_name ) {
275
+ } else if let Some ( receiver) = find_method_call_expr_usage ( & name ) {
278
276
edit. replace (
279
277
receiver. syntax ( ) . text_range ( ) ,
280
278
format ! ( "({} == Bool::True)" , receiver) ,
281
279
) ;
282
- } else if new_name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
280
+ } else if name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
283
281
// for any other usage in an expression, replace it with a check that it is the true variant
284
- if let Some ( ( record_field, expr) ) = new_name
285
- . as_name_ref ( )
286
- . and_then ( ast:: RecordExprField :: for_field_name)
287
- . and_then ( |record_field| {
288
- record_field. expr ( ) . map ( |expr| ( record_field, expr) )
289
- } )
282
+ if let Some ( ( record_field, expr) ) =
283
+ name. as_name_ref ( ) . and_then ( ast:: RecordExprField :: for_field_name) . and_then (
284
+ |record_field| record_field. expr ( ) . map ( |expr| ( record_field, expr) ) ,
285
+ )
290
286
{
291
- record_field. replace_expr (
287
+ utils:: replace_record_field_expr (
288
+ ctx,
289
+ edit,
290
+ record_field,
292
291
make:: expr_bin_op (
293
292
expr,
294
293
ast:: BinaryOp :: CmpOp ( ast:: CmpOp :: Eq { negated : false } ) ,
295
294
make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) ,
296
- )
297
- . clone_for_update ( ) ,
295
+ ) ,
298
296
) ;
299
297
} else {
300
- edit. replace ( range, format ! ( "{} == Bool::True" , new_name . text( ) ) ) ;
298
+ edit. replace ( range, format ! ( "{} == Bool::True" , name . text( ) ) ) ;
301
299
}
302
300
}
303
301
304
302
// add imports across modules where needed
305
303
if let Some ( ( import_scope, path) ) = import_data {
306
- insert_use ( & import_scope, path, & ctx. config . insert_use ) ;
304
+ let scope = match import_scope. clone ( ) {
305
+ ImportScope :: File ( it) => ImportScope :: File ( edit. make_mut ( it) ) ,
306
+ ImportScope :: Module ( it) => ImportScope :: Module ( edit. make_mut ( it) ) ,
307
+ ImportScope :: Block ( it) => ImportScope :: Block ( edit. make_mut ( it) ) ,
308
+ } ;
309
+ insert_use ( & scope, path, & ctx. config . insert_use ) ;
307
310
}
308
311
} ,
309
312
)
@@ -312,37 +315,31 @@ fn replace_usages(
312
315
313
316
struct FileReferenceWithImport {
314
317
range : TextRange ,
315
- old_name : ast:: NameLike ,
316
- new_name : ast:: NameLike ,
318
+ name : ast:: NameLike ,
317
319
import_data : Option < ( ImportScope , ast:: Path ) > ,
318
320
}
319
321
320
322
fn augment_references_with_imports (
321
- edit : & mut SourceChangeBuilder ,
322
323
ctx : & AssistContext < ' _ > ,
323
- references : & [ FileReference ] ,
324
+ references : Vec < FileReference > ,
324
325
target_module : & hir:: Module ,
325
326
) -> Vec < FileReferenceWithImport > {
326
327
let mut visited_modules = FxHashSet :: default ( ) ;
327
328
328
329
references
329
- . iter ( )
330
+ . into_iter ( )
330
331
. filter_map ( |FileReference { range, name, .. } | {
331
332
let name = name. clone ( ) . into_name_like ( ) ?;
332
- ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( * range, name, scope. module ( ) ) )
333
+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( range, name, scope. module ( ) ) )
333
334
} )
334
335
. map ( |( range, name, ref_module) | {
335
- let old_name = name. clone ( ) ;
336
- let new_name = edit. make_mut ( name. clone ( ) ) ;
337
-
338
336
// if the referenced module is not the same as the target one and has not been seen before, add an import
339
337
let import_data = if ref_module. nearest_non_block_module ( ctx. db ( ) ) != * target_module
340
338
&& !visited_modules. contains ( & ref_module)
341
339
{
342
340
visited_modules. insert ( ref_module) ;
343
341
344
- let import_scope =
345
- ImportScope :: find_insert_use_container ( new_name. syntax ( ) , & ctx. sema ) ;
342
+ let import_scope = ImportScope :: find_insert_use_container ( name. syntax ( ) , & ctx. sema ) ;
346
343
let path = ref_module
347
344
. find_use_path_prefixed (
348
345
ctx. sema . db ,
@@ -360,7 +357,7 @@ fn augment_references_with_imports(
360
357
None
361
358
} ;
362
359
363
- FileReferenceWithImport { range, old_name , new_name , import_data }
360
+ FileReferenceWithImport { range, name , import_data }
364
361
} )
365
362
. collect ( )
366
363
}
@@ -405,13 +402,12 @@ fn find_record_expr_usage(
405
402
let record_field = ast:: RecordExprField :: for_field_name ( name_ref) ?;
406
403
let initializer = record_field. expr ( ) ?;
407
404
408
- if let Definition :: Field ( expected_field ) = target_definition {
409
- if got_field != expected_field {
410
- return None ;
405
+ match target_definition {
406
+ Definition :: Field ( expected_field ) if got_field == expected_field => {
407
+ Some ( ( record_field , initializer ) )
411
408
}
409
+ _ => None ,
412
410
}
413
-
414
- Some ( ( record_field, initializer) )
415
411
}
416
412
417
413
fn find_record_pat_field_usage ( name : & ast:: NameLike ) -> Option < ast:: Pat > {
@@ -466,12 +462,9 @@ fn add_enum_def(
466
462
let indent = IndentLevel :: from_node ( & insert_before) ;
467
463
enum_def. reindent_to ( indent) ;
468
464
469
- ted:: insert_all (
470
- ted:: Position :: before ( & edit. make_syntax_mut ( insert_before) ) ,
471
- vec ! [
472
- enum_def. syntax( ) . clone( ) . into( ) ,
473
- make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
474
- ] ,
465
+ edit. insert (
466
+ insert_before. text_range ( ) . start ( ) ,
467
+ format ! ( "{}\n \n {indent}" , enum_def. syntax( ) . text( ) ) ,
475
468
) ;
476
469
}
477
470
@@ -800,6 +793,78 @@ fn main() {
800
793
)
801
794
}
802
795
796
+ #[ test]
797
+ fn local_var_init_struct_usage ( ) {
798
+ check_assist (
799
+ bool_to_enum,
800
+ r#"
801
+ struct Foo {
802
+ foo: bool,
803
+ }
804
+
805
+ fn main() {
806
+ let $0foo = true;
807
+ let s = Foo { foo };
808
+ }
809
+ "# ,
810
+ r#"
811
+ struct Foo {
812
+ foo: bool,
813
+ }
814
+
815
+ #[derive(PartialEq, Eq)]
816
+ enum Bool { True, False }
817
+
818
+ fn main() {
819
+ let foo = Bool::True;
820
+ let s = Foo { foo: foo == Bool::True };
821
+ }
822
+ "# ,
823
+ )
824
+ }
825
+
826
+ #[ test]
827
+ fn local_var_init_struct_usage_in_macro ( ) {
828
+ check_assist (
829
+ bool_to_enum,
830
+ r#"
831
+ struct Struct {
832
+ boolean: bool,
833
+ }
834
+
835
+ macro_rules! identity {
836
+ ($body:expr) => {
837
+ $body
838
+ }
839
+ }
840
+
841
+ fn new() -> Struct {
842
+ let $0boolean = true;
843
+ identity![Struct { boolean }]
844
+ }
845
+ "# ,
846
+ r#"
847
+ struct Struct {
848
+ boolean: bool,
849
+ }
850
+
851
+ macro_rules! identity {
852
+ ($body:expr) => {
853
+ $body
854
+ }
855
+ }
856
+
857
+ #[derive(PartialEq, Eq)]
858
+ enum Bool { True, False }
859
+
860
+ fn new() -> Struct {
861
+ let boolean = Bool::True;
862
+ identity![Struct { boolean: boolean == Bool::True }]
863
+ }
864
+ "# ,
865
+ )
866
+ }
867
+
803
868
#[ test]
804
869
fn field_struct_basic ( ) {
805
870
cov_mark:: check!( replaces_record_expr) ;
@@ -1321,6 +1386,46 @@ fn main() {
1321
1386
)
1322
1387
}
1323
1388
1389
+ #[ test]
1390
+ fn field_in_macro ( ) {
1391
+ check_assist (
1392
+ bool_to_enum,
1393
+ r#"
1394
+ struct Struct {
1395
+ $0boolean: bool,
1396
+ }
1397
+
1398
+ fn boolean(x: Struct) {
1399
+ let Struct { boolean } = x;
1400
+ }
1401
+
1402
+ macro_rules! identity { ($body:expr) => { $body } }
1403
+
1404
+ fn new() -> Struct {
1405
+ identity!(Struct { boolean: true })
1406
+ }
1407
+ "# ,
1408
+ r#"
1409
+ #[derive(PartialEq, Eq)]
1410
+ enum Bool { True, False }
1411
+
1412
+ struct Struct {
1413
+ boolean: Bool,
1414
+ }
1415
+
1416
+ fn boolean(x: Struct) {
1417
+ let Struct { boolean } = x;
1418
+ }
1419
+
1420
+ macro_rules! identity { ($body:expr) => { $body } }
1421
+
1422
+ fn new() -> Struct {
1423
+ identity!(Struct { boolean: Bool::True })
1424
+ }
1425
+ "# ,
1426
+ )
1427
+ }
1428
+
1324
1429
#[ test]
1325
1430
fn field_non_bool ( ) {
1326
1431
cov_mark:: check!( not_applicable_non_bool_field) ;
0 commit comments