Skip to content

Commit 70b95d1

Browse files
[mlir][linalg] Retain Op Type of linalg ops in fuseWithReshapeByExpansion pattern (llvm#129128)
This PR preserve linalg Op types for certain named ops such as Fill, Copy and Transpose instead of fusion always resulting in a generic Op. --------- Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
1 parent 6c9a9d9 commit 70b95d1

File tree

2 files changed

+144
-21
lines changed

2 files changed

+144
-21
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,77 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
815815
return success();
816816
}
817817

818+
// Create an expanded transpose op.
819+
static Operation *
820+
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
821+
SmallVector<ReassociationIndices> reassociation,
822+
Value expandedInput, Value output) {
823+
applyPermutationToVector(reassociation, transposeOp.getPermutation());
824+
SmallVector<int64_t> newPerm;
825+
for (auto reassoc : reassociation) {
826+
for (auto dim : reassoc) {
827+
newPerm.push_back(dim);
828+
}
829+
}
830+
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
831+
output, newPerm);
832+
}
833+
834+
// Create an expanded generic op.
835+
static Operation *createExpandedGenericOp(
836+
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
837+
ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
838+
ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
839+
// The iterator types of the expanded op are all parallel.
840+
SmallVector<utils::IteratorType> iteratorTypes(
841+
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
842+
843+
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
844+
for (auto j : expansionInfo.getExpandedDims(i))
845+
iteratorTypes[j] = type;
846+
847+
Operation *fused = rewriter.create<GenericOp>(
848+
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
849+
expandedOpIndexingMaps, iteratorTypes);
850+
851+
Region &fusedRegion = fused->getRegion(0);
852+
Region &originalRegion = linalgOp->getRegion(0);
853+
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
854+
855+
// Update the index accesses after the expansion.
856+
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
857+
expansionInfo);
858+
859+
return fused;
860+
}
861+
862+
// Create an expanded fused op that retains the name for certain ops
863+
// such as fill, copy and transpose and produce a generic op for
864+
// rest of linalg ops.
865+
static Operation *createExpandedOp(
866+
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
867+
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
868+
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
869+
SmallVector<ReassociationIndices> reassociation) {
870+
871+
return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
872+
.Case<TransposeOp>([&](TransposeOp transposeOp) {
873+
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
874+
expandedOpOperands[0], outputs[0]);
875+
})
876+
.Case<FillOp, CopyOp>([&](Operation *op) {
877+
return clone(rewriter, linalgOp, resultTypes,
878+
llvm::to_vector(llvm::concat<Value>(
879+
llvm::to_vector(expandedOpOperands),
880+
llvm::to_vector(outputs))));
881+
})
882+
.Default([&](Operation *op) {
883+
return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
884+
expandedOpOperands, outputs,
885+
expansionInfo, expandedOpIndexingMaps);
886+
});
887+
}
888+
818889
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
819890
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
820891
/// that those conditions have been satisfied.
@@ -919,25 +990,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
919990
}
920991
}
921992

922-
// The iterator types of the expanded op are all parallel.
923-
SmallVector<utils::IteratorType> iteratorTypes(
924-
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
925-
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
926-
for (auto j : expansionInfo.getExpandedDims(i))
927-
iteratorTypes[j] = type;
928-
929993
TypeRange resultTypes = ValueRange(outputs).getTypes();
930-
auto fusedOp =
931-
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
932-
/*inputs=*/expandedOpOperands, outputs,
933-
expandedOpIndexingMaps, iteratorTypes);
934-
Region &fusedRegion = fusedOp->getRegion(0);
935-
Region &originalRegion = linalgOp->getRegion(0);
936-
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
937-
938-
// Update the index accesses after the expansion.
939-
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
940-
994+
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
995+
isExpanding ? expandingReshapeOp.getReassociationIndices()
996+
: collapsingReshapeOp.getReassociationIndices();
997+
Operation *fusedOp = createExpandedOp(
998+
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
999+
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
9411000
// Reshape the result values to their original shape if this is a collapsing
9421001
// reshape folded into its consumer.
9431002
SmallVector<Value> resultVals;

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,9 +783,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
783783

784784
// -----
785785

786-
#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
787-
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
788-
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
789786
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
790787
%arg1 : tensor<?x?xf32>,
791788
%arg2 : tensor<?x?xf32>) ->
@@ -829,6 +826,73 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
829826

830827
// -----
831828

829+
func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
830+
%arg1 : tensor<?x?xf32>) ->
831+
tensor<?x?xf32>
832+
{
833+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
834+
tensor<?x7x?x8xf32> into tensor<?x?xf32>
835+
%1 = linalg.copy ins(%0 : tensor<?x?xf32>)
836+
outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
837+
return %1 : tensor<?x?xf32>
838+
}
839+
840+
// CHECK: func @linalg_copy_reshape_producer_fusion
841+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
842+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
843+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
844+
// CHECK: %[[C7:.+]] = arith.constant 7 : index
845+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
846+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
847+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
848+
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
849+
// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
850+
// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
851+
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
852+
// CHECK: %[[T2:.+]] = linalg.copy
853+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
854+
// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
855+
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
856+
// CHECK-SAME: [0, 1], [2, 3]
857+
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
858+
// CHECK: return %[[T3]]
859+
860+
// -----
861+
862+
func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
863+
%arg1 : tensor<?x?xf32>) ->
864+
tensor<?x?xf32>
865+
{
866+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
867+
tensor<?x7x?x8xf32> into tensor<?x?xf32>
868+
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
869+
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
870+
return %1 : tensor<?x?xf32>
871+
}
872+
873+
// CHECK: func @linalg_transpose_reshape_producer_fusion
874+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
875+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
876+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
877+
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
878+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
879+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
880+
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
881+
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
882+
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
883+
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
884+
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
885+
// CHECK: %[[T2:.+]] = linalg.transpose
886+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
887+
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
888+
// CHECK-SAME: permutation = [2, 3, 0, 1]
889+
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
890+
// CHECK-SAME: [0, 1], [2, 3]
891+
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
892+
// CHECK: return %[[T3]]
893+
894+
// -----
895+
832896
func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
833897
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
834898
%cst = arith.constant 0 : i32

0 commit comments

Comments
 (0)