Skip to content

Commit 0227aef

Browse files
authored
[mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (#146745)
extract_strided_slice(create_mask) can be folded into create_mask by simply subtracting the offsets from the bounds.
1 parent a2c0ac0 commit 0227aef

File tree

2 files changed

+110
-6
lines changed

2 files changed

+110
-6
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,6 +4096,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
40964096

40974097
namespace {
40984098

4099+
// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4100+
// CreateMaskOp.
4101+
//
4102+
// Example:
4103+
//
4104+
// %mask = vector.create_mask %ub : vector<16xi1>
4105+
// %slice = vector.extract_strided_slice [%offset] [8] [1]
4106+
//
4107+
// to
4108+
//
4109+
// %new_ub = arith.subi %ub, %offset
4110+
// %mask = vector.create_mask %new_ub : vector<8xi1>
4111+
class StridedSliceCreateMaskFolder final
4112+
: public OpRewritePattern<ExtractStridedSliceOp> {
4113+
using OpRewritePattern::OpRewritePattern;
4114+
4115+
public:
4116+
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4117+
PatternRewriter &rewriter) const override {
4118+
Location loc = extractStridedSliceOp.getLoc();
4119+
// Return if 'extractStridedSliceOp' operand is not defined by a
4120+
// CreateMaskOp.
4121+
auto createMaskOp =
4122+
extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
4123+
if (!createMaskOp)
4124+
return failure();
4125+
// Return if 'extractStridedSliceOp' has non-unit strides.
4126+
if (extractStridedSliceOp.hasNonUnitStrides())
4127+
return failure();
4128+
// Gather constant mask dimension sizes.
4129+
SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4130+
// Gather strided slice offsets and sizes.
4131+
SmallVector<int64_t> sliceOffsets;
4132+
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4133+
sliceOffsets);
4134+
SmallVector<int64_t> sliceSizes;
4135+
populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4136+
4137+
// Compute slice of vector mask region.
4138+
SmallVector<Value> sliceMaskDimSizes;
4139+
sliceMaskDimSizes.reserve(maskDimSizes.size());
4140+
// sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4141+
// only iterate on the leading dim sizes. The tail accounts for the
4142+
// remaining dim sizes.
4143+
for (auto [maskDimSize, sliceOffset, sliceSize] :
4144+
llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4145+
// No need to clamp on min/max values, because create_mask has clamping
4146+
// semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4147+
// greater than the vector dim size.
4148+
IntegerAttr offsetAttr =
4149+
rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4150+
Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
4151+
Value sliceMaskDimSize =
4152+
rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
4153+
sliceMaskDimSizes.push_back(sliceMaskDimSize);
4154+
}
4155+
// Add unchanged dimensions.
4156+
llvm::append_range(
4157+
sliceMaskDimSizes,
4158+
llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4159+
// Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4160+
// region.
4161+
rewriter.replaceOpWithNewOp<CreateMaskOp>(
4162+
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4163+
sliceMaskDimSizes);
4164+
return success();
4165+
}
4166+
};
4167+
40994168
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
41004169
// ConstantMaskOp.
41014170
class StridedSliceConstantMaskFolder final
@@ -4117,14 +4186,14 @@ class StridedSliceConstantMaskFolder final
41174186
// Gather constant mask dimension sizes.
41184187
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
41194188
// Gather strided slice offsets and sizes.
4120-
SmallVector<int64_t, 4> sliceOffsets;
4189+
SmallVector<int64_t> sliceOffsets;
41214190
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
41224191
sliceOffsets);
4123-
SmallVector<int64_t, 4> sliceSizes;
4192+
SmallVector<int64_t> sliceSizes;
41244193
populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
41254194

41264195
// Compute slice of vector mask region.
4127-
SmallVector<int64_t, 4> sliceMaskDimSizes;
4196+
SmallVector<int64_t> sliceMaskDimSizes;
41284197
sliceMaskDimSizes.reserve(maskDimSizes.size());
41294198
for (auto [maskDimSize, sliceOffset, sliceSize] :
41304199
llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
@@ -4294,9 +4363,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
42944363
RewritePatternSet &results, MLIRContext *context) {
42954364
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
42964365
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4297-
results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4298-
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4299-
context);
4366+
results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4367+
StridedSliceBroadcast, StridedSliceSplat,
4368+
ContiguousExtractStridedSliceToExtract>(context);
43004369
}
43014370

43024371
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,41 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
385385

386386
// -----
387387

388+
// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
389+
// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
390+
func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
391+
%0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
392+
%1 = vector.extract_strided_slice %0
393+
{offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
394+
: vector<4x3xi1> to vector<2x2xi1>
395+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
396+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
397+
// CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
398+
// CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
399+
// CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
400+
return %1 : vector<2x2xi1>
401+
}
402+
403+
// -----
404+
405+
// CHECK-LABEL: func.func @extract_strided_slice_partial_of_create_mask
406+
// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index)
407+
func.func @extract_strided_slice_partial_of_create_mask(
408+
%dim0: index, %dim1: index, %dim2 : index) -> (vector<2x2x8xi1>) {
409+
%0 = vector.create_mask %dim0, %dim1, %dim2 : vector<4x3x8xi1>
410+
%1 = vector.extract_strided_slice %0
411+
{offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
412+
: vector<4x3x8xi1> to vector<2x2x8xi1>
413+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
414+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
415+
// CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
416+
// CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
417+
// CHECK: vector.create_mask %[[A]], %[[B]], %[[DIM2]] : vector<2x2x8xi1>
418+
return %1 : vector<2x2x8xi1>
419+
}
420+
421+
// -----
422+
388423
// CHECK-LABEL: extract_strided_fold
389424
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
390425
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>

0 commit comments

Comments
 (0)