@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
14
14
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
15
15
func.func @create_scalable_vector_mask_to_constant_mask () -> (vector <[8 ]xi1 >) {
16
16
%c -1 = arith.constant -1 : index
17
- // CHECK: vector.constant_mask [0] : vector<[8]xi1>
17
+ // CHECK: arith.constant dense<false> : vector<[8]xi1>
18
18
%0 = vector.create_mask %c -1 : vector <[8 ]xi1 >
19
19
return %0 : vector <[8 ]xi1 >
20
20
}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
36
36
func.func @create_vector_mask_to_constant_mask_truncation_neg () -> (vector <4 x3 xi1 >) {
37
37
%cneg2 = arith.constant -2 : index
38
38
%c5 = arith.constant 5 : index
39
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
39
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
40
40
%0 = vector.create_mask %c5 , %cneg2 : vector <4 x3 xi1 >
41
41
return %0 : vector <4 x3 xi1 >
42
42
}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
47
47
func.func @create_vector_mask_to_constant_mask_truncation_zero () -> (vector <4 x3 xi1 >) {
48
48
%c2 = arith.constant 2 : index
49
49
%c0 = arith.constant 0 : index
50
- // CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
50
+ // CHECK: arith.constant dense<false> : vector<4x3xi1>
51
51
%0 = vector.create_mask %c0 , %c2 : vector <4 x3 xi1 >
52
52
return %0 : vector <4 x3 xi1 >
53
53
}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
60
60
%c16 = arith.constant 16 : index
61
61
%0 = vector.vscale
62
62
%1 = arith.muli %0 , %c16 : index
63
- // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
63
+ // CHECK: arith.constant dense<true> : vector<8x[16]xi1>
64
64
%10 = vector.create_mask %c8 , %1 : vector <8 x[16 ]xi1 >
65
65
return %10 : vector <8 x[16 ]xi1 >
66
66
}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
272
272
273
273
// -----
274
274
275
+ // CHECK-LABEL: constant_mask_to_true_splat
276
+ func.func @constant_mask_to_true_splat () -> vector <2 x4 xi1 > {
277
+ // CHECK: arith.constant dense<true>
278
+ // CHECK-NOT: vector.constant_mask
279
+ %0 = vector.constant_mask [2 , 4 ] : vector <2 x4 xi1 >
280
+ return %0 : vector <2 x4 xi1 >
281
+ }
282
+
283
+ // CHECK-LABEL: constant_mask_to_false_splat
284
+ func.func @constant_mask_to_false_splat () -> vector <2 x4 xi1 > {
285
+ // CHECK: arith.constant dense<false>
286
+ // CHECK-NOT: vector.constant_mask
287
+ %0 = vector.constant_mask [0 , 0 ] : vector <2 x4 xi1 >
288
+ return %0 : vector <2 x4 xi1 >
289
+ }
290
+
291
+ // CHECK-LABEL: constant_mask_to_true_splat_0d
292
+ func.func @constant_mask_to_true_splat_0d () -> vector <i1 > {
293
+ // CHECK: arith.constant dense<true>
294
+ // CHECK-NOT: vector.constant_mask
295
+ %0 = vector.constant_mask [1 ] : vector <i1 >
296
+ return %0 : vector <i1 >
297
+ }
298
+
275
299
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
276
300
func.func @constant_mask_transpose_to_transposed_constant_mask () -> (vector <2 x3 x4 xi1 >, vector <4 x2 x3 xi1 >) {
277
301
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
289
313
%1 = vector.extract_strided_slice %0
290
314
{offsets = [0 , 0 ], sizes = [2 , 2 ], strides = [1 , 1 ]}
291
315
: vector <4 x3 xi1 > to vector <2 x2 xi1 >
292
- // CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
316
+ // CHECK: arith.constant dense<true> : vector<2x2xi1>
293
317
return %1 : vector <2 x2 xi1 >
294
318
}
295
319
@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
322
346
%1 = vector.extract_strided_slice %0
323
347
{offsets = [2 , 0 ], sizes = [2 , 2 ], strides = [1 , 1 ]}
324
348
: vector <4 x3 xi1 > to vector <2 x2 xi1 >
325
- // CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
349
+ // CHECK: arith.constant dense<false> : vector<2x2xi1>
326
350
return %1 : vector <2 x2 xi1 >
327
351
}
328
352
@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
333
357
%1 = vector.extract_strided_slice %0
334
358
{offsets = [0 , 2 ], sizes = [2 , 1 ], strides = [1 , 1 ]}
335
359
: vector <4 x3 xi1 > to vector <2 x1 xi1 >
336
- // CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
360
+ // CHECK: arith.constant dense<false> : vector<2x1xi1>
337
361
return %1 : vector <2 x1 xi1 >
338
362
}
339
363
@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
344
368
%1 = vector.extract_strided_slice %0
345
369
{offsets = [0 , 1 ], sizes = [2 , 1 ], strides = [1 , 1 ]}
346
370
: vector <4 x3 xi1 > to vector <2 x1 xi1 >
347
- // CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
371
+ // CHECK: arith.constant dense<true> : vector<2x1xi1>
348
372
return %1 : vector <2 x1 xi1 >
349
373
}
350
374
0 commit comments