-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) #146745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) #146745
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Kunwar Grover (Groverkss) ChangesFull diff: https://github.com/llvm/llvm-project/pull/146745.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..66c2fe50529d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
+class StridedSliceCreateMaskFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+public:
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = extractStridedSliceOp.getLoc();
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // CreateMaskOp.
+ auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+ auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+ if (!createMaskOp)
+ return failure();
+ // Return if 'extractStridedSliceOp' has non-unit strides.
+ if (extractStridedSliceOp.hasNonUnitStrides())
+ return failure();
+ // Gather constant mask dimension sizes.
+ SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+ // Gather strided slice offsets and sizes.
+ SmallVector<int64_t, 4> sliceOffsets;
+ populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+ sliceOffsets);
+ SmallVector<int64_t, 4> sliceSizes;
+ populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+ // Compute slice of vector mask region.
+ SmallVector<Value> sliceMaskDimSizes;
+ sliceMaskDimSizes.reserve(maskDimSizes.size());
+ for (auto [maskDimSize, sliceOffset, sliceSize] :
+ llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+ // No need to clamp on min/max values, because create_mask has clamping
+ // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+ // greater than the vector dim size.
+ IntegerAttr offsetAttr =
+ rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+ Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+ Value sliceMaskDimSize =
+ rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+ sliceMaskDimSizes.push_back(sliceMaskDimSize);
+ }
+ // Add unchanged dimensions.
+ if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
+ for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+ sliceMaskDimSizes.push_back(maskDimSizes[i]);
+ }
+ }
+ // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
+ // region.
+ rewriter.replaceOpWithNewOp<CreateMaskOp>(
+ extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
+ sliceMaskDimSizes);
+ return success();
+ }
+};
+
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
// ConstantMaskOp.
class StridedSliceConstantMaskFolder final
@@ -4279,9 +4335,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
- StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
- context);
+ results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
+ StridedSliceBroadcast, StridedSliceSplat,
+ ContiguousExtractStridedSliceToExtract>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..c09dab8232900 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -361,6 +361,23 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
// -----
+// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
+ %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x2xi1>
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+ // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+ // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+ // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
+ return %1 : vector<2x2xi1>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks! Just some minor comments.
I'm curious: what is the current lowering of this without the canonicalization? A vector.shuffle
extracting the slice from the mask?
Yes. Not only that, in a lot of cases you could actually fold the mask with the maskedload into a load, but the extract_strided_slice was blocking it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a good idea to me
extract_strided_slice(create_mask) can be folded into create_mask by simply subtracting the offsets from the bounds.