Skip to content

Commit 82ab0f7

Browse files
authored
[mlir][linalg] Fix rank-reduced cases for extract/insert slice in DropUnitDims (#74723)
Inferring the reshape reassociation indices for extract/insert slice ops based on the read sizes of the original slicing op will generate an invalid expand/collapse shape op for already rank-reduced cases. Instead just infer from the shape of the slice. Ported from Differential Revision: https://reviews.llvm.org/D147488
1 parent c398fa0 commit 82ab0f7

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -572,13 +572,17 @@ struct RankReducedExtractSliceOp
572572
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
573573
PatternRewriter &rewriter) const override {
574574
RankedTensorType resultType = sliceOp.getType();
575-
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
576-
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
577-
SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
578-
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
575+
SmallVector<OpFoldResult> targetShape;
576+
for (auto size : resultType.getShape())
577+
targetShape.push_back(rewriter.getIndexAttr(size));
578+
auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
579579
if (!reassociation ||
580580
reassociation->size() == static_cast<size_t>(resultType.getRank()))
581581
return failure();
582+
583+
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
584+
SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
585+
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
582586
auto rankReducedType = cast<RankedTensorType>(
583587
tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
584588
reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
@@ -602,13 +606,14 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
602606
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
603607
PatternRewriter &rewriter) const override {
604608
RankedTensorType sourceType = insertSliceOp.getSourceType();
605-
SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
606-
SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
607-
SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
608-
auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
609+
SmallVector<OpFoldResult> targetShape;
610+
for (auto size : sourceType.getShape())
611+
targetShape.push_back(rewriter.getIndexAttr(size));
612+
auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
609613
if (!reassociation ||
610614
reassociation->size() == static_cast<size_t>(sourceType.getRank()))
611615
return failure();
616+
612617
Location loc = insertSliceOp.getLoc();
613618
tensor::CollapseShapeOp reshapedSource;
614619
{

mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,18 @@ func.func @slice_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> {
489489

490490
// -----
491491

492+
func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> {
493+
%0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32>
494+
return %0 : tensor<1x3x3xf32>
495+
}
496+
// CHECK-LABEL: func @rank_reduced_extract_slice
497+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice
498+
// CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32>
499+
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]]
500+
// CHECK: return %[[RESULT]]
501+
502+
// -----
503+
492504
func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> {
493505
%0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32>
494506
return %0 : tensor<1x3xf32>
@@ -501,6 +513,18 @@ func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>
501513

502514
// -----
503515

516+
func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> {
517+
%0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32>
518+
return %0 : tensor<1x1x3x1x3xf32>
519+
}
520+
// CHECK-LABEL: func @rank_reduced_insert_slice
521+
// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]]
522+
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]]
523+
// CHECK-SAME: tensor<3x3xf32> into tensor<1x1x3x1x3xf32>
524+
// CHECK: return %[[RESULT]]
525+
526+
// -----
527+
504528
#accesses = [
505529
affine_map<(i, j, k, l, m) -> (i, k, m)>,
506530
affine_map<(i, j, k, l, m) -> ()>,

0 commit comments

Comments
 (0)