Skip to content

Commit 5eb195f

Browse files
authored
[mlir][Vector] Fold vector.constant_mask to SplatElementsAttr (#146724)
Adds a folder to vector.constant_mask to fold to SplatElementsAttr when possible
1 parent 0b4941a commit 5eb195f

File tree

5 files changed

+58
-11
lines changed

5 files changed

+58
-11
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,6 +2492,7 @@ def Vector_ConstantMaskOp :
24922492

24932493
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
24942494
let hasVerifier = 1;
2495+
let hasFolder = 1;
24952496
}
24962497

24972498
def Vector_CreateMaskOp :

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6594,6 +6594,28 @@ bool ConstantMaskOp::isAllOnesMask() {
65946594
return true;
65956595
}
65966596

6597+
OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
6598+
ArrayRef<int64_t> bounds = getMaskDimSizes();
6599+
ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
6600+
6601+
auto createBoolSplat = [&](bool x) {
6602+
return SplatElementsAttr::get(getVectorType(),
6603+
BoolAttr::get(getContext(), x));
6604+
};
6605+
6606+
// Check the corner case of 0-D vectors first.
6607+
if (vectorSizes.empty()) {
6608+
assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
6609+
return createBoolSplat(bounds[0] == 1);
6610+
}
6611+
// Fold vector.constant_mask to splat if possible.
6612+
if (bounds == vectorSizes)
6613+
return createBoolSplat(true);
6614+
if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
6615+
return createBoolSplat(false);
6616+
return OpFoldResult();
6617+
}
6618+
65976619
//===----------------------------------------------------------------------===//
65986620
// CreateMaskOp
65996621
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
2424
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
2525
// CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index
26+
// CHECK-DAG: %[[mask:.*]] = arith.constant dense<true> : vector<16xi1>
2627
// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref<?xi32>
2728
// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64
2829
// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index
@@ -31,7 +32,6 @@
3132
// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index
3233
// CHECK: %[[boundary:.*]] = affine.apply #[[$map0]]()[%[[q]], %[[s]]]
3334
// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[boundary]] step %[[c16]] {
34-
// CHECK: %[[mask:.*]] = vector.constant_mask [16] : vector<16xi1>
3535
// CHECK: %[[li:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xi32>, vector<16xi32>
3636
// CHECK: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64>
3737
// CHECK: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ func.func @create_vector_mask_to_constant_mask() -> (vector<4x3xi1>) {
1414
// CHECK-LABEL: create_scalable_vector_mask_to_constant_mask
1515
func.func @create_scalable_vector_mask_to_constant_mask() -> (vector<[8]xi1>) {
1616
%c-1 = arith.constant -1 : index
17-
// CHECK: vector.constant_mask [0] : vector<[8]xi1>
17+
// CHECK: arith.constant dense<false> : vector<[8]xi1>
1818
%0 = vector.create_mask %c-1 : vector<[8]xi1>
1919
return %0 : vector<[8]xi1>
2020
}
@@ -36,7 +36,7 @@ func.func @create_vector_mask_to_constant_mask_truncation() -> (vector<4x3xi1>)
3636
func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi1>) {
3737
%cneg2 = arith.constant -2 : index
3838
%c5 = arith.constant 5 : index
39-
// CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
39+
// CHECK: arith.constant dense<false> : vector<4x3xi1>
4040
%0 = vector.create_mask %c5, %cneg2 : vector<4x3xi1>
4141
return %0 : vector<4x3xi1>
4242
}
@@ -47,7 +47,7 @@ func.func @create_vector_mask_to_constant_mask_truncation_neg() -> (vector<4x3xi
4747
func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3xi1>) {
4848
%c2 = arith.constant 2 : index
4949
%c0 = arith.constant 0 : index
50-
// CHECK: vector.constant_mask [0, 0] : vector<4x3xi1>
50+
// CHECK: arith.constant dense<false> : vector<4x3xi1>
5151
%0 = vector.create_mask %c0, %c2 : vector<4x3xi1>
5252
return %0 : vector<4x3xi1>
5353
}
@@ -60,7 +60,7 @@ func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x
6060
%c16 = arith.constant 16 : index
6161
%0 = vector.vscale
6262
%1 = arith.muli %0, %c16 : index
63-
// CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
63+
// CHECK: arith.constant dense<true> : vector<8x[16]xi1>
6464
%10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
6565
return %10 : vector<8x[16]xi1>
6666
}
@@ -272,6 +272,30 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
272272

273273
// -----
274274

275+
// CHECK-LABEL: constant_mask_to_true_splat
276+
func.func @constant_mask_to_true_splat() -> vector<2x4xi1> {
277+
// CHECK: arith.constant dense<true>
278+
// CHECK-NOT: vector.constant_mask
279+
%0 = vector.constant_mask [2, 4] : vector<2x4xi1>
280+
return %0 : vector<2x4xi1>
281+
}
282+
283+
// CHECK-LABEL: constant_mask_to_false_splat
284+
func.func @constant_mask_to_false_splat() -> vector<2x4xi1> {
285+
// CHECK: arith.constant dense<false>
286+
// CHECK-NOT: vector.constant_mask
287+
%0 = vector.constant_mask [0, 0] : vector<2x4xi1>
288+
return %0 : vector<2x4xi1>
289+
}
290+
291+
// CHECK-LABEL: constant_mask_to_true_splat_0d
292+
func.func @constant_mask_to_true_splat_0d() -> vector<i1> {
293+
// CHECK: arith.constant dense<true>
294+
// CHECK-NOT: vector.constant_mask
295+
%0 = vector.constant_mask [1] : vector<i1>
296+
return %0 : vector<i1>
297+
}
298+
275299
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
276300
func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) {
277301
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
@@ -289,7 +313,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
289313
%1 = vector.extract_strided_slice %0
290314
{offsets = [0, 0], sizes = [2, 2], strides = [1, 1]}
291315
: vector<4x3xi1> to vector<2x2xi1>
292-
// CHECK: vector.constant_mask [2, 2] : vector<2x2xi1>
316+
// CHECK: arith.constant dense<true> : vector<2x2xi1>
293317
return %1 : vector<2x2xi1>
294318
}
295319

@@ -322,7 +346,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
322346
%1 = vector.extract_strided_slice %0
323347
{offsets = [2, 0], sizes = [2, 2], strides = [1, 1]}
324348
: vector<4x3xi1> to vector<2x2xi1>
325-
// CHECK: vector.constant_mask [0, 0] : vector<2x2xi1>
349+
// CHECK: arith.constant dense<false> : vector<2x2xi1>
326350
return %1 : vector<2x2xi1>
327351
}
328352

@@ -333,7 +357,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
333357
%1 = vector.extract_strided_slice %0
334358
{offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
335359
: vector<4x3xi1> to vector<2x1xi1>
336-
// CHECK: vector.constant_mask [0, 0] : vector<2x1xi1>
360+
// CHECK: arith.constant dense<false> : vector<2x1xi1>
337361
return %1 : vector<2x1xi1>
338362
}
339363

@@ -344,7 +368,7 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
344368
%1 = vector.extract_strided_slice %0
345369
{offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
346370
: vector<4x3xi1> to vector<2x1xi1>
347-
// CHECK: vector.constant_mask [2, 1] : vector<2x1xi1>
371+
// CHECK: arith.constant dense<true> : vector<2x1xi1>
348372
return %1 : vector<2x1xi1>
349373
}
350374

mlir/test/Dialect/Vector/vector-mem-transforms.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) {
8383
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
8484
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
8585
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
86-
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
86+
// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
8787
// CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
8888
// CHECK-NEXT: return %[[G]] : vector<16xf32>
8989
func.func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
@@ -112,7 +112,7 @@ func.func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru:
112112
// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
113113
// CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
114114
// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
115-
// CHECK-NEXT: %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
115+
// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
116116
// CHECK-NEXT: vector.scatter %[[A0]][%[[C]]] [%[[A1]]], %[[M]], %[[A2]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
117117
// CHECK-NEXT: return
118118
func.func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {

0 commit comments

Comments
 (0)