Skip to content

Commit a1b8813

Browse files
committed
[MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization
The pattern would produce an invalid slice when some dimensions were both sliced and broadcast.
1 parent 0863979 commit a1b8813

File tree

2 files changed

+37
-16
lines changed

2 files changed

+37
-16
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final
41694169
auto dstVecType = llvm::cast<VectorType>(op.getType());
41704170
unsigned dstRank = dstVecType.getRank();
41714171
unsigned rankDiff = dstRank - srcRank;
4172-
// Check if the most inner dimensions of the source of the broadcast are the
4173-
// same as the destination of the extract. If this is the case we can just
4174-
// use a broadcast as the original dimensions are untouched.
4175-
bool lowerDimMatch = true;
4172+
// Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4173+
// (n -> m with n > m). If they are originally both broadcasted *and*
4174+
// sliced, this can be simplified to just broadcasting.
4175+
bool needsSlice = false;
41764176
for (unsigned i = 0; i < srcRank; i++) {
4177-
if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4178-
lowerDimMatch = false;
4177+
if (srcVecType.getDimSize(i) != 1 &&
4178+
srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4179+
needsSlice = true;
41794180
break;
41804181
}
41814182
}
41824183
Value source = broadcast.getSource();
4183-
// If the inner dimensions don't match, it means we need to extract from the
4184-
// source of the orignal broadcast and then broadcast the extracted value.
4185-
// We also need to handle degenerated cases where the source is effectively
4186-
// just a single scalar.
4187-
bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
4188-
if (!lowerDimMatch && !isScalarSrc) {
4184+
if (needsSlice) {
4185+
SmallVector<int64_t> offsets =
4186+
getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4187+
SmallVector<int64_t> sizes =
4188+
getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4189+
for (unsigned i = 0; i < srcRank; i++) {
4190+
if (srcVecType.getDimSize(i) == 1) {
4191+
// In case this dimension was broadcasted *and* sliced, the offset
4192+
// and size need to be updated now that there is no broadcast before
4193+
// the slice.
4194+
offsets[i] = 0;
4195+
sizes[i] = 1;
4196+
}
4197+
}
41894198
source = rewriter.create<ExtractStridedSliceOp>(
4190-
op->getLoc(), source,
4191-
getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
4192-
getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
4193-
getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
4199+
op->getLoc(), source, offsets, sizes,
4200+
getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
41944201
}
41954202
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
41964203
return success();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,20 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
13441344

13451345
// -----
13461346

1347+
// CHECK-LABEL: func @extract_strided_broadcast5
1348+
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
1349+
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
1350+
// CHECK: return %[[V]]
1351+
func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
1352+
%0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
1353+
%1 = vector.extract_strided_slice %0
1354+
{offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
1355+
: vector<2x8xf32> to vector<2x4xf32>
1356+
return %1 : vector<2x4xf32>
1357+
}
1358+
1359+
// -----
1360+
13471361
// CHECK-LABEL: consecutive_shape_cast
13481362
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
13491363
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>

0 commit comments

Comments
 (0)