15
15
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
16
16
#include " mlir/Dialect/SCF/IR/SCF.h"
17
17
#include " mlir/Dialect/Vector/IR/VectorOps.h"
18
+ #include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18
19
#include " mlir/IR/BuiltinTypes.h"
19
20
#include " mlir/IR/OpDefinition.h"
20
21
#include " mlir/IR/PatternMatch.h"
@@ -52,13 +53,25 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
52
53
}
53
54
54
55
static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
55
- vector::MaskedLoadOp maskedOp) {
56
+ vector::MaskedLoadOp maskedOp,
57
+ bool passthru) {
56
58
VectorType vectorType = maskedOp.getVectorType ();
57
59
Value load = builder.create <vector::LoadOp>(
58
60
loc, vectorType, maskedOp.getBase (), maskedOp.getIndices ());
59
- Value res = builder.create <arith::SelectOp>(
60
- loc, vectorType, maskedOp.getMask (), load, maskedOp.getPassThru ());
61
- return res;
61
+ if (passthru)
62
+ load = builder.create <arith::SelectOp>(loc, vectorType, maskedOp.getMask (),
63
+ load, maskedOp.getPassThru ());
64
+ return load;
65
+ }
66
+
67
+ // / Check if the given value comes from a broadcasted i1 condition.
68
+ static FailureOr<Value> matchFullMask (OpBuilder &b, Value val) {
69
+ auto broadcastOp = val.getDefiningOp <vector::BroadcastOp>();
70
+ if (!broadcastOp)
71
+ return failure ();
72
+ if (isa<VectorType>(broadcastOp.getSourceType ()))
73
+ return failure ();
74
+ return broadcastOp.getSource ();
62
75
}
63
76
64
77
static constexpr char kMaskedloadNeedsMask [] =
@@ -78,6 +91,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
78
91
return failure ();
79
92
}
80
93
94
+ // Check if this is either a full inbounds load or an empty, oob load. If
95
+ // so, take the fast path and don't generate an if condition, because we
96
+ // know doing the oob load is always safe.
97
+ if (succeeded (matchFullMask (rewriter, maskedOp.getMask ()))) {
98
+ Value load = createVectorLoadForMaskedLoad (rewriter, maskedOp.getLoc (),
99
+ maskedOp, /* passthru=*/ true );
100
+ rewriter.replaceOp (maskedOp, load);
101
+ return success ();
102
+ }
103
+
81
104
Location loc = maskedOp.getLoc ();
82
105
Value src = maskedOp.getBase ();
83
106
@@ -135,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
135
158
};
136
159
137
160
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
138
- Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp);
161
+ Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp,
162
+ /* passthru=*/ true );
139
163
rewriter.create <scf::YieldOp>(loc, res);
140
164
};
141
165
@@ -148,11 +172,63 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
148
172
}
149
173
};
150
174
175
+ struct FullMaskedLoadToConditionalLoad
176
+ : OpRewritePattern<vector::MaskedLoadOp> {
177
+ using OpRewritePattern::OpRewritePattern;
178
+
179
+ LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
180
+ PatternRewriter &rewriter) const override {
181
+ FailureOr<Value> maybeCond = matchFullMask (rewriter, loadOp.getMask ());
182
+ if (failed (maybeCond)) {
183
+ return failure ();
184
+ }
185
+
186
+ Value cond = maybeCond.value ();
187
+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
188
+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp,
189
+ /* passthru=*/ false );
190
+ rewriter.create <scf::YieldOp>(loc, res);
191
+ };
192
+ auto falseBuilder = [&](OpBuilder &builder, Location loc) {
193
+ rewriter.create <scf::YieldOp>(loc, loadOp.getPassThru ());
194
+ };
195
+ auto ifOp = rewriter.create <scf::IfOp>(loadOp.getLoc (), cond, trueBuilder,
196
+ falseBuilder);
197
+ rewriter.replaceOp (loadOp, ifOp);
198
+ return success ();
199
+ }
200
+ };
201
+
202
+ struct FullMaskedStoreToConditionalStore
203
+ : OpRewritePattern<vector::MaskedStoreOp> {
204
+ using OpRewritePattern::OpRewritePattern;
205
+
206
+ LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
207
+ PatternRewriter &rewriter) const override {
208
+ FailureOr<Value> maybeCond = matchFullMask (rewriter, storeOp.getMask ());
209
+ if (failed (maybeCond)) {
210
+ return failure ();
211
+ }
212
+ Value cond = maybeCond.value ();
213
+
214
+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
215
+ rewriter.create <vector::StoreOp>(loc, storeOp.getValueToStore (),
216
+ storeOp.getBase (), storeOp.getIndices ());
217
+ rewriter.create <scf::YieldOp>(loc);
218
+ };
219
+ auto ifOp = rewriter.create <scf::IfOp>(storeOp.getLoc (), cond, trueBuilder);
220
+ rewriter.replaceOp (storeOp, ifOp);
221
+ return success ();
222
+ }
223
+ };
224
+
151
225
} // namespace
152
226
153
227
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns (
154
228
RewritePatternSet &patterns, PatternBenefit benefit) {
155
- patterns.add <MaskedLoadLowering>(patterns.getContext (), benefit);
229
+ patterns.add <MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
230
+ FullMaskedStoreToConditionalStore>(patterns.getContext (),
231
+ benefit);
156
232
}
157
233
158
234
struct AmdgpuMaskedloadToLoadPass final
0 commit comments