From 6aa29a5f778eb074928d297392d0aedfd471e1cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Tue, 8 Jul 2025 19:55:23 +0000 Subject: [PATCH] [MLIR][Vector] Fix bug in ExtractStrideSlicesOp canonicalization The pattern would produce an invalid slice when some dimensions were both sliced and broadcast. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 39 +++++++++++++--------- mlir/test/Dialect/Vector/canonicalize.mlir | 15 +++++++++ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 214d2ba7e1b8e..2f5b831b3c40b 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4169,28 +4169,35 @@ class StridedSliceBroadcast final auto dstVecType = llvm::cast(op.getType()); unsigned dstRank = dstVecType.getRank(); unsigned rankDiff = dstRank - srcRank; - // Check if the most inner dimensions of the source of the broadcast are the - // same as the destination of the extract. If this is the case we can just - // use a broadcast as the original dimensions are untouched. - bool lowerDimMatch = true; + // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced + // (n -> m with n > m). If they are originally both broadcasted *and* + // sliced, this can be simplified to just broadcasting. + bool needsSlice = false; for (unsigned i = 0; i < srcRank; i++) { - if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { - lowerDimMatch = false; + if (srcVecType.getDimSize(i) != 1 && + srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { + needsSlice = true; break; } } Value source = broadcast.getSource(); - // If the inner dimensions don't match, it means we need to extract from the - // source of the orignal broadcast and then broadcast the extracted value. - // We also need to handle degenerated cases where the source is effectively - // just a single scalar. - bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1); - if (!lowerDimMatch && !isScalarSrc) { + if (needsSlice) { + SmallVector offsets = + getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff); + SmallVector sizes = + getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff); + for (unsigned i = 0; i < srcRank; i++) { + if (srcVecType.getDimSize(i) == 1) { + // In case this dimension was broadcasted *and* sliced, the offset + // and size need to be updated now that there is no broadcast before + // the slice. + offsets[i] = 0; + sizes[i] = 1; + } + } source = rewriter.create( - op->getLoc(), source, - getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), - getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff), - getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff)); + op->getLoc(), source, offsets, sizes, + getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff)); } rewriter.replaceOpWithNewOp(op, op.getType(), source); return success(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 8a9e27378df61..93c3de619084d 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1344,6 +1344,21 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> { // ----- +// Check the case where the same dimension is both broadcasted and sliced +// CHECK-LABEL: func @extract_strided_broadcast5 +// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>) +// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32> +// CHECK: return %[[V]] +func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> { + %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} + : vector<2x8xf32> to vector<2x4xf32> + return %1 : vector<2x4xf32> +} + +// ----- + // CHECK-LABEL: consecutive_shape_cast // CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16> // CHECK-NEXT: return %[[C]] : vector<4x4xf16>