@@ -775,23 +775,26 @@ class FlattenContiguousRowMajorTransferWritePattern
775
775
unsigned targetVectorBitwidth;
776
776
};
777
777
778
- // / Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
779
- // / to `memref.load` patterns. The `match` method is shared for both
780
- // / `vector.extract` and `vector.extract_element`.
781
- template <class VectorExtractOp >
782
- class RewriteScalarExtractOfTransferReadBase
783
- : public OpRewritePattern<VectorExtractOp> {
784
- using Base = OpRewritePattern<VectorExtractOp>;
785
-
778
+ // / Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
779
+ // /
780
+ // / All the users of the transfer op must be `vector.extract` ops. If
781
+ // / `allowMultipleUses` is set to true, rewrite transfer ops with any number of
782
+ // / users. Otherwise, rewrite only if the extract op is the single user of the
783
+ // / transfer op. Rewriting a single vector load with multiple scalar loads may
784
+ // / negatively affect performance.
785
+ class RewriteScalarExtractOfTransferRead
786
+ : public OpRewritePattern<vector::ExtractOp> {
786
787
public:
787
- RewriteScalarExtractOfTransferReadBase (MLIRContext *context,
788
- PatternBenefit benefit,
789
- bool allowMultipleUses)
790
- : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
791
-
792
- LogicalResult match (VectorExtractOp extractOp) const {
793
- auto xferOp =
794
- extractOp.getVector ().template getDefiningOp <vector::TransferReadOp>();
788
+ RewriteScalarExtractOfTransferRead (MLIRContext *context,
789
+ PatternBenefit benefit,
790
+ bool allowMultipleUses)
791
+ : OpRewritePattern(context, benefit),
792
+ allowMultipleUses (allowMultipleUses) {}
793
+
794
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
795
+ PatternRewriter &rewriter) const override {
796
+ // Match phase.
797
+ auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
795
798
if (!xferOp)
796
799
return failure ();
797
800
// Check that we are extracting a scalar and not a sub-vector.
@@ -803,8 +806,7 @@ class RewriteScalarExtractOfTransferReadBase
803
806
// If multiple uses are allowed, check if all the xfer uses are extract ops.
804
807
if (allowMultipleUses &&
805
808
!llvm::all_of (xferOp->getUses (), [](OpOperand &use) {
806
- return isa<vector::ExtractOp, vector::ExtractElementOp>(
807
- use.getOwner ());
809
+ return isa<vector::ExtractOp>(use.getOwner ());
808
810
}))
809
811
return failure ();
810
812
// Mask not supported.
@@ -816,81 +818,8 @@ class RewriteScalarExtractOfTransferReadBase
816
818
// Cannot rewrite if the indices may be out of bounds.
817
819
if (xferOp.hasOutOfBoundsDim ())
818
820
return failure ();
819
- return success ();
820
- }
821
-
822
- private:
823
- bool allowMultipleUses;
824
- };
825
-
826
- // / Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
827
- // /
828
- // / All the users of the transfer op must be either `vector.extractelement` or
829
- // / `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
830
- // / transfer ops with any number of users. Otherwise, rewrite only if the
831
- // / extract op is the single user of the transfer op. Rewriting a single
832
- // / vector load with multiple scalar loads may negatively affect performance.
833
- class RewriteScalarExtractElementOfTransferRead
834
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
835
- using RewriteScalarExtractOfTransferReadBase::
836
- RewriteScalarExtractOfTransferReadBase;
837
-
838
- LogicalResult matchAndRewrite (vector::ExtractElementOp extractOp,
839
- PatternRewriter &rewriter) const override {
840
- if (failed (match (extractOp)))
841
- return failure ();
842
-
843
- // Construct scalar load.
844
- auto loc = extractOp.getLoc ();
845
- auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
846
- SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
847
- xferOp.getIndices ().end ());
848
- if (extractOp.getPosition ()) {
849
- AffineExpr sym0, sym1;
850
- bindSymbols (extractOp.getContext (), sym0, sym1);
851
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
852
- rewriter, loc, sym0 + sym1,
853
- {newIndices[newIndices.size () - 1 ], extractOp.getPosition ()});
854
- if (auto value = dyn_cast<Value>(ofr)) {
855
- newIndices[newIndices.size () - 1 ] = value;
856
- } else {
857
- newIndices[newIndices.size () - 1 ] =
858
- rewriter.create <arith::ConstantIndexOp>(loc,
859
- *getConstantIntValue (ofr));
860
- }
861
- }
862
- if (isa<MemRefType>(xferOp.getBase ().getType ())) {
863
- rewriter.replaceOpWithNewOp <memref::LoadOp>(extractOp, xferOp.getBase (),
864
- newIndices);
865
- } else {
866
- rewriter.replaceOpWithNewOp <tensor::ExtractOp>(
867
- extractOp, xferOp.getBase (), newIndices);
868
- }
869
-
870
- return success ();
871
- }
872
- };
873
-
874
- // / Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
875
- // / Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
876
- // /
877
- // / All the users of the transfer op must be either `vector.extractelement` or
878
- // / `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
879
- // / transfer ops with any number of users. Otherwise, rewrite only if the
880
- // / extract op is the single user of the transfer op. Rewriting a single
881
- // / vector load with multiple scalar loads may negatively affect performance.
882
- class RewriteScalarExtractOfTransferRead
883
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
884
- using RewriteScalarExtractOfTransferReadBase::
885
- RewriteScalarExtractOfTransferReadBase;
886
-
887
- LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
888
- PatternRewriter &rewriter) const override {
889
- if (failed (match (extractOp)))
890
- return failure ();
891
821
892
- // Construct scalar load.
893
- auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
822
+ // Rewrite phase: construct scalar load.
894
823
SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
895
824
xferOp.getIndices ().end ());
896
825
for (auto [i, pos] : llvm::enumerate (extractOp.getMixedPosition ())) {
@@ -931,6 +860,9 @@ class RewriteScalarExtractOfTransferRead
931
860
932
861
return success ();
933
862
}
863
+
864
+ private:
865
+ bool allowMultipleUses;
934
866
};
935
867
936
868
// / Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -987,8 +919,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
987
919
void mlir::vector::populateScalarVectorTransferLoweringPatterns (
988
920
RewritePatternSet &patterns, PatternBenefit benefit,
989
921
bool allowMultipleUses) {
990
- patterns.add <RewriteScalarExtractElementOfTransferRead,
991
- RewriteScalarExtractOfTransferRead>(patterns.getContext (),
922
+ patterns.add <RewriteScalarExtractOfTransferRead>(patterns.getContext (),
992
923
benefit, allowMultipleUses);
993
924
patterns.add <RewriteScalarWrite>(patterns.getContext (), benefit);
994
925
}
0 commit comments