Skip to content

[mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) #146745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4081,6 +4081,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {

namespace {

// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
// CreateMaskOp.
//
// Example:
//
// %mask = vector.create_mask %ub : vector<16xi1>
// %slice = vector.extract_strided_slice [%offset] [8] [1]
//
// to
//
// %new_ub = arith.subi %ub, %offset
// %mask = vector.create_mask %new_ub : vector<8xi1>
class StridedSliceCreateMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;

public:
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
Location loc = extractStridedSliceOp.getLoc();
// Return if 'extractStridedSliceOp' operand is not defined by a
// CreateMaskOp.
auto createMaskOp =
extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
if (!createMaskOp)
return failure();
// Return if 'extractStridedSliceOp' has non-unit strides.
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
// Gather strided slice offsets and sizes.
SmallVector<int64_t> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
sliceOffsets);
SmallVector<int64_t> sliceSizes;
populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);

// Compute slice of vector mask region.
SmallVector<Value> sliceMaskDimSizes;
sliceMaskDimSizes.reserve(maskDimSizes.size());
// sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
// only iterate on the leading dim sizes. The tail accounts for the
// remaining dim sizes.
for (auto [maskDimSize, sliceOffset, sliceSize] :
llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
// No need to clamp on min/max values, because create_mask has clamping
// semantics, i.e. the sliceMaskDimSize is allowed to be negative or
// greater than the vector dim size.
IntegerAttr offsetAttr =
rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
Value sliceMaskDimSize =
rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
sliceMaskDimSizes.push_back(sliceMaskDimSize);
}
// Add unchanged dimensions.
llvm::append_range(
sliceMaskDimSizes,
llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
// Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
// region.
rewriter.replaceOpWithNewOp<CreateMaskOp>(
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
sliceMaskDimSizes);
return success();
}
};

// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
// ConstantMaskOp.
class StridedSliceConstantMaskFolder final
Expand All @@ -4102,14 +4171,14 @@ class StridedSliceConstantMaskFolder final
// Gather constant mask dimension sizes.
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
SmallVector<int64_t> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
sliceOffsets);
SmallVector<int64_t, 4> sliceSizes;
SmallVector<int64_t> sliceSizes;
populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);

// Compute slice of vector mask region.
SmallVector<int64_t, 4> sliceMaskDimSizes;
SmallVector<int64_t> sliceMaskDimSizes;
sliceMaskDimSizes.reserve(maskDimSizes.size());
for (auto [maskDimSize, sliceOffset, sliceSize] :
llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
Expand Down Expand Up @@ -4279,9 +4348,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
context);
results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
StridedSliceBroadcast, StridedSliceSplat,
ContiguousExtractStridedSliceToExtract>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,41 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {

// -----

// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
%0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
%1 = vector.extract_strided_slice %0
{offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
: vector<4x3xi1> to vector<2x2xi1>
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
// CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
// CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
return %1 : vector<2x2xi1>
}

// -----

// CHECK-LABEL: func.func @extract_strided_slice_partial_of_create_mask
// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index)
func.func @extract_strided_slice_partial_of_create_mask(
%dim0: index, %dim1: index, %dim2 : index) -> (vector<2x2x8xi1>) {
%0 = vector.create_mask %dim0, %dim1, %dim2 : vector<4x3x8xi1>
%1 = vector.extract_strided_slice %0
{offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
: vector<4x3x8xi1> to vector<2x2x8xi1>
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
// CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
// CHECK: vector.create_mask %[[A]], %[[B]], %[[DIM2]] : vector<2x2x8xi1>
return %1 : vector<2x2x8xi1>
}

// -----

// CHECK-LABEL: extract_strided_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
Expand Down