@@ -282,7 +282,7 @@ module {
282
282
return %unpack : tensor <2048 xf32 >
283
283
}
284
284
}
285
-
285
+
286
286
module attributes {transform.with_named_sequence } {
287
287
transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
288
288
%slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
@@ -343,7 +343,7 @@ module {
343
343
return %unpack : tensor <2047 xf32 >
344
344
}
345
345
}
346
-
346
+
347
347
module attributes {transform.with_named_sequence } {
348
348
transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
349
349
%slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
@@ -404,7 +404,7 @@ module {
404
404
return %pack : tensor <4 x32 x16 xf32 >
405
405
}
406
406
}
407
-
407
+
408
408
module attributes {transform.with_named_sequence } {
409
409
transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
410
410
%slice_op = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg1
@@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
610
610
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
611
611
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
612
612
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
613
- // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
613
+ // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
614
614
// CHECK-SAME: {
615
615
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
616
616
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
676
676
// CHECK: }
677
677
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
678
678
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
679
+
680
+ // -----
681
+
682
+ module {
683
+ func.func @forall_producer_multiple_result_single_consumer (%arg2: tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 > {
684
+ %c4 = arith.constant 4 : index
685
+ %c64 = arith.constant 64 : index
686
+ %c0 = arith.constant 0 : index
687
+ %1:2 = scf.forall (%arg3 , %arg4 ) in (2 , 2 ) shared_outs (%arg5 = %arg2 , %arg6 = %arg2 ) -> (tensor <64 x64 xf32 >, tensor <64 x64 xf32 >) {
688
+ %outs = tensor.empty () : tensor <32 x32 xf32 >
689
+ %extracted_slice = tensor.extract_slice %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <64 x64 xf32 > to tensor <32 x32 xf32 >
690
+ %3 = linalg.matmul ins (%extracted_slice , %extracted_slice : tensor <32 x32 xf32 >, tensor <32 x32 xf32 >) outs (%outs : tensor <32 x32 xf32 >) -> tensor <32 x32 xf32 >
691
+ scf.forall.in_parallel {
692
+ tensor.parallel_insert_slice %3 into %arg6 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <32 x32 xf32 > into tensor <64 x64 xf32 >
693
+ tensor.parallel_insert_slice %extracted_slice into %arg5 [%arg3 , %arg4 ] [32 , 32 ] [1 , 1 ] : tensor <32 x32 xf32 > into tensor <64 x64 xf32 >
694
+ }
695
+ }
696
+ %final_out = tensor.empty () : tensor <64 x64 xf32 >
697
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn <add >} ins (%1#0 , %1#1 : tensor <64 x64 xf32 >, tensor <64 x64 xf32 >) outs (%final_out : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
698
+ return %2 : tensor <64 x64 xf32 >
699
+ }
700
+ }
701
+
702
+ module attributes {transform.with_named_sequence } {
703
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
704
+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
705
+ %1:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
706
+ %consumer , %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
707
+ transform.yield
708
+ }
709
+ }
710
+
711
+ // CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
712
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>
713
+
714
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
715
+ // CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])
716
+
717
+ // CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
718
+ // CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
719
+ // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
720
+ // CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
721
+ // CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
722
+ // CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
723
+ // CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)
724
+
725
+ // CHECK: scf.forall.in_parallel {
726
+ // CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
727
+ // CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
728
+ // CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
729
+ // CHECK: }
730
+
731
+ // CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>
732
+
733
+
734
+ // -----
735
+
736
+ #map = affine_map <(d0 ) -> (d0 )>
737
+ module {
738
+ func.func @for_producer_producing_multiple_result_single_consumer (%arg0: tensor <32 xf32 >, %arg1: tensor <32 xf32 >, %arg2: tensor <64 xf32 >) -> tensor <64 xf32 > {
739
+ %c4 = arith.constant 4 : index
740
+ %c64 = arith.constant 64 : index
741
+ %c0 = arith.constant 0 : index
742
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args (%arg4 = %arg2 , %arg5 = %arg2 ) -> (tensor <64 xf32 >, tensor <64 xf32 >) {
743
+ %extracted_slice = tensor.extract_slice %arg4 [%arg3 ] [32 ] [1 ] : tensor <64 xf32 > to tensor <32 xf32 >
744
+ %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" ]} ins (%arg0 , %arg1 : tensor <32 xf32 >, tensor <32 xf32 >) outs (%extracted_slice : tensor <32 xf32 >) {
745
+ ^bb0 (%in: f32 , %in_16: f32 , %out: f32 ):
746
+ %13 = arith.mulf %in , %in_16 : f32
747
+ %14 = arith.addf %out , %13 : f32
748
+ linalg.yield %14 : f32
749
+ } -> tensor <32 xf32 >
750
+ %4 = tensor.insert_slice %3 into %arg4 [%arg3 ] [32 ] [1 ] : tensor <32 xf32 > into tensor <64 xf32 >
751
+ %5 = tensor.insert_slice %3 into %arg5 [%arg3 ] [32 ] [1 ] : tensor <32 xf32 > into tensor <64 xf32 >
752
+ scf.yield %5 , %4 : tensor <64 xf32 >, tensor <64 xf32 >
753
+ }
754
+ %out_operand = tensor.empty () : tensor <64 xf32 >
755
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn <add >} ins (%1#1 , %1#0 : tensor <64 xf32 >, tensor <64 xf32 >) outs (%out_operand : tensor <64 xf32 >) -> tensor <64 xf32 >
756
+ return %2 : tensor <64 xf32 >
757
+ }
758
+ }
759
+
760
+ module attributes {transform.with_named_sequence } {
761
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
762
+ %0 = transform.structured.match ops {[" tensor.insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
763
+ %1:2 = transform.split_handle %0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
764
+ %consumer , %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
765
+ transform.yield
766
+ }
767
+ }
768
+
769
+ // CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
770
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
771
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
772
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>
773
+
774
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
775
+ // CHECK: %[[C64:.+]] = arith.constant 64 : index
776
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
777
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32>
778
+
779
+ // CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
780
+ // CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
781
+ // CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
782
+
783
+ // CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
784
+ // CHECK: %[[GENERIC:.+]] = linalg.generic
785
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
786
+ // CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
787
+ // CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
788
+ // CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
789
+ // CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
790
+ // CHECK: linalg.yield %[[ADD]] : f32
791
+
792
+ // CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
793
+ // CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
794
+ // CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
795
+ // CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
796
+ // CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
797
+ // CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
798
+ // CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]
799
+
800
+ // CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]
801
+
802
+ // CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>
0 commit comments