@@ -11,14 +11,19 @@ use ide_db::{
11
11
search:: FileReference ,
12
12
RootDatabase ,
13
13
} ;
14
+ use itertools:: Itertools ;
14
15
use rustc_hash:: FxHashSet ;
15
16
use syntax:: {
16
- algo:: find_node_at_offset,
17
- ast:: { self , make, AstNode , NameOwner , VisibilityOwner } ,
18
- ted, SyntaxNode , T ,
17
+ ast:: {
18
+ self , make, AstNode , AttrsOwner , GenericParamsOwner , NameOwner , TypeBoundsOwner ,
19
+ VisibilityOwner ,
20
+ } ,
21
+ match_ast,
22
+ ted:: { self , Position } ,
23
+ SyntaxNode , T ,
19
24
} ;
20
25
21
- use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
26
+ use crate :: { assist_context :: AssistBuilder , AssistContext , AssistId , AssistKind , Assists } ;
22
27
23
28
// Assist: extract_struct_from_enum_variant
24
29
//
@@ -70,11 +75,10 @@ pub(crate) fn extract_struct_from_enum_variant(
70
75
continue ;
71
76
}
72
77
builder. edit_file ( file_id) ;
73
- let source_file = builder. make_mut ( ctx. sema . parse ( file_id) ) ;
74
78
let processed = process_references (
75
79
ctx,
80
+ builder,
76
81
& mut visited_modules_set,
77
- source_file. syntax ( ) ,
78
82
& enum_module_def,
79
83
& variant_hir_name,
80
84
references,
@@ -84,13 +88,12 @@ pub(crate) fn extract_struct_from_enum_variant(
84
88
} ) ;
85
89
}
86
90
builder. edit_file ( ctx. frange . file_id ) ;
87
- let source_file = builder. make_mut ( ctx. sema . parse ( ctx. frange . file_id ) ) ;
88
91
let variant = builder. make_mut ( variant. clone ( ) ) ;
89
92
if let Some ( references) = def_file_references {
90
93
let processed = process_references (
91
94
ctx,
95
+ builder,
92
96
& mut visited_modules_set,
93
- source_file. syntax ( ) ,
94
97
& enum_module_def,
95
98
& variant_hir_name,
96
99
references,
@@ -100,12 +103,12 @@ pub(crate) fn extract_struct_from_enum_variant(
100
103
} ) ;
101
104
}
102
105
103
- let def = create_struct_def ( variant_name. clone ( ) , & field_list, enum_ast. visibility ( ) ) ;
106
+ let def = create_struct_def ( variant_name. clone ( ) , & field_list, & enum_ast) ;
104
107
let start_offset = & variant. parent_enum ( ) . syntax ( ) . clone ( ) ;
105
108
ted:: insert_raw ( ted:: Position :: before ( start_offset) , def. syntax ( ) ) ;
106
109
ted:: insert_raw ( ted:: Position :: before ( start_offset) , & make:: tokens:: blank_line ( ) ) ;
107
110
108
- update_variant ( & variant) ;
111
+ update_variant ( & variant, enum_ast . generic_param_list ( ) ) ;
109
112
} ,
110
113
)
111
114
}
@@ -149,7 +152,7 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
149
152
fn create_struct_def (
150
153
variant_name : ast:: Name ,
151
154
field_list : & Either < ast:: RecordFieldList , ast:: TupleFieldList > ,
152
- visibility : Option < ast:: Visibility > ,
155
+ enum_ : & ast:: Enum ,
153
156
) -> ast:: Struct {
154
157
let pub_vis = make:: visibility_pub ( ) ;
155
158
@@ -184,12 +187,38 @@ fn create_struct_def(
184
187
}
185
188
} ;
186
189
187
- make:: struct_ ( visibility, variant_name, None , field_list) . clone_for_update ( )
190
+ // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
191
+ let strukt =
192
+ make:: struct_ ( enum_. visibility ( ) , variant_name, enum_. generic_param_list ( ) , field_list)
193
+ . clone_for_update ( ) ;
194
+
195
+ // copy attributes
196
+ ted:: insert_all (
197
+ Position :: first_child_of ( strukt. syntax ( ) ) ,
198
+ enum_. attrs ( ) . map ( |it| it. syntax ( ) . clone_for_update ( ) . into ( ) ) . collect ( ) ,
199
+ ) ;
200
+ strukt
188
201
}
189
202
190
- fn update_variant ( variant : & ast:: Variant ) -> Option < ( ) > {
203
+ fn update_variant ( variant : & ast:: Variant , generic : Option < ast :: GenericParamList > ) -> Option < ( ) > {
191
204
let name = variant. name ( ) ?;
192
- let tuple_field = make:: tuple_field ( None , make:: ty ( & name. text ( ) ) ) ;
205
+ let ty = match generic {
206
+ // FIXME: This uses all the generic params of the enum, but the variant might not use all of them.
207
+ Some ( gpl) => {
208
+ let gpl = gpl. clone_for_update ( ) ;
209
+ gpl. generic_params ( ) . for_each ( |gp| {
210
+ match gp {
211
+ ast:: GenericParam :: LifetimeParam ( it) => it. type_bound_list ( ) ,
212
+ ast:: GenericParam :: TypeParam ( it) => it. type_bound_list ( ) ,
213
+ ast:: GenericParam :: ConstParam ( _) => return ,
214
+ }
215
+ . map ( |it| it. remove ( ) ) ;
216
+ } ) ;
217
+ make:: ty ( & format ! ( "{}<{}>" , name. text( ) , gpl. generic_params( ) . join( ", " ) ) )
218
+ }
219
+ None => make:: ty ( & name. text ( ) ) ,
220
+ } ;
221
+ let tuple_field = make:: tuple_field ( None , ty) ;
193
222
let replacement = make:: variant (
194
223
name,
195
224
Some ( ast:: FieldList :: TupleFieldList ( make:: tuple_field_list ( iter:: once ( tuple_field) ) ) ) ,
@@ -208,18 +237,17 @@ fn apply_references(
208
237
if let Some ( ( scope, path) ) = import {
209
238
insert_use ( & scope, mod_path_to_ast ( & path) , insert_use_cfg) ;
210
239
}
211
- ted:: insert_raw (
212
- ted:: Position :: before ( segment. syntax ( ) ) ,
213
- make:: path_from_text ( & format ! ( "{}" , segment) ) . clone_for_update ( ) . syntax ( ) ,
214
- ) ;
240
+ // deep clone to prevent cycle
241
+ let path = make:: path_from_segments ( iter:: once ( segment. clone_subtree ( ) ) , false ) ;
242
+ ted:: insert_raw ( ted:: Position :: before ( segment. syntax ( ) ) , path. clone_for_update ( ) . syntax ( ) ) ;
215
243
ted:: insert_raw ( ted:: Position :: before ( segment. syntax ( ) ) , make:: token ( T ! [ '(' ] ) ) ;
216
244
ted:: insert_raw ( ted:: Position :: after ( & node) , make:: token ( T ! [ ')' ] ) ) ;
217
245
}
218
246
219
247
fn process_references (
220
248
ctx : & AssistContext ,
249
+ builder : & mut AssistBuilder ,
221
250
visited_modules : & mut FxHashSet < Module > ,
222
- source_file : & SyntaxNode ,
223
251
enum_module_def : & ModuleDef ,
224
252
variant_hir_name : & Name ,
225
253
refs : Vec < FileReference > ,
@@ -228,8 +256,9 @@ fn process_references(
228
256
// and corresponding nodes up front
229
257
refs. into_iter ( )
230
258
. flat_map ( |reference| {
231
- let ( segment, scope_node, module) =
232
- reference_to_node ( & ctx. sema , source_file, reference) ?;
259
+ let ( segment, scope_node, module) = reference_to_node ( & ctx. sema , reference) ?;
260
+ let segment = builder. make_mut ( segment) ;
261
+ let scope_node = builder. make_syntax_mut ( scope_node) ;
233
262
if !visited_modules. contains ( & module) {
234
263
let mod_path = module. find_use_path_prefixed (
235
264
ctx. sema . db ,
@@ -251,23 +280,22 @@ fn process_references(
251
280
252
281
fn reference_to_node (
253
282
sema : & hir:: Semantics < RootDatabase > ,
254
- source_file : & SyntaxNode ,
255
283
reference : FileReference ,
256
284
) -> Option < ( ast:: PathSegment , SyntaxNode , hir:: Module ) > {
257
- let offset = reference . range . start ( ) ;
258
- if let Some ( path_expr ) = find_node_at_offset :: < ast :: PathExpr > ( source_file , offset ) {
259
- // tuple variant
260
- Some ( ( path_expr . path ( ) ? . segment ( ) ? , path_expr . syntax ( ) . parent ( ) ? ) )
261
- } else if let Some ( record_expr ) = find_node_at_offset :: < ast :: RecordExpr > ( source_file , offset ) {
262
- // record variant
263
- Some ( ( record_expr . path ( ) ? . segment ( ) ? , record_expr . syntax ( ) . clone ( ) ) )
264
- } else {
265
- None
266
- }
267
- . and_then ( | ( segment , expr ) | {
268
- let module = sema . scope ( & expr ) . module ( ) ? ;
269
- Some ( ( segment , expr , module) )
270
- } )
285
+ let segment =
286
+ reference . name . as_name_ref ( ) ? . syntax ( ) . parent ( ) . and_then ( ast :: PathSegment :: cast ) ? ;
287
+ let parent = segment . parent_path ( ) . syntax ( ) . parent ( ) ? ;
288
+ let expr_or_pat = match_ast ! {
289
+ match parent {
290
+ ast :: PathExpr ( _it ) => parent . parent ( ) ? ,
291
+ ast :: RecordExpr ( _it ) => parent ,
292
+ ast :: TupleStructPat ( _it ) => parent ,
293
+ ast :: RecordPat ( _it ) => parent ,
294
+ _ => return None ,
295
+ }
296
+ } ;
297
+ let module = sema . scope ( & expr_or_pat ) . module ( ) ? ;
298
+ Some ( ( segment , expr_or_pat , module ) )
271
299
}
272
300
273
301
#[ cfg( test) ]
@@ -278,6 +306,12 @@ mod tests {
278
306
279
307
use super :: * ;
280
308
309
+ fn check_not_applicable ( ra_fixture : & str ) {
310
+ let fixture =
311
+ format ! ( "//- /main.rs crate:main deps:core\n {}\n {}" , ra_fixture, FamousDefs :: FIXTURE ) ;
312
+ check_assist_not_applicable ( extract_struct_from_enum_variant, & fixture)
313
+ }
314
+
281
315
#[ test]
282
316
fn test_extract_struct_several_fields_tuple ( ) {
283
317
check_assist (
@@ -311,6 +345,32 @@ enum A { One(One) }"#,
311
345
) ;
312
346
}
313
347
348
+ #[ test]
349
+ fn test_extract_struct_carries_over_generics ( ) {
350
+ check_assist (
351
+ extract_struct_from_enum_variant,
352
+ r"enum En<T> { Var { a: T$0 } }" ,
353
+ r#"struct Var<T>{ pub a: T }
354
+
355
+ enum En<T> { Var(Var<T>) }"# ,
356
+ ) ;
357
+ }
358
+
359
+ #[ test]
360
+ fn test_extract_struct_carries_over_attributes ( ) {
361
+ check_assist (
362
+ extract_struct_from_enum_variant,
363
+ r#"#[derive(Debug)]
364
+ #[derive(Clone)]
365
+ enum Enum { Variant{ field: u32$0 } }"# ,
366
+ r#"#[derive(Debug)]#[derive(Clone)] struct Variant{ pub field: u32 }
367
+
368
+ #[derive(Debug)]
369
+ #[derive(Clone)]
370
+ enum Enum { Variant(Variant) }"# ,
371
+ ) ;
372
+ }
373
+
314
374
#[ test]
315
375
fn test_extract_struct_keep_comments_and_attrs_one_field_named ( ) {
316
376
check_assist (
@@ -496,7 +556,7 @@ enum E {
496
556
}
497
557
498
558
fn f() {
499
- let e = E::V { i: 9, j: 2 };
559
+ let E::V { i, j } = E::V { i: 9, j: 2 };
500
560
}
501
561
"# ,
502
562
r#"
@@ -507,7 +567,34 @@ enum E {
507
567
}
508
568
509
569
fn f() {
510
- let e = E::V(V { i: 9, j: 2 });
570
+ let E::V(V { i, j }) = E::V(V { i: 9, j: 2 });
571
+ }
572
+ "# ,
573
+ )
574
+ }
575
+
576
+ #[ test]
577
+ fn extract_record_fix_references2 ( ) {
578
+ check_assist (
579
+ extract_struct_from_enum_variant,
580
+ r#"
581
+ enum E {
582
+ $0V(i32, i32)
583
+ }
584
+
585
+ fn f() {
586
+ let E::V(i, j) = E::V(9, 2);
587
+ }
588
+ "# ,
589
+ r#"
590
+ struct V(pub i32, pub i32);
591
+
592
+ enum E {
593
+ V(V)
594
+ }
595
+
596
+ fn f() {
597
+ let E::V(V(i, j)) = E::V(V(9, 2));
511
598
}
512
599
"# ,
513
600
)
@@ -610,12 +697,6 @@ fn foo() {
610
697
) ;
611
698
}
612
699
613
- fn check_not_applicable ( ra_fixture : & str ) {
614
- let fixture =
615
- format ! ( "//- /main.rs crate:main deps:core\n {}\n {}" , ra_fixture, FamousDefs :: FIXTURE ) ;
616
- check_assist_not_applicable ( extract_struct_from_enum_variant, & fixture)
617
- }
618
-
619
700
#[ test]
620
701
fn test_extract_enum_not_applicable_for_element_with_no_fields ( ) {
621
702
check_not_applicable ( "enum A { $0One }" ) ;
0 commit comments