@@ -767,23 +767,26 @@ class FlattenContiguousRowMajorTransferWritePattern
767
767
unsigned targetVectorBitwidth;
768
768
};
769
769
770
- // / Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
771
- // / to `memref.load` patterns. The `match` method is shared for both
772
- // / `vector.extract` and `vector.extract_element`.
773
- template <class VectorExtractOp >
774
- class RewriteScalarExtractOfTransferReadBase
775
- : public OpRewritePattern<VectorExtractOp> {
776
- using Base = OpRewritePattern<VectorExtractOp>;
777
-
770
+ // / Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
771
+ // /
772
+ // / All the users of the transfer op must be `vector.extract` ops. If
773
+ // / `allowMultipleUses` is set to true, rewrite transfer ops with any number of
774
+ // / users. Otherwise, rewrite only if the extract op is the single user of the
775
+ // / transfer op. Rewriting a single vector load with multiple scalar loads may
776
+ // / negatively affect performance.
777
+ class RewriteScalarExtractOfTransferRead
778
+ : public OpRewritePattern<vector::ExtractOp> {
778
779
public:
779
- RewriteScalarExtractOfTransferReadBase (MLIRContext *context,
780
- PatternBenefit benefit,
781
- bool allowMultipleUses)
782
- : Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
783
-
784
- LogicalResult match (VectorExtractOp extractOp) const {
785
- auto xferOp =
786
- extractOp.getVector ().template getDefiningOp <vector::TransferReadOp>();
780
+ RewriteScalarExtractOfTransferRead (MLIRContext *context,
781
+ PatternBenefit benefit,
782
+ bool allowMultipleUses)
783
+ : OpRewritePattern(context, benefit),
784
+ allowMultipleUses (allowMultipleUses) {}
785
+
786
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
787
+ PatternRewriter &rewriter) const override {
788
+ // Match phase.
789
+ auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
787
790
if (!xferOp)
788
791
return failure ();
789
792
// Check that we are extracting a scalar and not a sub-vector.
@@ -795,8 +798,7 @@ class RewriteScalarExtractOfTransferReadBase
795
798
// If multiple uses are allowed, check if all the xfer uses are extract ops.
796
799
if (allowMultipleUses &&
797
800
!llvm::all_of (xferOp->getUses (), [](OpOperand &use) {
798
- return isa<vector::ExtractOp, vector::ExtractElementOp>(
799
- use.getOwner ());
801
+ return isa<vector::ExtractOp>(use.getOwner ());
800
802
}))
801
803
return failure ();
802
804
// Mask not supported.
@@ -808,81 +810,8 @@ class RewriteScalarExtractOfTransferReadBase
808
810
// Cannot rewrite if the indices may be out of bounds.
809
811
if (xferOp.hasOutOfBoundsDim ())
810
812
return failure ();
811
- return success ();
812
- }
813
-
814
- private:
815
- bool allowMultipleUses;
816
- };
817
-
818
- // / Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
819
- // /
820
- // / All the users of the transfer op must be either `vector.extractelement` or
821
- // / `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
822
- // / transfer ops with any number of users. Otherwise, rewrite only if the
823
- // / extract op is the single user of the transfer op. Rewriting a single
824
- // / vector load with multiple scalar loads may negatively affect performance.
825
- class RewriteScalarExtractElementOfTransferRead
826
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
827
- using RewriteScalarExtractOfTransferReadBase::
828
- RewriteScalarExtractOfTransferReadBase;
829
-
830
- LogicalResult matchAndRewrite (vector::ExtractElementOp extractOp,
831
- PatternRewriter &rewriter) const override {
832
- if (failed (match (extractOp)))
833
- return failure ();
834
-
835
- // Construct scalar load.
836
- auto loc = extractOp.getLoc ();
837
- auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
838
- SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
839
- xferOp.getIndices ().end ());
840
- if (extractOp.getPosition ()) {
841
- AffineExpr sym0, sym1;
842
- bindSymbols (extractOp.getContext (), sym0, sym1);
843
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
844
- rewriter, loc, sym0 + sym1,
845
- {newIndices[newIndices.size () - 1 ], extractOp.getPosition ()});
846
- if (auto value = dyn_cast<Value>(ofr)) {
847
- newIndices[newIndices.size () - 1 ] = value;
848
- } else {
849
- newIndices[newIndices.size () - 1 ] =
850
- rewriter.create <arith::ConstantIndexOp>(loc,
851
- *getConstantIntValue (ofr));
852
- }
853
- }
854
- if (isa<MemRefType>(xferOp.getBase ().getType ())) {
855
- rewriter.replaceOpWithNewOp <memref::LoadOp>(extractOp, xferOp.getBase (),
856
- newIndices);
857
- } else {
858
- rewriter.replaceOpWithNewOp <tensor::ExtractOp>(
859
- extractOp, xferOp.getBase (), newIndices);
860
- }
861
-
862
- return success ();
863
- }
864
- };
865
-
866
- // / Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
867
- // / Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
868
- // /
869
- // / All the users of the transfer op must be either `vector.extractelement` or
870
- // / `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
871
- // / transfer ops with any number of users. Otherwise, rewrite only if the
872
- // / extract op is the single user of the transfer op. Rewriting a single
873
- // / vector load with multiple scalar loads may negatively affect performance.
874
- class RewriteScalarExtractOfTransferRead
875
- : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
876
- using RewriteScalarExtractOfTransferReadBase::
877
- RewriteScalarExtractOfTransferReadBase;
878
-
879
- LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
880
- PatternRewriter &rewriter) const override {
881
- if (failed (match (extractOp)))
882
- return failure ();
883
813
884
- // Construct scalar load.
885
- auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
814
+ // Rewrite phase: construct scalar load.
886
815
SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
887
816
xferOp.getIndices ().end ());
888
817
for (auto [i, pos] : llvm::enumerate (extractOp.getMixedPosition ())) {
@@ -923,6 +852,9 @@ class RewriteScalarExtractOfTransferRead
923
852
924
853
return success ();
925
854
}
855
+
856
+ private:
857
+ bool allowMultipleUses;
926
858
};
927
859
928
860
// / Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
@@ -979,8 +911,7 @@ void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
979
911
void mlir::vector::populateScalarVectorTransferLoweringPatterns (
980
912
RewritePatternSet &patterns, PatternBenefit benefit,
981
913
bool allowMultipleUses) {
982
- patterns.add <RewriteScalarExtractElementOfTransferRead,
983
- RewriteScalarExtractOfTransferRead>(patterns.getContext (),
914
+ patterns.add <RewriteScalarExtractOfTransferRead>(patterns.getContext (),
984
915
benefit, allowMultipleUses);
985
916
patterns.add <RewriteScalarWrite>(patterns.getContext (), benefit);
986
917
}
0 commit comments