Skip to content

[mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) #146745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 10, 2025

Conversation

Groverkss
Copy link
Member

@Groverkss Groverkss commented Jul 2, 2025

extract_strided_slice(create_mask) can be folded into create_mask by simply subtracting the offsets from the bounds.

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/146745.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+59-3)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+17)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..66c2fe50529d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+class StridedSliceCreateMaskFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = extractStridedSliceOp.getLoc();
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // CreateMaskOp.
+    auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+    auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+    if (!createMaskOp)
+      return failure();
+    // Return if 'extractStridedSliceOp' has non-unit strides.
+    if (extractStridedSliceOp.hasNonUnitStrides())
+      return failure();
+    // Gather constant mask dimension sizes.
+    SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+    // Gather strided slice offsets and sizes.
+    SmallVector<int64_t, 4> sliceOffsets;
+    populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+                               sliceOffsets);
+    SmallVector<int64_t, 4> sliceSizes;
+    populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+    // Compute slice of vector mask region.
+    SmallVector<Value> sliceMaskDimSizes;
+    sliceMaskDimSizes.reserve(maskDimSizes.size());
+    for (auto [maskDimSize, sliceOffset, sliceSize] :
+         llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+      // No need to clamp on min/max values, because create_mask has clamping
+      // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+      // greater than the vector dim size.
+      IntegerAttr offsetAttr =
+          rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+      Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+      Value sliceMaskDimSize =
+          rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+      sliceMaskDimSizes.push_back(sliceMaskDimSize);
+    }
+    // Add unchanged dimensions.
+    if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
+      for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+        sliceMaskDimSizes.push_back(maskDimSizes[i]);
+      }
+    }
+    // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
+    // region.
+    rewriter.replaceOpWithNewOp<CreateMaskOp>(
+        extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
+        sliceMaskDimSizes);
+    return success();
+  }
+};
+
 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
 // ConstantMaskOp.
 class StridedSliceConstantMaskFolder final
@@ -4279,9 +4335,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
-              StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
-      context);
+  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
+              StridedSliceBroadcast, StridedSliceSplat,
+              ContiguousExtractStridedSliceToExtract>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..c09dab8232900 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -361,6 +361,23 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
 
 // -----
 
+// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
+  %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+      : vector<4x3xi1> to vector<2x2xi1>
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+  // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+  // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
+  return %1 : vector<2x2xi1>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_fold
 //  CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
 //  CHECK-NEXT:   return %[[ARG]] : vector<4x3xi1>

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks! Just some minor comments.

I'm curious: what is the current lowering of this without the canonicalization? A vector.shuffle extracting the slice from the mask?

@Groverkss Groverkss requested review from joker-eph and dcaballe July 4, 2025 10:53
@Groverkss
Copy link
Member Author

LG, thanks! Just some minor comments.

I'm curious: what is the current lowering of this without the canonicalization? A vector.shuffle extracting the slice from the mask?

Yes. Not only that, in a lot of cases you could actually fold the mask with the maskedload into a load, but the extract_strided_slice was blocking it.

@Groverkss Groverkss requested a review from kuhar July 4, 2025 13:48
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good idea to me

@Groverkss Groverkss requested a review from kuhar July 9, 2025 13:08
@Groverkss Groverkss merged commit 0227aef into llvm:main Jul 10, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants