@@ -15,7 +15,7 @@ use rustc_hash::FxHashSet;
15
15
use syntax:: {
16
16
algo:: find_node_at_offset,
17
17
ast:: { self , make, AstNode , NameOwner , VisibilityOwner } ,
18
- ted, SourceFile , SyntaxElement , SyntaxNode , T ,
18
+ ted, SyntaxNode , T ,
19
19
} ;
20
20
21
21
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -62,6 +62,7 @@ pub(crate) fn extract_struct_from_enum_variant(
62
62
let mut visited_modules_set = FxHashSet :: default ( ) ;
63
63
let current_module = enum_hir. module ( ctx. db ( ) ) ;
64
64
visited_modules_set. insert ( current_module) ;
65
+ // record file references of the file the def resides in, we only want to swap to the edited file in the builder once
65
66
let mut def_file_references = None ;
66
67
for ( file_id, references) in usages {
67
68
if file_id == ctx. frange . file_id {
@@ -70,36 +71,57 @@ pub(crate) fn extract_struct_from_enum_variant(
70
71
}
71
72
builder. edit_file ( file_id) ;
72
73
let source_file = builder. make_ast_mut ( ctx. sema . parse ( file_id) ) ;
73
- for reference in references {
74
- update_reference (
75
- ctx,
76
- reference,
77
- & source_file,
78
- & enum_module_def,
79
- & variant_hir_name,
80
- & mut visited_modules_set,
74
+ let processed = process_references (
75
+ ctx,
76
+ & mut visited_modules_set,
77
+ source_file. syntax ( ) ,
78
+ & enum_module_def,
79
+ & variant_hir_name,
80
+ references,
81
+ ) ;
82
+ processed. into_iter ( ) . for_each ( |( segment, node, import) | {
83
+ if let Some ( ( scope, path) ) = import {
84
+ insert_use ( & scope, mod_path_to_ast ( & path) , ctx. config . insert_use ) ;
85
+ }
86
+ ted:: insert_raw (
87
+ ted:: Position :: before ( segment. syntax ( ) ) ,
88
+ make:: path_from_text ( & format ! ( "{}" , segment) ) . clone_for_update ( ) . syntax ( ) ,
81
89
) ;
82
- }
90
+ ted:: insert_raw ( ted:: Position :: before ( segment. syntax ( ) ) , make:: token ( T ! [ '(' ] ) ) ;
91
+ ted:: insert_raw ( ted:: Position :: after ( & node) , make:: token ( T ! [ ')' ] ) ) ;
92
+ } ) ;
83
93
}
84
94
builder. edit_file ( ctx. frange . file_id ) ;
85
- let variant = builder. make_ast_mut ( variant. clone ( ) ) ;
86
95
let source_file = builder. make_ast_mut ( ctx. sema . parse ( ctx. frange . file_id ) ) ;
87
- for reference in def_file_references. into_iter ( ) . flatten ( ) {
88
- update_reference (
96
+ let variant = builder. make_ast_mut ( variant. clone ( ) ) ;
97
+ if let Some ( references) = def_file_references {
98
+ let processed = process_references (
89
99
ctx,
90
- reference ,
91
- & source_file,
100
+ & mut visited_modules_set ,
101
+ source_file. syntax ( ) ,
92
102
& enum_module_def,
93
103
& variant_hir_name,
94
- & mut visited_modules_set ,
104
+ references ,
95
105
) ;
106
+ processed. into_iter ( ) . for_each ( |( segment, node, import) | {
107
+ if let Some ( ( scope, path) ) = import {
108
+ insert_use ( & scope, mod_path_to_ast ( & path) , ctx. config . insert_use ) ;
109
+ }
110
+ ted:: insert_raw (
111
+ ted:: Position :: before ( segment. syntax ( ) ) ,
112
+ make:: path_from_text ( & format ! ( "{}" , segment) ) . clone_for_update ( ) . syntax ( ) ,
113
+ ) ;
114
+ ted:: insert_raw ( ted:: Position :: before ( segment. syntax ( ) ) , make:: token ( T ! [ '(' ] ) ) ;
115
+ ted:: insert_raw ( ted:: Position :: after ( & node) , make:: token ( T ! [ ')' ] ) ) ;
116
+ } ) ;
96
117
}
97
- extract_struct_def (
98
- variant_name. clone ( ) ,
99
- & field_list,
100
- & variant. parent_enum ( ) . syntax ( ) . clone ( ) . into ( ) ,
101
- enum_ast. visibility ( ) ,
102
- ) ;
118
+
119
+ let def = create_struct_def ( variant_name. clone ( ) , & field_list, enum_ast. visibility ( ) )
120
+ . unwrap ( ) ;
121
+ let start_offset = & variant. parent_enum ( ) . syntax ( ) . clone ( ) ;
122
+ ted:: insert_raw ( ted:: Position :: before ( start_offset) , def. syntax ( ) ) ;
123
+ ted:: insert_raw ( ted:: Position :: before ( start_offset) , & make:: tokens:: blank_line ( ) ) ;
124
+
103
125
update_variant ( & variant) ;
104
126
} ,
105
127
)
@@ -141,31 +163,11 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
141
163
. any ( |( name, _) | name. to_string ( ) == variant_name. to_string ( ) )
142
164
}
143
165
144
- fn insert_import (
145
- ctx : & AssistContext ,
146
- scope_node : & SyntaxNode ,
147
- module : & Module ,
148
- enum_module_def : & ModuleDef ,
149
- variant_hir_name : & Name ,
150
- ) -> Option < ( ) > {
151
- let db = ctx. db ( ) ;
152
- let mod_path =
153
- module. find_use_path_prefixed ( db, * enum_module_def, ctx. config . insert_use . prefix_kind ) ;
154
- if let Some ( mut mod_path) = mod_path {
155
- mod_path. pop_segment ( ) ;
156
- mod_path. push_segment ( variant_hir_name. clone ( ) ) ;
157
- let scope = ImportScope :: find_insert_use_container ( scope_node, & ctx. sema ) ?;
158
- insert_use ( & scope, mod_path_to_ast ( & mod_path) , ctx. config . insert_use ) ;
159
- }
160
- Some ( ( ) )
161
- }
162
-
163
- fn extract_struct_def (
166
+ fn create_struct_def (
164
167
variant_name : ast:: Name ,
165
168
field_list : & Either < ast:: RecordFieldList , ast:: TupleFieldList > ,
166
- start_offset : & SyntaxElement ,
167
169
visibility : Option < ast:: Visibility > ,
168
- ) -> Option < ( ) > {
170
+ ) -> Option < ast :: Struct > {
169
171
let pub_vis = Some ( make:: visibility_pub ( ) ) ;
170
172
let field_list = match field_list {
171
173
Either :: Left ( field_list) => {
@@ -182,18 +184,7 @@ fn extract_struct_def(
182
184
. into ( ) ,
183
185
} ;
184
186
185
- ted:: insert_raw (
186
- ted:: Position :: before ( start_offset) ,
187
- make:: struct_ ( visibility, variant_name, None , field_list) . clone_for_update ( ) . syntax ( ) ,
188
- ) ;
189
- ted:: insert_raw ( ted:: Position :: before ( start_offset) , & make:: tokens:: blank_line ( ) ) ;
190
-
191
- // if let indent_level @ 1..=usize::MAX = IndentLevel::from_node(enum_.syntax()).0 as usize {
192
- // ted::insert(ted::Position::before(start_offset), &make::tokens::blank_line());
193
- // rewriter
194
- // .insert_before(start_offset, &make::tokens::whitespace(&" ".repeat(4 * indent_level)));
195
- // }
196
- Some ( ( ) )
187
+ Some ( make:: struct_ ( visibility, variant_name, None , field_list) . clone_for_update ( ) )
197
188
}
198
189
199
190
fn update_variant ( variant : & ast:: Variant ) -> Option < ( ) > {
@@ -208,42 +199,57 @@ fn update_variant(variant: &ast::Variant) -> Option<()> {
208
199
Some ( ( ) )
209
200
}
210
201
211
- fn update_reference (
202
+ fn process_references (
212
203
ctx : & AssistContext ,
213
- reference : FileReference ,
214
- source_file : & SourceFile ,
204
+ visited_modules : & mut FxHashSet < Module > ,
205
+ source_file : & SyntaxNode ,
215
206
enum_module_def : & ModuleDef ,
216
207
variant_hir_name : & Name ,
217
- visited_modules_set : & mut FxHashSet < Module > ,
218
- ) -> Option < ( ) > {
208
+ refs : Vec < FileReference > ,
209
+ ) -> Vec < ( ast:: PathSegment , SyntaxNode , Option < ( ImportScope , hir:: ModPath ) > ) > {
210
+ refs. into_iter ( )
211
+ . flat_map ( |reference| {
212
+ let ( segment, scope_node, module) =
213
+ reference_to_node ( & ctx. sema , source_file, reference) ?;
214
+ if !visited_modules. contains ( & module) {
215
+ let mod_path = module. find_use_path_prefixed (
216
+ ctx. sema . db ,
217
+ * enum_module_def,
218
+ ctx. config . insert_use . prefix_kind ,
219
+ ) ;
220
+ if let Some ( mut mod_path) = mod_path {
221
+ mod_path. pop_segment ( ) ;
222
+ mod_path. push_segment ( variant_hir_name. clone ( ) ) ;
223
+ // uuuh this wont properly work, find_insert_use_container ascends macros so we might a get new syntax node???
224
+ let scope = ImportScope :: find_insert_use_container ( & scope_node, & ctx. sema ) ?;
225
+ visited_modules. insert ( module) ;
226
+ return Some ( ( segment, scope_node, Some ( ( scope, mod_path) ) ) ) ;
227
+ }
228
+ }
229
+ Some ( ( segment, scope_node, None ) )
230
+ } )
231
+ . collect ( )
232
+ }
233
+
234
+ fn reference_to_node (
235
+ sema : & hir:: Semantics < RootDatabase > ,
236
+ source_file : & SyntaxNode ,
237
+ reference : FileReference ,
238
+ ) -> Option < ( ast:: PathSegment , SyntaxNode , hir:: Module ) > {
219
239
let offset = reference. range . start ( ) ;
220
- let ( segment, expr) = if let Some ( path_expr) =
221
- find_node_at_offset :: < ast:: PathExpr > ( source_file. syntax ( ) , offset)
222
- {
240
+ if let Some ( path_expr) = find_node_at_offset :: < ast:: PathExpr > ( source_file, offset) {
223
241
// tuple variant
224
- ( path_expr. path ( ) ?. segment ( ) ?, path_expr. syntax ( ) . parent ( ) ?)
225
- } else if let Some ( record_expr) =
226
- find_node_at_offset :: < ast:: RecordExpr > ( source_file. syntax ( ) , offset)
227
- {
242
+ Some ( ( path_expr. path ( ) ?. segment ( ) ?, path_expr. syntax ( ) . parent ( ) ?) )
243
+ } else if let Some ( record_expr) = find_node_at_offset :: < ast:: RecordExpr > ( source_file, offset) {
228
244
// record variant
229
- ( record_expr. path ( ) ?. segment ( ) ?, record_expr. syntax ( ) . clone ( ) )
245
+ Some ( ( record_expr. path ( ) ?. segment ( ) ?, record_expr. syntax ( ) . clone ( ) ) )
230
246
} else {
231
- return None ;
232
- } ;
233
-
234
- let module = ctx. sema . scope ( & expr) . module ( ) ?;
235
- if !visited_modules_set. contains ( & module) {
236
- if insert_import ( ctx, & expr, & module, enum_module_def, variant_hir_name) . is_some ( ) {
237
- visited_modules_set. insert ( module) ;
238
- }
247
+ None
239
248
}
240
- ted:: insert_raw (
241
- ted:: Position :: before ( segment. syntax ( ) ) ,
242
- make:: path_from_text ( & format ! ( "{}" , segment) ) . clone_for_update ( ) . syntax ( ) ,
243
- ) ;
244
- ted:: insert_raw ( ted:: Position :: before ( segment. syntax ( ) ) , make:: token ( T ! [ '(' ] ) ) ;
245
- ted:: insert_raw ( ted:: Position :: after ( & expr) , make:: token ( T ! [ ')' ] ) ) ;
246
- Some ( ( ) )
249
+ . and_then ( |( segment, expr) | {
250
+ let module = sema. scope ( & expr) . module ( ) ?;
251
+ Some ( ( segment, expr, module) )
252
+ } )
247
253
}
248
254
249
255
#[ cfg( test) ]
@@ -350,7 +356,7 @@ mod my_mod {
350
356
351
357
pub struct MyField(pub u8, pub u8);
352
358
353
- pub enum MyEnum {
359
+ pub enum MyEnum {
354
360
MyField(MyField),
355
361
}
356
362
}
0 commit comments