@@ -4096,6 +4096,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4096
4096
4097
4097
namespace {
4098
4098
4099
+ // Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4100
+ // CreateMaskOp.
4101
+ //
4102
+ // Example:
4103
+ //
4104
+ // %mask = vector.create_mask %ub : vector<16xi1>
4105
+ // %slice = vector.extract_strided_slice [%offset] [8] [1]
4106
+ //
4107
+ // to
4108
+ //
4109
+ // %new_ub = arith.subi %ub, %offset
4110
+ // %mask = vector.create_mask %new_ub : vector<8xi1>
4111
+ class StridedSliceCreateMaskFolder final
4112
+ : public OpRewritePattern<ExtractStridedSliceOp> {
4113
+ using OpRewritePattern::OpRewritePattern;
4114
+
4115
+ public:
4116
+ LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
4117
+ PatternRewriter &rewriter) const override {
4118
+ Location loc = extractStridedSliceOp.getLoc ();
4119
+ // Return if 'extractStridedSliceOp' operand is not defined by a
4120
+ // CreateMaskOp.
4121
+ auto createMaskOp =
4122
+ extractStridedSliceOp.getVector ().getDefiningOp <CreateMaskOp>();
4123
+ if (!createMaskOp)
4124
+ return failure ();
4125
+ // Return if 'extractStridedSliceOp' has non-unit strides.
4126
+ if (extractStridedSliceOp.hasNonUnitStrides ())
4127
+ return failure ();
4128
+ // Gather constant mask dimension sizes.
4129
+ SmallVector<Value> maskDimSizes (createMaskOp.getOperands ());
4130
+ // Gather strided slice offsets and sizes.
4131
+ SmallVector<int64_t > sliceOffsets;
4132
+ populateFromInt64AttrArray (extractStridedSliceOp.getOffsets (),
4133
+ sliceOffsets);
4134
+ SmallVector<int64_t > sliceSizes;
4135
+ populateFromInt64AttrArray (extractStridedSliceOp.getSizes (), sliceSizes);
4136
+
4137
+ // Compute slice of vector mask region.
4138
+ SmallVector<Value> sliceMaskDimSizes;
4139
+ sliceMaskDimSizes.reserve (maskDimSizes.size ());
4140
+ // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4141
+ // only iterate on the leading dim sizes. The tail accounts for the
4142
+ // remaining dim sizes.
4143
+ for (auto [maskDimSize, sliceOffset, sliceSize] :
4144
+ llvm::zip (maskDimSizes, sliceOffsets, sliceSizes)) {
4145
+ // No need to clamp on min/max values, because create_mask has clamping
4146
+ // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4147
+ // greater than the vector dim size.
4148
+ IntegerAttr offsetAttr =
4149
+ rewriter.getIntegerAttr (maskDimSize.getType (), sliceOffset);
4150
+ Value offset = rewriter.create <arith::ConstantOp>(loc, offsetAttr);
4151
+ Value sliceMaskDimSize =
4152
+ rewriter.create <arith::SubIOp>(loc, maskDimSize, offset);
4153
+ sliceMaskDimSizes.push_back (sliceMaskDimSize);
4154
+ }
4155
+ // Add unchanged dimensions.
4156
+ llvm::append_range (
4157
+ sliceMaskDimSizes,
4158
+ llvm::drop_begin (maskDimSizes, sliceMaskDimSizes.size ()));
4159
+ // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4160
+ // region.
4161
+ rewriter.replaceOpWithNewOp <CreateMaskOp>(
4162
+ extractStridedSliceOp, extractStridedSliceOp.getResult ().getType (),
4163
+ sliceMaskDimSizes);
4164
+ return success ();
4165
+ }
4166
+ };
4167
+
4099
4168
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4100
4169
// ConstantMaskOp.
4101
4170
class StridedSliceConstantMaskFolder final
@@ -4117,14 +4186,14 @@ class StridedSliceConstantMaskFolder final
4117
4186
// Gather constant mask dimension sizes.
4118
4187
ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
4119
4188
// Gather strided slice offsets and sizes.
4120
- SmallVector<int64_t , 4 > sliceOffsets;
4189
+ SmallVector<int64_t > sliceOffsets;
4121
4190
populateFromInt64AttrArray (extractStridedSliceOp.getOffsets (),
4122
4191
sliceOffsets);
4123
- SmallVector<int64_t , 4 > sliceSizes;
4192
+ SmallVector<int64_t > sliceSizes;
4124
4193
populateFromInt64AttrArray (extractStridedSliceOp.getSizes (), sliceSizes);
4125
4194
4126
4195
// Compute slice of vector mask region.
4127
- SmallVector<int64_t , 4 > sliceMaskDimSizes;
4196
+ SmallVector<int64_t > sliceMaskDimSizes;
4128
4197
sliceMaskDimSizes.reserve (maskDimSizes.size ());
4129
4198
for (auto [maskDimSize, sliceOffset, sliceSize] :
4130
4199
llvm::zip (maskDimSizes, sliceOffsets, sliceSizes)) {
@@ -4294,9 +4363,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
4294
4363
RewritePatternSet &results, MLIRContext *context) {
4295
4364
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4296
4365
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4297
- results.add <StridedSliceConstantMaskFolder, StridedSliceBroadcast ,
4298
- StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4299
- context);
4366
+ results.add <StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder ,
4367
+ StridedSliceBroadcast, StridedSliceSplat,
4368
+ ContiguousExtractStridedSliceToExtract>( context);
4300
4369
}
4301
4370
4302
4371
// ===----------------------------------------------------------------------===//
0 commit comments