@@ -5,17 +5,17 @@ use hir::{Module, ModuleDef, Name, Variant};
5
5
use ide_db:: {
6
6
defs:: Definition ,
7
7
helpers:: {
8
- insert_use:: { insert_use, ImportScope } ,
8
+ insert_use:: { insert_use, ImportScope , InsertUseConfig } ,
9
9
mod_path_to_ast,
10
10
} ,
11
11
search:: FileReference ,
12
12
RootDatabase ,
13
13
} ;
14
14
use rustc_hash:: FxHashSet ;
15
15
use syntax:: {
16
- algo:: { find_node_at_offset, SyntaxRewriter } ,
17
- ast:: { self , edit :: IndentLevel , make, AstNode , NameOwner , VisibilityOwner } ,
18
- SourceFile , SyntaxElement , SyntaxNode , T ,
16
+ algo:: find_node_at_offset,
17
+ ast:: { self , make, AstNode , NameOwner , VisibilityOwner } ,
18
+ ted , SyntaxNode , T ,
19
19
} ;
20
20
21
21
use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -62,40 +62,50 @@ pub(crate) fn extract_struct_from_enum_variant(
62
62
let mut visited_modules_set = FxHashSet :: default ( ) ;
63
63
let current_module = enum_hir. module ( ctx. db ( ) ) ;
64
64
visited_modules_set. insert ( current_module) ;
65
- let mut def_rewriter = None ;
65
+ // record file references of the file the def resides in, we only want to swap to the edited file in the builder once
66
+ let mut def_file_references = None ;
66
67
for ( file_id, references) in usages {
67
- let mut rewriter = SyntaxRewriter :: default ( ) ;
68
- let source_file = ctx. sema . parse ( file_id) ;
69
- for reference in references {
70
- update_reference (
71
- ctx,
72
- & mut rewriter,
73
- reference,
74
- & source_file,
75
- & enum_module_def,
76
- & variant_hir_name,
77
- & mut visited_modules_set,
78
- ) ;
79
- }
80
68
if file_id == ctx. frange . file_id {
81
- def_rewriter = Some ( rewriter ) ;
69
+ def_file_references = Some ( references ) ;
82
70
continue ;
83
71
}
84
72
builder. edit_file ( file_id) ;
85
- builder. rewrite ( rewriter) ;
73
+ let source_file = builder. make_ast_mut ( ctx. sema . parse ( file_id) ) ;
74
+ let processed = process_references (
75
+ ctx,
76
+ & mut visited_modules_set,
77
+ source_file. syntax ( ) ,
78
+ & enum_module_def,
79
+ & variant_hir_name,
80
+ references,
81
+ ) ;
82
+ processed. into_iter ( ) . for_each ( |( path, node, import) | {
83
+ apply_references ( ctx. config . insert_use , path, node, import)
84
+ } ) ;
86
85
}
87
- let mut rewriter = def_rewriter. unwrap_or_default ( ) ;
88
- update_variant ( & mut rewriter, & variant) ;
89
- extract_struct_def (
90
- & mut rewriter,
91
- & enum_ast,
92
- variant_name. clone ( ) ,
93
- & field_list,
94
- & variant. parent_enum ( ) . syntax ( ) . clone ( ) . into ( ) ,
95
- enum_ast. visibility ( ) ,
96
- ) ;
97
86
builder. edit_file ( ctx. frange . file_id ) ;
98
- builder. rewrite ( rewriter) ;
87
+ let source_file = builder. make_ast_mut ( ctx. sema . parse ( ctx. frange . file_id ) ) ;
88
+ let variant = builder. make_ast_mut ( variant. clone ( ) ) ;
89
+ if let Some ( references) = def_file_references {
90
+ let processed = process_references (
91
+ ctx,
92
+ & mut visited_modules_set,
93
+ source_file. syntax ( ) ,
94
+ & enum_module_def,
95
+ & variant_hir_name,
96
+ references,
97
+ ) ;
98
+ processed. into_iter ( ) . for_each ( |( path, node, import) | {
99
+ apply_references ( ctx. config . insert_use , path, node, import)
100
+ } ) ;
101
+ }
102
+
103
+ let def = create_struct_def ( variant_name. clone ( ) , & field_list, enum_ast. visibility ( ) ) ;
104
+ let start_offset = & variant. parent_enum ( ) . syntax ( ) . clone ( ) ;
105
+ ted:: insert_raw ( ted:: Position :: before ( start_offset) , def. syntax ( ) ) ;
106
+ ted:: insert_raw ( ted:: Position :: before ( start_offset) , & make:: tokens:: blank_line ( ) ) ;
107
+
108
+ update_variant ( & variant) ;
99
109
} ,
100
110
)
101
111
}
@@ -136,34 +146,11 @@ fn existing_definition(db: &RootDatabase, variant_name: &ast::Name, variant: &Va
136
146
. any ( |( name, _) | name. to_string ( ) == variant_name. to_string ( ) )
137
147
}
138
148
139
- fn insert_import (
140
- ctx : & AssistContext ,
141
- rewriter : & mut SyntaxRewriter ,
142
- scope_node : & SyntaxNode ,
143
- module : & Module ,
144
- enum_module_def : & ModuleDef ,
145
- variant_hir_name : & Name ,
146
- ) -> Option < ( ) > {
147
- let db = ctx. db ( ) ;
148
- let mod_path =
149
- module. find_use_path_prefixed ( db, * enum_module_def, ctx. config . insert_use . prefix_kind ) ;
150
- if let Some ( mut mod_path) = mod_path {
151
- mod_path. pop_segment ( ) ;
152
- mod_path. push_segment ( variant_hir_name. clone ( ) ) ;
153
- let scope = ImportScope :: find_insert_use_container ( scope_node, & ctx. sema ) ?;
154
- * rewriter += insert_use ( & scope, mod_path_to_ast ( & mod_path) , ctx. config . insert_use ) ;
155
- }
156
- Some ( ( ) )
157
- }
158
-
159
- fn extract_struct_def (
160
- rewriter : & mut SyntaxRewriter ,
161
- enum_ : & ast:: Enum ,
149
+ fn create_struct_def (
162
150
variant_name : ast:: Name ,
163
151
field_list : & Either < ast:: RecordFieldList , ast:: TupleFieldList > ,
164
- start_offset : & SyntaxElement ,
165
152
visibility : Option < ast:: Visibility > ,
166
- ) -> Option < ( ) > {
153
+ ) -> ast :: Struct {
167
154
let pub_vis = Some ( make:: visibility_pub ( ) ) ;
168
155
let field_list = match field_list {
169
156
Either :: Left ( field_list) => {
@@ -180,65 +167,90 @@ fn extract_struct_def(
180
167
. into ( ) ,
181
168
} ;
182
169
183
- rewriter. insert_before (
184
- start_offset,
185
- make:: struct_ ( visibility, variant_name, None , field_list) . syntax ( ) ,
186
- ) ;
187
- rewriter. insert_before ( start_offset, & make:: tokens:: blank_line ( ) ) ;
188
-
189
- if let indent_level @ 1 ..=usize:: MAX = IndentLevel :: from_node ( enum_. syntax ( ) ) . 0 as usize {
190
- rewriter
191
- . insert_before ( start_offset, & make:: tokens:: whitespace ( & " " . repeat ( 4 * indent_level) ) ) ;
192
- }
193
- Some ( ( ) )
170
+ make:: struct_ ( visibility, variant_name, None , field_list) . clone_for_update ( )
194
171
}
195
172
196
- fn update_variant ( rewriter : & mut SyntaxRewriter , variant : & ast:: Variant ) -> Option < ( ) > {
173
+ fn update_variant ( variant : & ast:: Variant ) -> Option < ( ) > {
197
174
let name = variant. name ( ) ?;
198
175
let tuple_field = make:: tuple_field ( None , make:: ty ( & name. text ( ) ) ) ;
199
176
let replacement = make:: variant (
200
177
name,
201
178
Some ( ast:: FieldList :: TupleFieldList ( make:: tuple_field_list ( iter:: once ( tuple_field) ) ) ) ,
202
- ) ;
203
- rewriter. replace ( variant. syntax ( ) , replacement. syntax ( ) ) ;
179
+ )
180
+ . clone_for_update ( ) ;
181
+ ted:: replace ( variant. syntax ( ) , replacement. syntax ( ) ) ;
204
182
Some ( ( ) )
205
183
}
206
184
207
- fn update_reference (
185
+ fn apply_references (
186
+ insert_use_cfg : InsertUseConfig ,
187
+ segment : ast:: PathSegment ,
188
+ node : SyntaxNode ,
189
+ import : Option < ( ImportScope , hir:: ModPath ) > ,
190
+ ) {
191
+ if let Some ( ( scope, path) ) = import {
192
+ insert_use ( & scope, mod_path_to_ast ( & path) , insert_use_cfg) ;
193
+ }
194
+ ted:: insert_raw (
195
+ ted:: Position :: before ( segment. syntax ( ) ) ,
196
+ make:: path_from_text ( & format ! ( "{}" , segment) ) . clone_for_update ( ) . syntax ( ) ,
197
+ ) ;
198
+ ted:: insert_raw ( ted:: Position :: before ( segment. syntax ( ) ) , make:: token ( T ! [ '(' ] ) ) ;
199
+ ted:: insert_raw ( ted:: Position :: after ( & node) , make:: token ( T ! [ ')' ] ) ) ;
200
+ }
201
+
202
+ fn process_references (
208
203
ctx : & AssistContext ,
209
- rewriter : & mut SyntaxRewriter ,
210
- reference : FileReference ,
211
- source_file : & SourceFile ,
204
+ visited_modules : & mut FxHashSet < Module > ,
205
+ source_file : & SyntaxNode ,
212
206
enum_module_def : & ModuleDef ,
213
207
variant_hir_name : & Name ,
214
- visited_modules_set : & mut FxHashSet < Module > ,
215
- ) -> Option < ( ) > {
208
+ refs : Vec < FileReference > ,
209
+ ) -> Vec < ( ast:: PathSegment , SyntaxNode , Option < ( ImportScope , hir:: ModPath ) > ) > {
210
+ // we have to recollect here eagerly as we are about to edit the tree we need to calculate the changes
211
+ // and corresponding nodes up front
212
+ refs. into_iter ( )
213
+ . flat_map ( |reference| {
214
+ let ( segment, scope_node, module) =
215
+ reference_to_node ( & ctx. sema , source_file, reference) ?;
216
+ if !visited_modules. contains ( & module) {
217
+ let mod_path = module. find_use_path_prefixed (
218
+ ctx. sema . db ,
219
+ * enum_module_def,
220
+ ctx. config . insert_use . prefix_kind ,
221
+ ) ;
222
+ if let Some ( mut mod_path) = mod_path {
223
+ mod_path. pop_segment ( ) ;
224
+ mod_path. push_segment ( variant_hir_name. clone ( ) ) ;
225
+ let scope = ImportScope :: find_insert_use_container ( & scope_node) ?;
226
+ visited_modules. insert ( module) ;
227
+ return Some ( ( segment, scope_node, Some ( ( scope, mod_path) ) ) ) ;
228
+ }
229
+ }
230
+ Some ( ( segment, scope_node, None ) )
231
+ } )
232
+ . collect ( )
233
+ }
234
+
235
+ fn reference_to_node (
236
+ sema : & hir:: Semantics < RootDatabase > ,
237
+ source_file : & SyntaxNode ,
238
+ reference : FileReference ,
239
+ ) -> Option < ( ast:: PathSegment , SyntaxNode , hir:: Module ) > {
216
240
let offset = reference. range . start ( ) ;
217
- let ( segment, expr) = if let Some ( path_expr) =
218
- find_node_at_offset :: < ast:: PathExpr > ( source_file. syntax ( ) , offset)
219
- {
241
+ if let Some ( path_expr) = find_node_at_offset :: < ast:: PathExpr > ( source_file, offset) {
220
242
// tuple variant
221
- ( path_expr. path ( ) ?. segment ( ) ?, path_expr. syntax ( ) . parent ( ) ?)
222
- } else if let Some ( record_expr) =
223
- find_node_at_offset :: < ast:: RecordExpr > ( source_file. syntax ( ) , offset)
224
- {
243
+ Some ( ( path_expr. path ( ) ?. segment ( ) ?, path_expr. syntax ( ) . parent ( ) ?) )
244
+ } else if let Some ( record_expr) = find_node_at_offset :: < ast:: RecordExpr > ( source_file, offset) {
225
245
// record variant
226
- ( record_expr. path ( ) ?. segment ( ) ?, record_expr. syntax ( ) . clone ( ) )
246
+ Some ( ( record_expr. path ( ) ?. segment ( ) ?, record_expr. syntax ( ) . clone ( ) ) )
227
247
} else {
228
- return None ;
229
- } ;
230
-
231
- let module = ctx. sema . scope ( & expr) . module ( ) ?;
232
- if !visited_modules_set. contains ( & module) {
233
- if insert_import ( ctx, rewriter, & expr, & module, enum_module_def, variant_hir_name) . is_some ( )
234
- {
235
- visited_modules_set. insert ( module) ;
236
- }
248
+ None
237
249
}
238
- rewriter . insert_after ( segment . syntax ( ) , & make :: token ( T ! [ '(' ] ) ) ;
239
- rewriter . insert_after ( segment . syntax ( ) , segment . syntax ( ) ) ;
240
- rewriter . insert_after ( & expr, & make :: token ( T ! [ ')' ] ) ) ;
241
- Some ( ( ) )
250
+ . and_then ( | ( segment , expr ) | {
251
+ let module = sema . scope ( & expr ) . module ( ) ? ;
252
+ Some ( ( segment , expr, module ) )
253
+ } )
242
254
}
243
255
244
256
#[ cfg( test) ]
@@ -345,7 +357,7 @@ mod my_mod {
345
357
346
358
pub struct MyField(pub u8, pub u8);
347
359
348
- pub enum MyEnum {
360
+ pub enum MyEnum {
349
361
MyField(MyField),
350
362
}
351
363
}
0 commit comments