11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
14
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
14
15
#include " mlir/Dialect/Arith/IR/Arith.h"
15
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
+ #include " mlir/Dialect/Tensor/IR/Tensor.h"
16
18
#include " mlir/Dialect/Vector/IR/VectorOps.h"
17
19
#include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18
20
#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -556,6 +558,101 @@ class FlattenContiguousRowMajorTransferWritePattern
556
558
}
557
559
};
558
560
561
+ // / Rewrite extractelement(transfer_read) to memref.load.
562
+ // /
563
+ // / Rewrite only if the extractelement op is the single user of the transfer op.
564
+ // / E.g., do not rewrite IR such as:
565
+ // / %0 = vector.transfer_read ... : vector<1024xf32>
566
+ // / %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
567
+ // / %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
568
+ // / Rewriting such IR (replacing one vector load with multiple scalar loads) may
569
+ // / negatively affect performance.
570
+ class FoldScalarExtractOfTransferRead
571
+ : public OpRewritePattern<vector::ExtractElementOp> {
572
+ using OpRewritePattern::OpRewritePattern;
573
+
574
+ LogicalResult matchAndRewrite (vector::ExtractElementOp extractOp,
575
+ PatternRewriter &rewriter) const override {
576
+ auto xferOp = extractOp.getVector ().getDefiningOp <vector::TransferReadOp>();
577
+ if (!xferOp)
578
+ return failure ();
579
+ // xfer result must have a single use. Otherwise, it may be better to
580
+ // perform a vector load.
581
+ if (!extractOp.getVector ().hasOneUse ())
582
+ return failure ();
583
+ // Mask not supported.
584
+ if (xferOp.getMask ())
585
+ return failure ();
586
+ // Map not supported.
587
+ if (!xferOp.getPermutationMap ().isMinorIdentity ())
588
+ return failure ();
589
+ // Cannot rewrite if the indices may be out of bounds. The starting point is
590
+ // always inbounds, so we don't care in case of 0d transfers.
591
+ if (xferOp.hasOutOfBoundsDim () && xferOp.getType ().getRank () > 0 )
592
+ return failure ();
593
+ // Construct scalar load.
594
+ SmallVector<Value> newIndices (xferOp.getIndices ().begin (),
595
+ xferOp.getIndices ().end ());
596
+ if (extractOp.getPosition ()) {
597
+ AffineExpr sym0, sym1;
598
+ bindSymbols (extractOp.getContext (), sym0, sym1);
599
+ OpFoldResult ofr = makeComposedFoldedAffineApply (
600
+ rewriter, extractOp.getLoc (), sym0 + sym1,
601
+ {newIndices[newIndices.size () - 1 ], extractOp.getPosition ()});
602
+ if (ofr.is <Value>()) {
603
+ newIndices[newIndices.size () - 1 ] = ofr.get <Value>();
604
+ } else {
605
+ newIndices[newIndices.size () - 1 ] =
606
+ rewriter.create <arith::ConstantIndexOp>(extractOp.getLoc (),
607
+ *getConstantIntValue (ofr));
608
+ }
609
+ }
610
+ if (xferOp.getSource ().getType ().isa <MemRefType>()) {
611
+ rewriter.replaceOpWithNewOp <memref::LoadOp>(extractOp, xferOp.getSource (),
612
+ newIndices);
613
+ } else {
614
+ rewriter.replaceOpWithNewOp <tensor::ExtractOp>(
615
+ extractOp, xferOp.getSource (), newIndices);
616
+ }
617
+ return success ();
618
+ }
619
+ };
620
+
621
+ // / Rewrite scalar transfer_write(broadcast) to memref.store.
622
+ class FoldScalarTransferWriteOfBroadcast
623
+ : public OpRewritePattern<vector::TransferWriteOp> {
624
+ using OpRewritePattern::OpRewritePattern;
625
+
626
+ LogicalResult matchAndRewrite (vector::TransferWriteOp xferOp,
627
+ PatternRewriter &rewriter) const override {
628
+ // Must be a scalar write.
629
+ auto vecType = xferOp.getVectorType ();
630
+ if (vecType.getRank () != 0 &&
631
+ (vecType.getRank () != 1 || vecType.getShape ()[0 ] != 1 ))
632
+ return failure ();
633
+ // Mask not supported.
634
+ if (xferOp.getMask ())
635
+ return failure ();
636
+ // Map not supported.
637
+ if (!xferOp.getPermutationMap ().isMinorIdentity ())
638
+ return failure ();
639
+ // Must be a broadcast of a scalar.
640
+ auto broadcastOp = xferOp.getVector ().getDefiningOp <vector::BroadcastOp>();
641
+ if (!broadcastOp || broadcastOp.getSource ().getType ().isa <VectorType>())
642
+ return failure ();
643
+ // Construct a scalar store.
644
+ if (xferOp.getSource ().getType ().isa <MemRefType>()) {
645
+ rewriter.replaceOpWithNewOp <memref::StoreOp>(
646
+ xferOp, broadcastOp.getSource (), xferOp.getSource (),
647
+ xferOp.getIndices ());
648
+ } else {
649
+ rewriter.replaceOpWithNewOp <tensor::InsertOp>(
650
+ xferOp, broadcastOp.getSource (), xferOp.getSource (),
651
+ xferOp.getIndices ());
652
+ }
653
+ return success ();
654
+ }
655
+ };
559
656
} // namespace
560
657
561
658
void mlir::vector::transferOpflowOpt (Operation *rootOp) {
@@ -574,6 +671,13 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
574
671
opt.removeDeadOp ();
575
672
}
576
673
674
+ void mlir::vector::populateScalarVectorTransferLoweringPatterns (
675
+ RewritePatternSet &patterns, PatternBenefit benefit) {
676
+ patterns
677
+ .add <FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
678
+ patterns.getContext (), benefit);
679
+ }
680
+
577
681
void mlir::vector::populateVectorTransferDropUnitDimsPatterns (
578
682
RewritePatternSet &patterns, PatternBenefit benefit) {
579
683
patterns
0 commit comments