Skip to content

Commit 54ae9e7

Browse files
authored
[mlir][SCF] Fix condition for fusability in consumer fusion API (#115768)
It was previously allowing either a tilable or dps op to be fused. Both are required for consumer fusion.
1 parent f1800df commit 54ae9e7

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,7 +1710,8 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
17101710
for (OpOperand &opOperand : val.getUses()) {
17111711
Operation *consumerOp = opOperand.getOwner();
17121712
// Step 1. Check if the user is tilable.
1713-
if (!isa<TilingInterface, DestinationStyleOpInterface>(consumerOp)) {
1713+
if (!isa<TilingInterface>(consumerOp) ||
1714+
!isa<DestinationStyleOpInterface>(consumerOp)) {
17141715
// TODO: We have to init result of consumer before scf.for, use
17151716
// DestinationStyleOpInterface to get result shape from init for now. Add
17161717
// support for other op such as op has InferTypeOpInterface.

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
570570
// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
571571
// CHECK: }
572572
// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
573+
574+
// -----
575+
576+
module {
577+
func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
578+
%c0 = arith.constant 0 : index
579+
%c64 = arith.constant 64 : index
580+
%c256 = arith.constant 256 : index
581+
%cst = arith.constant 0.000000e+00 : f32
582+
%dest0 = tensor.empty() : tensor<256x256xf32>
583+
%1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
584+
%extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
585+
%extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
586+
%extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
587+
%3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
588+
%insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
589+
scf.yield %insert_slice : tensor<256x256xf32>
590+
}
591+
%dest1 = tensor.empty() : tensor<258x258xf32>
592+
%4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
593+
%5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
594+
return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
595+
}
596+
}
597+
598+
module attributes {transform.with_named_sequence} {
599+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
600+
%slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
601+
: (!transform.any_op) -> !transform.any_op
602+
%slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
603+
%a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
604+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
605+
transform.yield
606+
}
607+
}
608+
// CHECK: func.func @no_fuse_only_dps_consumer(
609+
// CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
610+
// CHECK: linalg.add
611+
// CHECK: linalg.mul
612+
// CHECK: scf.yield
613+
// CHECK: }
614+
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
615+
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]

0 commit comments

Comments
 (0)