Skip to content

[MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization #147591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tlongeri
Copy link
Contributor

@tlongeri tlongeri commented Jul 8, 2025

The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.

The pattern would produce an invalid slice when some dimensions were
both sliced and broadcast.
@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Tomás Longeri (tlongeri)

Changes

The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.


Full diff: https://github.com/llvm/llvm-project/pull/147591.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+23-16)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..2f5b831b3c40b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
     auto dstVecType = llvm::cast<VectorType>(op.getType());
     unsigned dstRank = dstVecType.getRank();
     unsigned rankDiff = dstRank - srcRank;
-    // Check if the most inner dimensions of the source of the broadcast are the
-    // same as the destination of the extract. If this is the case we can just
-    // use a broadcast as the original dimensions are untouched.
-    bool lowerDimMatch = true;
+    // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
+    // (n -> m with n > m). If they are originally both broadcasted *and*
+    // sliced, this can be simplified to just broadcasting.
+    bool needsSlice = false;
     for (unsigned i = 0; i < srcRank; i++) {
-      if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
-        lowerDimMatch = false;
+      if (srcVecType.getDimSize(i) != 1 &&
+          srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+        needsSlice = true;
         break;
       }
     }
     Value source = broadcast.getSource();
-    // If the inner dimensions don't match, it means we need to extract from the
-    // source of the orignal broadcast and then broadcast the extracted value.
-    // We also need to handle degenerated cases where the source is effectively
-    // just a single scalar.
-    bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
-    if (!lowerDimMatch && !isScalarSrc) {
+    if (needsSlice) {
+      SmallVector<int64_t> offsets =
+          getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
+      SmallVector<int64_t> sizes =
+          getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
+      for (unsigned i = 0; i < srcRank; i++) {
+        if (srcVecType.getDimSize(i) == 1) {
+          // In case this dimension was broadcasted *and* sliced, the offset
+          // and size need to be updated now that there is no broadcast before
+          // the slice.
+          offsets[i] = 0;
+          sizes[i] = 1;
+        }
+      }
       source = rewriter.create<ExtractStridedSliceOp>(
-          op->getLoc(), source,
-          getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+          op->getLoc(), source, offsets, sizes,
+          getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
     }
     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
     return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..dfa2e1c2a5a24 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1344,6 +1344,20 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_broadcast5
+//  CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
+//       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
+//       CHECK: return %[[V]]
+func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
+ %1 = vector.extract_strided_slice %0
+      {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
+      : vector<2x8xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: consecutive_shape_cast
 //       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
 //  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great job improving the comments!

I've left one minor request for more comments - could you address that before merging? Thanks!

// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
// CHECK: return %[[V]]
func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you are simply following the pre-existing naming convention here, which is the most natural thing to do. Sadly, this function name does not reveal what makes this case unique. Could you add a comment to explain? IIUC, it's the fact that the trailing unit dim is broadcasted?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants