Skip to content

Commit 191008b

Browse files
committed
[mlir] Fix consumer fusion for producer with multiple results
In the case of consumer fusion where the producer is producing multiple results all used by a single consumer for e.g., %results:3 = scf.forall ... -> (tensor<...>, tensor<...>, tensor<...>) { // Produces 3 results scf.yield %a, %b, %c : tensor<...>, tensor<...>, tensor<...>} // Consumer uses all 3 results %final = consumer %results#0, %results#1, %results#2 all other operands of the tiled consumer needs to updated.
1 parent 82c6b8f commit 191008b

File tree

2 files changed

+238
-19
lines changed

2 files changed

+238
-19
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,60 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
19491949
}
19501950
}
19511951

1952+
// If the producer of the operand is a loopLikeOp, then finds the last
1953+
// insertSlice/parallelInsertSlice in the producer op that uses the block
1954+
// argument corresponding to the operand.
1955+
static FailureOr<Operation *>
1956+
getSliceOpFromConsumerOperand(OpOperand &operand) {
1957+
1958+
OpResult producerResult = dyn_cast<OpResult>(operand.get());
1959+
if (!producerResult)
1960+
return failure();
1961+
1962+
LoopLikeOpInterface loopLikeOp =
1963+
dyn_cast<LoopLikeOpInterface>(producerResult.getOwner());
1964+
if (!loopLikeOp)
1965+
return failure();
1966+
1967+
// Obtain the BlockArgument correponding to the result.
1968+
BlockArgument bbArg =
1969+
loopLikeOp.getRegionIterArgs()[producerResult.getResultNumber()];
1970+
1971+
// Finally return the operation corresponding to the yielded value.
1972+
// Also check whether it's an InsertSliceOp.
1973+
if (dyn_cast<scf::ForOp>(producerResult.getOwner())) {
1974+
OpOperand *yieldVal = loopLikeOp.getTiedLoopYieldedValue(bbArg);
1975+
Operation *lastOp = dyn_cast<OpResult>(yieldVal->get()).getOwner();
1976+
auto isInsertSliceOp = isa<tensor::InsertSliceOp>(lastOp);
1977+
if (!isInsertSliceOp) {
1978+
return failure();
1979+
}
1980+
return lastOp;
1981+
}
1982+
1983+
auto forallOp = dyn_cast<scf::ForallOp>(producerResult.getOwner());
1984+
if (!forallOp)
1985+
return failure();
1986+
1987+
// Iterate over the terminator operation of the forallOp to find the last
1988+
// parallelInsertSliceOp that uses the blockArgument.
1989+
Operation *lastOp = nullptr;
1990+
forallOp.getTerminator()->walk([&](tensor::ParallelInsertSliceOp op) {
1991+
for (mlir::Value operand : op->getOperands()) {
1992+
if (auto maybeBlockArg = dyn_cast<BlockArgument>(operand)) {
1993+
if (maybeBlockArg == bbArg) {
1994+
lastOp = op;
1995+
}
1996+
}
1997+
}
1998+
});
1999+
2000+
if (!lastOp)
2001+
return failure();
2002+
2003+
return lastOp;
2004+
}
2005+
19522006
/// Implementation of fusing consumer of a single slice by computing the
19532007
/// slice of the consumer in-place for scf loop.
19542008
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19792033
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
19802034
}
19812035

2036+
SmallVector<OpOperand *> potentialOperands = {*maybeConsumerOpOperand};
2037+
SmallVector<unsigned> potentialOperandResultNos = {
2038+
consumerOpOperand->getOperandNumber()};
2039+
SmallVector<Operation *> potentialSliceOps = {candidateSliceOp};
2040+
2041+
// 1b. Get all the other operands of the consumer op and their corresponding
2042+
// slice ops. In the case of the consumer using multiple results
2043+
// from the producer, we need to update every operand.
2044+
for (OpOperand &otherOperand : consumerOp->getOpOperands()) {
2045+
if (&otherOperand == *maybeConsumerOpOperand)
2046+
continue;
2047+
auto maybePotentialSlice = getSliceOpFromConsumerOperand(otherOperand);
2048+
if (failed(maybePotentialSlice)) {
2049+
continue;
2050+
}
2051+
potentialSliceOps.push_back(*maybePotentialSlice);
2052+
potentialOperands.push_back(&otherOperand);
2053+
potentialOperandResultNos.push_back(otherOperand.getOperandNumber());
2054+
}
2055+
19822056
// There are two possible cases regarding `oldLoopOp` here:
19832057
// 1. single `scf.forall` or `scf.for`.
19842058
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
@@ -2037,43 +2111,64 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
20372111
// tensor.insert_slice. In the scf.for case this is a clone of the
20382112
// candidateSliceOp whereas in the scf.forall case this is created from the
20392113
// operands of tensor.parallel_insert_slice.
2040-
tensor::InsertSliceOp clonedInsertSliceOp;
2114+
2115+
SmallVector<tensor::InsertSliceOp> allClonedInsertSliceOps;
2116+
2117+
scf::ForallOp newForallOp;
20412118
if (auto sliceOp =
20422119
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
20432120
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
20442121
rewriter.setInsertionPoint(newForallOp.getTerminator());
2045-
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2046-
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2047-
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
20482122
} else {
2049-
rewriter.setInsertionPoint(candidateSliceOp);
2050-
clonedInsertSliceOp =
2051-
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2123+
rewriter.setInsertionPoint(potentialSliceOps.back());
2124+
}
2125+
2126+
for (auto *candidateSliceOp : potentialSliceOps) {
2127+
if (auto sliceOp =
2128+
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2129+
allClonedInsertSliceOps.push_back(rewriter.create<tensor::InsertSliceOp>(
2130+
loc, sliceOp.getSource(), sliceOp.getDest(),
2131+
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
2132+
sliceOp.getMixedStrides()));
2133+
} else {
2134+
allClonedInsertSliceOps.push_back(
2135+
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)));
2136+
}
20522137
}
20532138

20542139
// 5.a. Clone consumer op.
20552140
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
20562141

20572142
// 5.b. Replace all uses of the loop result with the result of the cloned
20582143
// tensor.insert_slice.
2059-
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2060-
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2061-
operandToReplace.set(clonedInsertSliceOp.getResult());
2062-
});
2144+
for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
2145+
OpOperand &operandToReplace =
2146+
clonedConsumerOp->getOpOperand(potentialOperandResultNos[it.index()]);
2147+
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2148+
operandToReplace.set(it.value().getResult());
2149+
});
2150+
}
20632151

20642152
// 6. Perform tiling of the cloned consumer and replace the operand at
20652153
// `operandNumber` with the source of the cloned tensor.insert_slice op.
2066-
auto ossSliceOp =
2067-
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2154+
auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(
2155+
allClonedInsertSliceOps.front().getOperation());
20682156
FailureOr<TilingResult> tileAndFuseResult =
20692157
tensor::replaceInsertSliceWithTiledConsumer(
20702158
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2159+
20712160
if (failed(tileAndFuseResult)) {
20722161
return failure();
20732162
}
2163+
20742164
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2075-
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2076-
clonedInsertSliceOp.getSource());
2165+
2166+
// 6b. Update the tiled consumer op with the new operands.
2167+
for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
2168+
rewriter.replaceAllUsesWith(
2169+
tiledConsumerOp->getOperand(potentialOperandResultNos[it.index()]),
2170+
it.value().getSource());
2171+
}
20772172

20782173
// 7. Reconstruct [nested] loop with new inits.
20792174
YieldTiledValuesFn newYieldValuesFn =

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ module {
282282
return %unpack : tensor<2048xf32>
283283
}
284284
}
285-
285+
286286
module attributes {transform.with_named_sequence} {
287287
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
288288
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -343,7 +343,7 @@ module {
343343
return %unpack : tensor<2047xf32>
344344
}
345345
}
346-
346+
347347
module attributes {transform.with_named_sequence} {
348348
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
349349
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -404,7 +404,7 @@ module {
404404
return %pack : tensor<4x32x16xf32>
405405
}
406406
}
407-
407+
408408
module attributes {transform.with_named_sequence} {
409409
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
410410
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
610610
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
611611
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
612612
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
613-
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
613+
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
614614
// CHECK-SAME: {
615615
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
616616
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
676676
// CHECK: }
677677
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
678678
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
679+
680+
// -----
681+
682+
module {
683+
func.func @forall_producer_multiple_result_single_consumer(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
684+
%c4 = arith.constant 4 : index
685+
%c64 = arith.constant 64 : index
686+
%c0 = arith.constant 0 : index
687+
%1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
688+
%outs = tensor.empty() : tensor<32x32xf32>
689+
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
690+
%3 = linalg.matmul ins(%extracted_slice, %extracted_slice : tensor<32x32xf32>, tensor<32x32xf32>) outs(%outs : tensor<32x32xf32>) -> tensor<32x32xf32>
691+
scf.forall.in_parallel {
692+
tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
693+
tensor.parallel_insert_slice %extracted_slice into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
694+
}
695+
}
696+
%final_out = tensor.empty() : tensor<64x64xf32>
697+
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#0, %1#1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%final_out : tensor<64x64xf32>) -> tensor<64x64xf32>
698+
return %2 : tensor<64x64xf32>
699+
}
700+
}
701+
702+
module attributes {transform.with_named_sequence} {
703+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
704+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
705+
%1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
706+
%consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
707+
transform.yield
708+
}
709+
}
710+
711+
// CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
712+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>
713+
714+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
715+
// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])
716+
717+
// CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
718+
// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
719+
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
720+
// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
721+
// CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
722+
// CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
723+
// CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)
724+
725+
// CHECK: scf.forall.in_parallel {
726+
// CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
727+
// CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
728+
// CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
729+
// CHECK: }
730+
731+
// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>
732+
733+
734+
// -----
735+
736+
#map = affine_map<(d0) -> (d0)>
737+
module {
738+
func.func @for_producer_producing_multiple_result_single_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
739+
%c4 = arith.constant 4 : index
740+
%c64 = arith.constant 64 : index
741+
%c0 = arith.constant 0 : index
742+
%1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
743+
%extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
744+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
745+
^bb0(%in: f32, %in_16: f32, %out: f32):
746+
%13 = arith.mulf %in, %in_16 : f32
747+
%14 = arith.addf %out, %13 : f32
748+
linalg.yield %14 : f32
749+
} -> tensor<32xf32>
750+
%4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
751+
%5 = tensor.insert_slice %3 into %arg5[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
752+
scf.yield %5, %4 : tensor<64xf32>, tensor<64xf32>
753+
}
754+
%out_operand = tensor.empty() : tensor<64xf32>
755+
%2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %1#0 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand : tensor<64xf32>) -> tensor<64xf32>
756+
return %2 : tensor<64xf32>
757+
}
758+
}
759+
760+
module attributes {transform.with_named_sequence} {
761+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
762+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
763+
%1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
764+
%consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
765+
transform.yield
766+
}
767+
}
768+
769+
// CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
770+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
771+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
772+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>
773+
774+
// CHECK: %[[C4:.+]] = arith.constant 4 : index
775+
// CHECK: %[[C64:.+]] = arith.constant 64 : index
776+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
777+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32>
778+
779+
// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
780+
// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
781+
// CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
782+
783+
// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
784+
// CHECK: %[[GENERIC:.+]] = linalg.generic
785+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
786+
// CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
787+
// CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
788+
// CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
789+
// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
790+
// CHECK: linalg.yield %[[ADD]] : f32
791+
792+
// CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
793+
// CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
794+
// CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
795+
// CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
796+
// CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
797+
// CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
798+
// CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]
799+
800+
// CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]
801+
802+
// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>

0 commit comments

Comments
 (0)