@@ -570,3 +570,46 @@ module attributes {transform.with_named_sequence} {
570
570
// CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
571
571
// CHECK: }
572
572
// CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
573
+
574
+ // -----
575
+
576
+ module {
577
+ func.func @no_fuse_only_dps_consumer (%arg0: tensor <256 x256 xf32 >, %arg1: tensor <256 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> (tensor <256 x256 xf32 >, tensor <258 x258 xf32 >) {
578
+ %c0 = arith.constant 0 : index
579
+ %c64 = arith.constant 64 : index
580
+ %c256 = arith.constant 256 : index
581
+ %cst = arith.constant 0.000000e+00 : f32
582
+ %dest0 = tensor.empty () : tensor <256 x256 xf32 >
583
+ %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args (%arg4 = %dest0 ) -> (tensor <256 x256 xf32 >) {
584
+ %extracted_slice_1 = tensor.extract_slice %arg4 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
585
+ %extracted_slice_2 = tensor.extract_slice %arg0 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
586
+ %extracted_slice_3 = tensor.extract_slice %arg1 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <64 x256 xf32 >
587
+ %3 = linalg.add ins (%extracted_slice_2 , %extracted_slice_3 : tensor <64 x256 xf32 >, tensor <64 x256 xf32 >) outs (%extracted_slice_1 : tensor <64 x256 xf32 >) -> tensor <64 x256 xf32 >
588
+ %insert_slice = tensor.insert_slice %3 into %arg4 [%arg3 , 0 ] [64 , 256 ] [1 , 1 ] : tensor <64 x256 xf32 > into tensor <256 x256 xf32 >
589
+ scf.yield %insert_slice : tensor <256 x256 xf32 >
590
+ }
591
+ %dest1 = tensor.empty () : tensor <258 x258 xf32 >
592
+ %4 = tensor.insert_slice %1 into %dest1 [0 , 0 ] [256 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > into tensor <258 x258 xf32 >
593
+ %5 = linalg.mul ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
594
+ return %5 , %4 : tensor <256 x256 xf32 >, tensor <258 x258 xf32 >
595
+ }
596
+ }
597
+
598
+ module attributes {transform.with_named_sequence } {
599
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
600
+ %slice_ops = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg1
601
+ : (!transform.any_op ) -> !transform.any_op
602
+ %slice_op , %other_slice = transform.split_handle %slice_ops : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
603
+ %a , %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
604
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
605
+ transform.yield
606
+ }
607
+ }
608
+ // CHECK: func.func @no_fuse_only_dps_consumer(
609
+ // CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
610
+ // CHECK: linalg.add
611
+ // CHECK: linalg.mul
612
+ // CHECK: scf.yield
613
+ // CHECK: }
614
+ // CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
615
+ // CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
0 commit comments