@@ -16,7 +16,7 @@ use syntax::{
16
16
edit_in_place:: { AttrsOwnerEdit , Indent } ,
17
17
make, HasName ,
18
18
} ,
19
- ted , AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
19
+ AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
20
20
} ;
21
21
use text_edit:: TextRange ;
22
22
@@ -73,7 +73,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
73
73
74
74
let usages = definition. usages ( & ctx. sema ) . all ( ) ;
75
75
add_enum_def ( edit, ctx, & usages, target_node, & target_module) ;
76
- replace_usages ( edit, ctx, & usages, definition, & target_module) ;
76
+ replace_usages ( edit, ctx, usages, definition, & target_module) ;
77
77
} ,
78
78
)
79
79
}
@@ -192,58 +192,55 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
192
192
fn replace_usages (
193
193
edit : & mut SourceChangeBuilder ,
194
194
ctx : & AssistContext < ' _ > ,
195
- usages : & UsageSearchResult ,
195
+ usages : UsageSearchResult ,
196
196
target_definition : Definition ,
197
197
target_module : & hir:: Module ,
198
198
) {
199
- for ( file_id, references) in usages. iter ( ) {
200
- edit. edit_file ( * file_id) ;
199
+ for ( file_id, references) in usages {
200
+ edit. edit_file ( file_id) ;
201
201
202
- let refs_with_imports =
203
- augment_references_with_imports ( edit, ctx, references, target_module) ;
202
+ let refs_with_imports = augment_references_with_imports ( ctx, references, target_module) ;
204
203
205
204
refs_with_imports. into_iter ( ) . rev ( ) . for_each (
206
- |FileReferenceWithImport { range, old_name , new_name , import_data } | {
205
+ |FileReferenceWithImport { range, name , import_data } | {
207
206
// replace the usages in patterns and expressions
208
- if let Some ( ident_pat) = old_name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast)
209
- {
207
+ if let Some ( ident_pat) = name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast) {
210
208
cov_mark:: hit!( replaces_record_pat_shorthand) ;
211
209
212
210
let definition = ctx. sema . to_def ( & ident_pat) . map ( Definition :: Local ) ;
213
211
if let Some ( def) = definition {
214
212
replace_usages (
215
213
edit,
216
214
ctx,
217
- & def. usages ( & ctx. sema ) . all ( ) ,
215
+ def. usages ( & ctx. sema ) . all ( ) ,
218
216
target_definition,
219
217
target_module,
220
218
)
221
219
}
222
- } else if let Some ( initializer) = find_assignment_usage ( & new_name ) {
220
+ } else if let Some ( initializer) = find_assignment_usage ( & name ) {
223
221
cov_mark:: hit!( replaces_assignment) ;
224
222
225
223
replace_bool_expr ( edit, initializer) ;
226
- } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & new_name ) {
224
+ } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & name ) {
227
225
cov_mark:: hit!( replaces_negation) ;
228
226
229
227
edit. replace (
230
228
prefix_expr. syntax ( ) . text_range ( ) ,
231
229
format ! ( "{} == Bool::False" , inner_expr) ,
232
230
) ;
233
- } else if let Some ( ( record_field, initializer) ) = old_name
231
+ } else if let Some ( ( record_field, initializer) ) = name
234
232
. as_name_ref ( )
235
233
. and_then ( ast:: RecordExprField :: for_field_name)
236
234
. and_then ( |record_field| ctx. sema . resolve_record_field ( & record_field) )
237
235
. and_then ( |( got_field, _, _) | {
238
- find_record_expr_usage ( & new_name , got_field, target_definition)
236
+ find_record_expr_usage ( & name , got_field, target_definition)
239
237
} )
240
238
{
241
239
cov_mark:: hit!( replaces_record_expr) ;
242
240
243
- let record_field = edit. make_mut ( record_field) ;
244
241
let enum_expr = bool_expr_to_enum_expr ( initializer) ;
245
- record_field . replace_expr ( enum_expr) ;
246
- } else if let Some ( pat) = find_record_pat_field_usage ( & old_name ) {
242
+ replace_record_field_expr ( edit , record_field , enum_expr) ;
243
+ } else if let Some ( pat) = find_record_pat_field_usage ( & name ) {
247
244
match pat {
248
245
ast:: Pat :: IdentPat ( ident_pat) => {
249
246
cov_mark:: hit!( replaces_record_pat) ;
@@ -253,7 +250,7 @@ fn replace_usages(
253
250
replace_usages (
254
251
edit,
255
252
ctx,
256
- & def. usages ( & ctx. sema ) . all ( ) ,
253
+ def. usages ( & ctx. sema ) . all ( ) ,
257
254
target_definition,
258
255
target_module,
259
256
)
@@ -270,79 +267,94 @@ fn replace_usages(
270
267
}
271
268
_ => ( ) ,
272
269
}
273
- } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274
- {
270
+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & name) {
275
271
edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276
272
replace_bool_expr ( edit, initializer) ;
277
- } else if let Some ( receiver) = find_method_call_expr_usage ( & new_name ) {
273
+ } else if let Some ( receiver) = find_method_call_expr_usage ( & name ) {
278
274
edit. replace (
279
275
receiver. syntax ( ) . text_range ( ) ,
280
276
format ! ( "({} == Bool::True)" , receiver) ,
281
277
) ;
282
- } else if new_name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
278
+ } else if name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
283
279
// for any other usage in an expression, replace it with a check that it is the true variant
284
- if let Some ( ( record_field, expr) ) = new_name
285
- . as_name_ref ( )
286
- . and_then ( ast:: RecordExprField :: for_field_name)
287
- . and_then ( |record_field| {
288
- record_field. expr ( ) . map ( |expr| ( record_field, expr) )
289
- } )
280
+ if let Some ( ( record_field, expr) ) =
281
+ name. as_name_ref ( ) . and_then ( ast:: RecordExprField :: for_field_name) . and_then (
282
+ |record_field| record_field. expr ( ) . map ( |expr| ( record_field, expr) ) ,
283
+ )
290
284
{
291
- record_field. replace_expr (
285
+ replace_record_field_expr (
286
+ edit,
287
+ record_field,
292
288
make:: expr_bin_op (
293
289
expr,
294
290
ast:: BinaryOp :: CmpOp ( ast:: CmpOp :: Eq { negated : false } ) ,
295
291
make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) ,
296
- )
297
- . clone_for_update ( ) ,
292
+ ) ,
298
293
) ;
299
294
} else {
300
- edit. replace ( range, format ! ( "{} == Bool::True" , new_name . text( ) ) ) ;
295
+ edit. replace ( range, format ! ( "{} == Bool::True" , name . text( ) ) ) ;
301
296
}
302
297
}
303
298
304
299
// add imports across modules where needed
305
300
if let Some ( ( import_scope, path) ) = import_data {
306
- insert_use ( & import_scope, path, & ctx. config . insert_use ) ;
301
+ let scope = match import_scope. clone ( ) {
302
+ ImportScope :: File ( it) => ImportScope :: File ( edit. make_mut ( it) ) ,
303
+ ImportScope :: Module ( it) => ImportScope :: Module ( edit. make_mut ( it) ) ,
304
+ ImportScope :: Block ( it) => ImportScope :: Block ( edit. make_mut ( it) ) ,
305
+ } ;
306
+ insert_use ( & scope, path, & ctx. config . insert_use ) ;
307
307
}
308
308
} ,
309
309
)
310
310
}
311
311
}
312
312
313
+ /// Replaces the record expression, handling field shorthands.
314
+ fn replace_record_field_expr (
315
+ edit : & mut SourceChangeBuilder ,
316
+ record_field : ast:: RecordExprField ,
317
+ initializer : ast:: Expr ,
318
+ ) {
319
+ if let Some ( ast:: Expr :: PathExpr ( path_expr) ) = record_field. expr ( ) {
320
+ // replace field shorthand
321
+ edit. insert (
322
+ path_expr. syntax ( ) . text_range ( ) . end ( ) ,
323
+ format ! ( ": {}" , initializer. syntax( ) . text( ) ) ,
324
+ )
325
+ } else if let Some ( expr) = record_field. expr ( ) {
326
+ // just replace expr
327
+ edit. replace_ast ( expr, initializer) ;
328
+ }
329
+ }
330
+
313
331
struct FileReferenceWithImport {
314
332
range : TextRange ,
315
- old_name : ast:: NameLike ,
316
- new_name : ast:: NameLike ,
333
+ name : ast:: NameLike ,
317
334
import_data : Option < ( ImportScope , ast:: Path ) > ,
318
335
}
319
336
320
337
fn augment_references_with_imports (
321
- edit : & mut SourceChangeBuilder ,
322
338
ctx : & AssistContext < ' _ > ,
323
- references : & [ FileReference ] ,
339
+ references : Vec < FileReference > ,
324
340
target_module : & hir:: Module ,
325
341
) -> Vec < FileReferenceWithImport > {
326
342
let mut visited_modules = FxHashSet :: default ( ) ;
327
343
328
344
references
329
- . iter ( )
345
+ . into_iter ( )
330
346
. filter_map ( |FileReference { range, name, .. } | {
331
347
let name = name. clone ( ) . into_name_like ( ) ?;
332
- ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( * range, name, scope. module ( ) ) )
348
+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( range, name, scope. module ( ) ) )
333
349
} )
334
350
. map ( |( range, name, ref_module) | {
335
- let old_name = name. clone ( ) ;
336
- let new_name = edit. make_mut ( name. clone ( ) ) ;
337
-
338
351
// if the referenced module is not the same as the target one and has not been seen before, add an import
339
352
let import_data = if ref_module. nearest_non_block_module ( ctx. db ( ) ) != * target_module
340
353
&& !visited_modules. contains ( & ref_module)
341
354
{
342
355
visited_modules. insert ( ref_module) ;
343
356
344
- let import_scope =
345
- ImportScope :: find_insert_use_container ( new_name. syntax ( ) , & ctx. sema ) ;
357
+ let import_scope = ImportScope :: find_insert_use_container ( name. syntax ( ) , & ctx. sema ) ;
346
358
let path = ref_module
347
359
. find_use_path_prefixed (
348
360
ctx. sema . db ,
@@ -360,7 +372,7 @@ fn augment_references_with_imports(
360
372
None
361
373
} ;
362
374
363
- FileReferenceWithImport { range, old_name , new_name , import_data }
375
+ FileReferenceWithImport { range, name , import_data }
364
376
} )
365
377
. collect ( )
366
378
}
@@ -465,12 +477,9 @@ fn add_enum_def(
465
477
let indent = IndentLevel :: from_node ( & insert_before) ;
466
478
enum_def. reindent_to ( indent) ;
467
479
468
- ted:: insert_all (
469
- ted:: Position :: before ( & edit. make_syntax_mut ( insert_before) ) ,
470
- vec ! [
471
- enum_def. syntax( ) . clone( ) . into( ) ,
472
- make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
473
- ] ,
480
+ edit. insert (
481
+ insert_before. text_range ( ) . start ( ) ,
482
+ format ! ( "{}\n \n {indent}" , enum_def. syntax( ) . text( ) ) ,
474
483
) ;
475
484
}
476
485
0 commit comments