Skip to content

Commit 2ec98ff

Browse files
[mlir][vector] Add scalar vector xfer to memref patterns
These patterns devectorize scalar transfers such as vector<f32> or vector<1xf32>. Differential Revision: https://reviews.llvm.org/D140215
1 parent a583616 commit 2ec98ff

File tree

5 files changed

+214
-0
lines changed

5 files changed

+214
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns,
142142
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
143143
PatternBenefit benefit = 1);
144144

145+
/// Collects patterns that lower scalar vector transfer ops to memref loads and
146+
/// stores when beneficial.
147+
void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns,
148+
PatternBenefit benefit = 1);
149+
145150
/// Returns the integer type required for subscripts in the vector dialect.
146151
IntegerType getVectorSubscriptType(Builder &builder);
147152

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
3232
MLIRMemRefDialect
3333
MLIRSCFDialect
3434
MLIRSideEffectInterfaces
35+
MLIRTensorDialect
3536
MLIRTransforms
3637
MLIRVectorDialect
3738
MLIRVectorInterfaces

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1415
#include "mlir/Dialect/Arith/IR/Arith.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1618
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1719
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1820
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -556,6 +558,101 @@ class FlattenContiguousRowMajorTransferWritePattern
556558
}
557559
};
558560

561+
/// Rewrite extractelement(transfer_read) to memref.load.
562+
///
563+
/// Rewrite only if the extractelement op is the single user of the transfer op.
564+
/// E.g., do not rewrite IR such as:
565+
/// %0 = vector.transfer_read ... : vector<1024xf32>
566+
/// %1 = vector.extractelement %0[%a : index] : vector<1024xf32>
567+
/// %2 = vector.extractelement %0[%b : index] : vector<1024xf32>
568+
/// Rewriting such IR (replacing one vector load with multiple scalar loads) may
569+
/// negatively affect performance.
570+
class FoldScalarExtractOfTransferRead
571+
: public OpRewritePattern<vector::ExtractElementOp> {
572+
using OpRewritePattern::OpRewritePattern;
573+
574+
LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
575+
PatternRewriter &rewriter) const override {
576+
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
577+
if (!xferOp)
578+
return failure();
579+
// xfer result must have a single use. Otherwise, it may be better to
580+
// perform a vector load.
581+
if (!extractOp.getVector().hasOneUse())
582+
return failure();
583+
// Mask not supported.
584+
if (xferOp.getMask())
585+
return failure();
586+
// Map not supported.
587+
if (!xferOp.getPermutationMap().isMinorIdentity())
588+
return failure();
589+
// Cannot rewrite if the indices may be out of bounds. The starting point is
590+
// always inbounds, so we don't care in case of 0d transfers.
591+
if (xferOp.hasOutOfBoundsDim() && xferOp.getType().getRank() > 0)
592+
return failure();
593+
// Construct scalar load.
594+
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
595+
xferOp.getIndices().end());
596+
if (extractOp.getPosition()) {
597+
AffineExpr sym0, sym1;
598+
bindSymbols(extractOp.getContext(), sym0, sym1);
599+
OpFoldResult ofr = makeComposedFoldedAffineApply(
600+
rewriter, extractOp.getLoc(), sym0 + sym1,
601+
{newIndices[newIndices.size() - 1], extractOp.getPosition()});
602+
if (ofr.is<Value>()) {
603+
newIndices[newIndices.size() - 1] = ofr.get<Value>();
604+
} else {
605+
newIndices[newIndices.size() - 1] =
606+
rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
607+
*getConstantIntValue(ofr));
608+
}
609+
}
610+
if (xferOp.getSource().getType().isa<MemRefType>()) {
611+
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
612+
newIndices);
613+
} else {
614+
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
615+
extractOp, xferOp.getSource(), newIndices);
616+
}
617+
return success();
618+
}
619+
};
620+
621+
/// Rewrite scalar transfer_write(broadcast) to memref.store.
622+
class FoldScalarTransferWriteOfBroadcast
623+
: public OpRewritePattern<vector::TransferWriteOp> {
624+
using OpRewritePattern::OpRewritePattern;
625+
626+
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
627+
PatternRewriter &rewriter) const override {
628+
// Must be a scalar write.
629+
auto vecType = xferOp.getVectorType();
630+
if (vecType.getRank() != 0 &&
631+
(vecType.getRank() != 1 || vecType.getShape()[0] != 1))
632+
return failure();
633+
// Mask not supported.
634+
if (xferOp.getMask())
635+
return failure();
636+
// Map not supported.
637+
if (!xferOp.getPermutationMap().isMinorIdentity())
638+
return failure();
639+
// Must be a broadcast of a scalar.
640+
auto broadcastOp = xferOp.getVector().getDefiningOp<vector::BroadcastOp>();
641+
if (!broadcastOp || broadcastOp.getSource().getType().isa<VectorType>())
642+
return failure();
643+
// Construct a scalar store.
644+
if (xferOp.getSource().getType().isa<MemRefType>()) {
645+
rewriter.replaceOpWithNewOp<memref::StoreOp>(
646+
xferOp, broadcastOp.getSource(), xferOp.getSource(),
647+
xferOp.getIndices());
648+
} else {
649+
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
650+
xferOp, broadcastOp.getSource(), xferOp.getSource(),
651+
xferOp.getIndices());
652+
}
653+
return success();
654+
}
655+
};
559656
} // namespace
560657

561658
void mlir::vector::transferOpflowOpt(Operation *rootOp) {
@@ -574,6 +671,13 @@ void mlir::vector::transferOpflowOpt(Operation *rootOp) {
574671
opt.removeDeadOp();
575672
}
576673

674+
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
675+
RewritePatternSet &patterns, PatternBenefit benefit) {
676+
patterns
677+
.add<FoldScalarExtractOfTransferRead, FoldScalarTransferWriteOfBroadcast>(
678+
patterns.getContext(), benefit);
679+
}
680+
577681
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
578682
RewritePatternSet &patterns, PatternBenefit benefit) {
579683
patterns
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-opt %s -test-scalar-vector-transfer-lowering -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func @transfer_read_0d(
4+
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index
5+
// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[idx]]]
6+
// CHECK: return %[[r]]
7+
func.func @transfer_read_0d(%m: memref<?x?x?xf32>, %idx: index) -> f32 {
8+
%cst = arith.constant 0.0 : f32
9+
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst : memref<?x?x?xf32>, vector<f32>
10+
%1 = vector.extractelement %0[] : vector<f32>
11+
return %1 : f32
12+
}
13+
14+
// -----
15+
16+
// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
17+
// CHECK-LABEL: func @transfer_read_1d(
18+
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[idx2:.*]]: index
19+
// CHECK: %[[added:.*]] = affine.apply #[[$map]]()[%[[idx]], %[[idx2]]]
20+
// CHECK: %[[r:.*]] = memref.load %[[m]][%[[idx]], %[[idx]], %[[added]]]
21+
// CHECK: return %[[r]]
22+
func.func @transfer_read_1d(%m: memref<?x?x?xf32>, %idx: index, %idx2: index) -> f32 {
23+
%cst = arith.constant 0.0 : f32
24+
%c0 = arith.constant 0 : index
25+
%0 = vector.transfer_read %m[%idx, %idx, %idx], %cst {in_bounds = [true]} : memref<?x?x?xf32>, vector<5xf32>
26+
%1 = vector.extractelement %0[%idx2 : index] : vector<5xf32>
27+
return %1 : f32
28+
}
29+
30+
// -----
31+
32+
// CHECK-LABEL: func @tensor_transfer_read_0d(
33+
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index
34+
// CHECK: %[[r:.*]] = tensor.extract %[[t]][%[[idx]], %[[idx]], %[[idx]]]
35+
// CHECK: return %[[r]]
36+
func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
37+
%cst = arith.constant 0.0 : f32
38+
%0 = vector.transfer_read %t[%idx, %idx, %idx], %cst : tensor<?x?x?xf32>, vector<f32>
39+
%1 = vector.extractelement %0[] : vector<f32>
40+
return %1 : f32
41+
}
42+
43+
// -----
44+
45+
// CHECK-LABEL: func @transfer_write_0d(
46+
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
47+
// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
48+
func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
49+
%0 = vector.broadcast %f : f32 to vector<f32>
50+
vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
51+
return
52+
}
53+
54+
// -----
55+
56+
// CHECK-LABEL: func @transfer_write_1d(
57+
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
58+
// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
59+
func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
60+
%0 = vector.broadcast %f : f32 to vector<1xf32>
61+
vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<1xf32>, memref<?x?x?xf32>
62+
return
63+
}
64+
65+
// -----
66+
67+
// CHECK-LABEL: func @tensor_transfer_write_0d(
68+
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
69+
// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
70+
// CHECK: return %[[r]]
71+
func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
72+
%0 = vector.broadcast %f : f32 to vector<f32>
73+
%1 = vector.transfer_write %0, %t[%idx, %idx, %idx] : vector<f32>, tensor<?x?x?xf32>
74+
return %1 : tensor<?x?x?xf32>
75+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,33 @@ struct TestVectorTransferFullPartialSplitPatterns
462462
}
463463
};
464464

465+
struct TestScalarVectorTransferLoweringPatterns
466+
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
467+
OperationPass<func::FuncOp>> {
468+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
469+
TestScalarVectorTransferLoweringPatterns)
470+
471+
StringRef getArgument() const final {
472+
return "test-scalar-vector-transfer-lowering";
473+
}
474+
StringRef getDescription() const final {
475+
return "Test lowering of scalar vector transfers to memref loads/stores.";
476+
}
477+
TestScalarVectorTransferLoweringPatterns() = default;
478+
479+
void getDependentDialects(DialectRegistry &registry) const override {
480+
registry.insert<AffineDialect, memref::MemRefDialect, tensor::TensorDialect,
481+
vector::VectorDialect>();
482+
}
483+
484+
void runOnOperation() override {
485+
MLIRContext *ctx = &getContext();
486+
RewritePatternSet patterns(ctx);
487+
vector::populateScalarVectorTransferLoweringPatterns(patterns);
488+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
489+
}
490+
};
491+
465492
struct TestVectorTransferOpt
466493
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
467494
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
@@ -869,6 +896,8 @@ void registerTestVectorLowerings() {
869896

870897
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
871898

899+
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
900+
872901
PassRegistration<TestVectorTransferOpt>();
873902

874903
PassRegistration<TestVectorTransferLoweringPatterns>();

0 commit comments

Comments
 (0)