-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization #147591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Tomás Longeri (tlongeri) ChangesThe pattern would produce an invalid slice when some dimensions were both sliced and broadcast. Full diff: https://github.com/llvm/llvm-project/pull/147591.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 214d2ba7e1b8e..2f5b831b3c40b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
auto dstVecType = llvm::cast<VectorType>(op.getType());
unsigned dstRank = dstVecType.getRank();
unsigned rankDiff = dstRank - srcRank;
- // Check if the most inner dimensions of the source of the broadcast are the
- // same as the destination of the extract. If this is the case we can just
- // use a broadcast as the original dimensions are untouched.
- bool lowerDimMatch = true;
+ // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
+ // (n -> m with n > m). If they are originally both broadcasted *and*
+ // sliced, this can be simplified to just broadcasting.
+ bool needsSlice = false;
for (unsigned i = 0; i < srcRank; i++) {
- if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
- lowerDimMatch = false;
+ if (srcVecType.getDimSize(i) != 1 &&
+ srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+ needsSlice = true;
break;
}
}
Value source = broadcast.getSource();
- // If the inner dimensions don't match, it means we need to extract from the
- // source of the orignal broadcast and then broadcast the extracted value.
- // We also need to handle degenerated cases where the source is effectively
- // just a single scalar.
- bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
- if (!lowerDimMatch && !isScalarSrc) {
+ if (needsSlice) {
+ SmallVector<int64_t> offsets =
+ getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
+ SmallVector<int64_t> sizes =
+ getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
+ for (unsigned i = 0; i < srcRank; i++) {
+ if (srcVecType.getDimSize(i) == 1) {
+ // In case this dimension was broadcasted *and* sliced, the offset
+ // and size need to be updated now that there is no broadcast before
+ // the slice.
+ offsets[i] = 0;
+ sizes[i] = 1;
+ }
+ }
source = rewriter.create<ExtractStridedSliceOp>(
- op->getLoc(), source,
- getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
- getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
- getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+ op->getLoc(), source, offsets, sizes,
+ getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
}
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8a9e27378df61..dfa2e1c2a5a24 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1344,6 +1344,20 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----
+// CHECK-LABEL: func @extract_strided_broadcast5
+// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
+// CHECK: return %[[V]]
+func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
+ : vector<2x8xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: consecutive_shape_cast
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great job improving the comments!
I've left one minor request for more comments - could you address that before merging? Thanks!
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>) | ||
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32> | ||
// CHECK: return %[[V]] | ||
func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that you are simply following the pre-existing naming convention here, which is the most natural thing to do. Sadly, this function name does not reveal what makes this case unique. Could you add a comment to explain? IIUC, it's the fact that the trailing unit dim is broadcasted?
The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.