@@ -815,6 +815,77 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
815
815
return success ();
816
816
}
817
817
818
+ // Create an expanded transpose op.
819
+ static Operation *
820
+ createExpandedTransposeOp (PatternRewriter &rewriter, TransposeOp transposeOp,
821
+ SmallVector<ReassociationIndices> reassociation,
822
+ Value expandedInput, Value output) {
823
+ applyPermutationToVector (reassociation, transposeOp.getPermutation ());
824
+ SmallVector<int64_t > newPerm;
825
+ for (auto reassoc : reassociation) {
826
+ for (auto dim : reassoc) {
827
+ newPerm.push_back (dim);
828
+ }
829
+ }
830
+ return rewriter.create <TransposeOp>(transposeOp.getLoc (), expandedInput,
831
+ output, newPerm);
832
+ }
833
+
834
+ // Create an expanded generic op.
835
+ static Operation *createExpandedGenericOp (
836
+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
837
+ ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
838
+ ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
839
+ // The iterator types of the expanded op are all parallel.
840
+ SmallVector<utils::IteratorType> iteratorTypes (
841
+ expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
842
+
843
+ for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
844
+ for (auto j : expansionInfo.getExpandedDims (i))
845
+ iteratorTypes[j] = type;
846
+
847
+ Operation *fused = rewriter.create <GenericOp>(
848
+ linalgOp.getLoc (), resultTypes, expandedOpOperands, outputs,
849
+ expandedOpIndexingMaps, iteratorTypes);
850
+
851
+ Region &fusedRegion = fused->getRegion (0 );
852
+ Region &originalRegion = linalgOp->getRegion (0 );
853
+ rewriter.cloneRegionBefore (originalRegion, fusedRegion, fusedRegion.begin ());
854
+
855
+ // Update the index accesses after the expansion.
856
+ updateExpandedGenericOpRegion (rewriter, linalgOp.getLoc (), fusedRegion,
857
+ expansionInfo);
858
+
859
+ return fused;
860
+ }
861
+
862
+ // Create an expanded fused op that retains the name for certain ops
863
+ // such as fill, copy and transpose and produce a generic op for
864
+ // rest of linalg ops.
865
+ static Operation *createExpandedOp (
866
+ PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
867
+ ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
868
+ ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
869
+ SmallVector<ReassociationIndices> reassociation) {
870
+
871
+ return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation ())
872
+ .Case <TransposeOp>([&](TransposeOp transposeOp) {
873
+ return createExpandedTransposeOp (rewriter, transposeOp, reassociation,
874
+ expandedOpOperands[0 ], outputs[0 ]);
875
+ })
876
+ .Case <FillOp, CopyOp>([&](Operation *op) {
877
+ return clone (rewriter, linalgOp, resultTypes,
878
+ llvm::to_vector (llvm::concat<Value>(
879
+ llvm::to_vector (expandedOpOperands),
880
+ llvm::to_vector (outputs))));
881
+ })
882
+ .Default ([&](Operation *op) {
883
+ return createExpandedGenericOp (rewriter, linalgOp, resultTypes,
884
+ expandedOpOperands, outputs,
885
+ expansionInfo, expandedOpIndexingMaps);
886
+ });
887
+ }
888
+
818
889
// / Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
819
890
// / and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
820
891
// / that those conditions have been satisfied.
@@ -919,25 +990,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
919
990
}
920
991
}
921
992
922
- // The iterator types of the expanded op are all parallel.
923
- SmallVector<utils::IteratorType> iteratorTypes (
924
- expansionInfo.getExpandedOpNumDims (), utils::IteratorType::parallel);
925
- for (auto [i, type] : llvm::enumerate (linalgOp.getIteratorTypesArray ()))
926
- for (auto j : expansionInfo.getExpandedDims (i))
927
- iteratorTypes[j] = type;
928
-
929
993
TypeRange resultTypes = ValueRange (outputs).getTypes ();
930
- auto fusedOp =
931
- rewriter.create <GenericOp>(linalgOp.getLoc (), resultTypes,
932
- /* inputs=*/ expandedOpOperands, outputs,
933
- expandedOpIndexingMaps, iteratorTypes);
934
- Region &fusedRegion = fusedOp->getRegion (0 );
935
- Region &originalRegion = linalgOp->getRegion (0 );
936
- rewriter.cloneRegionBefore (originalRegion, fusedRegion, fusedRegion.begin ());
937
-
938
- // Update the index accesses after the expansion.
939
- updateExpandedGenericOpRegion (rewriter, loc, fusedRegion, expansionInfo);
940
-
994
+ SmallVector<ReassociationIndices> reassociationBeforeExpansion =
995
+ isExpanding ? expandingReshapeOp.getReassociationIndices ()
996
+ : collapsingReshapeOp.getReassociationIndices ();
997
+ Operation *fusedOp = createExpandedOp (
998
+ rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
999
+ expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
941
1000
// Reshape the result values to their original shape if this is a collapsing
942
1001
// reshape folded into its consumer.
943
1002
SmallVector<Value> resultVals;
0 commit comments