1
1
//! FIXME: write short doc here
2
2
3
- use std:: iter;
3
+ use std:: { collections :: LinkedList , iter} ;
4
4
5
5
use hir:: { Adt , HasSource , Semantics } ;
6
6
use ra_ide_db:: RootDatabase ;
7
- use ra_syntax:: ast:: { self , edit:: IndentLevel , make, AstNode , NameOwner } ;
8
7
9
8
use crate :: { Assist , AssistCtx , AssistId } ;
9
+ use ra_syntax:: {
10
+ ast:: { self , edit:: IndentLevel , make, AstNode , NameOwner } ,
11
+ SyntaxKind , SyntaxNode ,
12
+ } ;
13
+
14
+ use ast:: { MatchArm , MatchGuard , Pat } ;
10
15
11
16
// Assist: fill_match_arms
12
17
//
@@ -36,16 +41,6 @@ pub(crate) fn fill_match_arms(ctx: AssistCtx) -> Option<Assist> {
36
41
let match_expr = ctx. find_node_at_offset :: < ast:: MatchExpr > ( ) ?;
37
42
let match_arm_list = match_expr. match_arm_list ( ) ?;
38
43
39
- // We already have some match arms, so we don't provide any assists.
40
- // Unless if there is only one trivial match arm possibly created
41
- // by match postfix complete. Trivial match arm is the catch all arm.
42
- let mut existing_arms = match_arm_list. arms ( ) ;
43
- if let Some ( arm) = existing_arms. next ( ) {
44
- if !is_trivial ( & arm) || existing_arms. next ( ) . is_some ( ) {
45
- return None ;
46
- }
47
- } ;
48
-
49
44
let expr = match_expr. expr ( ) ?;
50
45
let enum_def = resolve_enum_def ( & ctx. sema , & expr) ?;
51
46
let module = ctx. sema . scope ( expr. syntax ( ) ) . module ( ) ?;
@@ -56,29 +51,113 @@ pub(crate) fn fill_match_arms(ctx: AssistCtx) -> Option<Assist> {
56
51
}
57
52
58
53
let db = ctx. db ;
59
-
60
54
ctx. add_assist ( AssistId ( "fill_match_arms" ) , "Fill match arms" , |edit| {
61
- let indent_level = IndentLevel :: from_node ( match_arm_list. syntax ( ) ) ;
55
+ let mut arms: Vec < MatchArm > = match_arm_list. arms ( ) . collect ( ) ;
56
+ if arms. len ( ) == 1 {
57
+ if let Some ( Pat :: PlaceholderPat ( ..) ) = arms[ 0 ] . pat ( ) {
58
+ arms. clear ( ) ;
59
+ }
60
+ }
62
61
63
- let new_arm_list = {
64
- let arms = variants
65
- . into_iter ( )
66
- . filter_map ( |variant| build_pat ( db, module, variant) )
67
- . map ( |pat| make:: match_arm ( iter:: once ( pat) , make:: expr_unit ( ) ) ) ;
68
- indent_level. increase_indent ( make:: match_arm_list ( arms) )
69
- } ;
62
+ let mut has_partial_match = false ;
63
+ let variants: Vec < MatchArm > = variants
64
+ . into_iter ( )
65
+ . filter_map ( |variant| build_pat ( db, module, variant) )
66
+ . filter ( |variant_pat| {
67
+ !arms. iter ( ) . filter_map ( |arm| arm. pat ( ) . map ( |_| arm) ) . any ( |arm| {
68
+ let pat = arm. pat ( ) . unwrap ( ) ;
69
+
70
+ // Special casee OrPat as separate top-level pats
71
+ let pats: Vec < Pat > = match Pat :: from ( pat. clone ( ) ) {
72
+ Pat :: OrPat ( pats) => pats. pats ( ) . collect :: < Vec < _ > > ( ) ,
73
+ _ => vec ! [ pat] ,
74
+ } ;
75
+
76
+ pats. iter ( ) . any ( |pat| {
77
+ match does_arm_pat_match_variant ( pat, arm. guard ( ) , variant_pat) {
78
+ ArmMatch :: Yes => true ,
79
+ ArmMatch :: No => false ,
80
+ ArmMatch :: Partial => {
81
+ has_partial_match = true ;
82
+ true
83
+ }
84
+ }
85
+ } )
86
+ } )
87
+ } )
88
+ . map ( |pat| make:: match_arm ( iter:: once ( pat) , make:: expr_unit ( ) ) )
89
+ . collect ( ) ;
90
+
91
+ arms. extend ( variants) ;
92
+ if has_partial_match {
93
+ arms. push ( make:: match_arm (
94
+ iter:: once ( make:: placeholder_pat ( ) . into ( ) ) ,
95
+ make:: expr_unit ( ) ,
96
+ ) ) ;
97
+ }
98
+
99
+ let indent_level = IndentLevel :: from_node ( match_arm_list. syntax ( ) ) ;
100
+ let new_arm_list = indent_level. increase_indent ( make:: match_arm_list ( arms) ) ;
70
101
71
102
edit. target ( match_expr. syntax ( ) . text_range ( ) ) ;
72
103
edit. set_cursor ( expr. syntax ( ) . text_range ( ) . start ( ) ) ;
73
104
edit. replace_ast ( match_arm_list, new_arm_list) ;
74
105
} )
75
106
}
76
107
77
- fn is_trivial ( arm : & ast:: MatchArm ) -> bool {
78
- match arm. pat ( ) {
79
- Some ( ast:: Pat :: PlaceholderPat ( ..) ) => true ,
80
- _ => false ,
108
+ enum ArmMatch {
109
+ Yes ,
110
+ No ,
111
+ Partial ,
112
+ }
113
+
114
+ fn does_arm_pat_match_variant ( arm : & Pat , arm_guard : Option < MatchGuard > , var : & Pat ) -> ArmMatch {
115
+ let arm = flatten_pats ( arm. clone ( ) ) ;
116
+ let var = flatten_pats ( var. clone ( ) ) ;
117
+ let mut arm = arm. iter ( ) ;
118
+ let mut var = var. iter ( ) ;
119
+
120
+ // If the first part of the Pat don't match, there's no match
121
+ match ( arm. next ( ) , var. next ( ) ) {
122
+ ( Some ( arm) , Some ( var) ) if arm. text ( ) == var. text ( ) => { }
123
+ _ => return ArmMatch :: No ,
124
+ }
125
+
126
+ // If we have a guard we automatically know we have a partial match
127
+ if arm_guard. is_some ( ) {
128
+ return ArmMatch :: Partial ;
129
+ }
130
+
131
+ if arm. clone ( ) . count ( ) != var. clone ( ) . count ( ) {
132
+ return ArmMatch :: Partial ;
133
+ }
134
+
135
+ let direct_match = arm. zip ( var) . all ( |( arm, var) | {
136
+ if arm. text ( ) == var. text ( ) {
137
+ return true ;
138
+ }
139
+ match ( arm. kind ( ) , var. kind ( ) ) {
140
+ ( SyntaxKind :: PLACEHOLDER_PAT , SyntaxKind :: PLACEHOLDER_PAT ) => true ,
141
+ ( SyntaxKind :: DOT_DOT_PAT , SyntaxKind :: PLACEHOLDER_PAT ) => true ,
142
+ ( SyntaxKind :: BIND_PAT , SyntaxKind :: PLACEHOLDER_PAT ) => true ,
143
+ _ => false ,
144
+ }
145
+ } ) ;
146
+
147
+ match direct_match {
148
+ true => ArmMatch :: Yes ,
149
+ false => ArmMatch :: Partial ,
150
+ }
151
+ }
152
+
153
+ fn flatten_pats ( pat : Pat ) -> Vec < SyntaxNode > {
154
+ let mut pats: LinkedList < SyntaxNode > = pat. syntax ( ) . children ( ) . collect ( ) ;
155
+ let mut out: Vec < SyntaxNode > = vec ! [ ] ;
156
+ while let Some ( p) = pats. pop_front ( ) {
157
+ pats. extend ( p. children ( ) ) ;
158
+ out. push ( p) ;
81
159
}
160
+ out
82
161
}
83
162
84
163
fn resolve_enum_def ( sema : & Semantics < RootDatabase > , expr : & ast:: Expr ) -> Option < hir:: Enum > {
@@ -114,6 +193,183 @@ mod tests {
114
193
115
194
use super :: fill_match_arms;
116
195
196
+ #[ test]
197
+ fn partial_fill_multi ( ) {
198
+ check_assist (
199
+ fill_match_arms,
200
+ r#"
201
+ enum A {
202
+ As,
203
+ Bs(i32, Option<i32>)
204
+ }
205
+ fn main() {
206
+ match A::As<|> {
207
+ A::Bs(_, Some(_)) => (),
208
+ }
209
+ }
210
+ "# ,
211
+ r#"
212
+ enum A {
213
+ As,
214
+ Bs(i32, Option<i32>)
215
+ }
216
+ fn main() {
217
+ match <|>A::As {
218
+ A::Bs(_, Some(_)) => (),
219
+ A::As => (),
220
+ _ => (),
221
+ }
222
+ }
223
+ "# ,
224
+ ) ;
225
+ }
226
+
227
+ #[ test]
228
+ fn partial_fill_record ( ) {
229
+ check_assist (
230
+ fill_match_arms,
231
+ r#"
232
+ enum A {
233
+ As,
234
+ Bs{x:i32, y:Option<i32>},
235
+ }
236
+ fn main() {
237
+ match A::As<|> {
238
+ A::Bs{x,y:Some(_)} => (),
239
+ }
240
+ }
241
+ "# ,
242
+ r#"
243
+ enum A {
244
+ As,
245
+ Bs{x:i32, y:Option<i32>},
246
+ }
247
+ fn main() {
248
+ match <|>A::As {
249
+ A::Bs{x,y:Some(_)} => (),
250
+ A::As => (),
251
+ _ => (),
252
+ }
253
+ }
254
+ "# ,
255
+ ) ;
256
+ }
257
+
258
+ #[ test]
259
+ fn partial_fill_or_pat ( ) {
260
+ check_assist (
261
+ fill_match_arms,
262
+ r#"
263
+ enum A {
264
+ As,
265
+ Bs,
266
+ Cs(Option<i32>),
267
+ }
268
+ fn main() {
269
+ match A::As<|> {
270
+ A::Cs(_) | A::Bs => (),
271
+ }
272
+ }
273
+ "# ,
274
+ r#"
275
+ enum A {
276
+ As,
277
+ Bs,
278
+ Cs(Option<i32>),
279
+ }
280
+ fn main() {
281
+ match <|>A::As {
282
+ A::Cs(_) | A::Bs => (),
283
+ A::As => (),
284
+ }
285
+ }
286
+ "# ,
287
+ ) ;
288
+ }
289
+
290
+ #[ test]
291
+ fn partial_fill_or_pat2 ( ) {
292
+ check_assist (
293
+ fill_match_arms,
294
+ r#"
295
+ enum A {
296
+ As,
297
+ Bs,
298
+ Cs(Option<i32>),
299
+ }
300
+ fn main() {
301
+ match A::As<|> {
302
+ A::Cs(Some(_)) | A::Bs => (),
303
+ }
304
+ }
305
+ "# ,
306
+ r#"
307
+ enum A {
308
+ As,
309
+ Bs,
310
+ Cs(Option<i32>),
311
+ }
312
+ fn main() {
313
+ match <|>A::As {
314
+ A::Cs(Some(_)) | A::Bs => (),
315
+ A::As => (),
316
+ _ => (),
317
+ }
318
+ }
319
+ "# ,
320
+ ) ;
321
+ }
322
+
323
+ #[ test]
324
+ fn partial_fill ( ) {
325
+ check_assist (
326
+ fill_match_arms,
327
+ r#"
328
+ enum A {
329
+ As,
330
+ Bs,
331
+ Cs,
332
+ Ds(String),
333
+ Es(B),
334
+ }
335
+ enum B {
336
+ Xs,
337
+ Ys,
338
+ }
339
+ fn main() {
340
+ match A::As<|> {
341
+ A::Bs if 0 < 1 => (),
342
+ A::Ds(_value) => (),
343
+ A::Es(B::Xs) => (),
344
+ }
345
+ }
346
+ "# ,
347
+ r#"
348
+ enum A {
349
+ As,
350
+ Bs,
351
+ Cs,
352
+ Ds(String),
353
+ Es(B),
354
+ }
355
+ enum B {
356
+ Xs,
357
+ Ys,
358
+ }
359
+ fn main() {
360
+ match <|>A::As {
361
+ A::Bs if 0 < 1 => (),
362
+ A::Ds(_value) => (),
363
+ A::Es(B::Xs) => (),
364
+ A::As => (),
365
+ A::Cs => (),
366
+ _ => (),
367
+ }
368
+ }
369
+ "# ,
370
+ ) ;
371
+ }
372
+
117
373
#[ test]
118
374
fn fill_match_arms_empty_body ( ) {
119
375
check_assist (
0 commit comments