@@ -61,6 +61,36 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
61
61
return res;
62
62
}
63
63
64
+ // / Check if the given value comes from a:
65
+ // /
66
+ // / arith.select %cond, TRUE/FALSE, TRUE/FALSE
67
+ // /
68
+ // / i.e the condition is either always true or it's always false.
69
+ // /
70
+ // / Returns the condition to use for scf.if (condition) { true } else { false }.
71
+ static FailureOr<Value> matchFullSelect (OpBuilder &b, Value val) {
72
+ auto selectOp = val.getDefiningOp <arith::SelectOp>();
73
+ if (!selectOp)
74
+ return failure ();
75
+ std::optional<int64_t > trueInt = getConstantIntValue (selectOp.getTrueValue ());
76
+ std::optional<int64_t > falseInt =
77
+ getConstantIntValue (selectOp.getFalseValue ());
78
+ if (!trueInt || !falseInt)
79
+ return failure ();
80
+ // getConstantIntValue returns -1 for "true" for bools.
81
+ if (trueInt.value () == -1 && falseInt.value () == 0 )
82
+ return selectOp.getCondition ();
83
+
84
+ if (trueInt.value () == 0 && falseInt.value () == -1 ) {
85
+ Value cond = selectOp.getCondition ();
86
+ Value one = b.create <arith::ConstantIntOp>(cond.getLoc (), /* value=*/ true ,
87
+ /* width=*/ 1 );
88
+ Value inverse = b.create <arith::XOrIOp>(cond.getLoc (), cond, one);
89
+ return inverse;
90
+ }
91
+ return failure ();
92
+ }
93
+
64
94
static constexpr char kMaskedloadNeedsMask [] =
65
95
" amdgpu.buffer_maskedload_needs_mask" ;
66
96
@@ -78,6 +108,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
78
108
return failure ();
79
109
}
80
110
111
+ // Check if this is either a full inbounds load or an empty, oob load. If
112
+ // so, take the fast path and don't generate a if condition, because we know
113
+ // doing the oob load is always safe.
114
+ if (succeeded (matchFullSelect (rewriter, maskedOp.getMask ()))) {
115
+ Value load =
116
+ createVectorLoadForMaskedLoad (rewriter, maskedOp.getLoc (), maskedOp);
117
+ rewriter.replaceOp (maskedOp, load);
118
+ return success ();
119
+ }
120
+
81
121
Location loc = maskedOp.getLoc ();
82
122
Value src = maskedOp.getBase ();
83
123
@@ -148,11 +188,64 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
148
188
}
149
189
};
150
190
191
+ struct FullMaskedLoadToConditionalLoad
192
+ : OpRewritePattern<vector::MaskedLoadOp> {
193
+ using OpRewritePattern::OpRewritePattern;
194
+
195
+ public:
196
+ LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
197
+ PatternRewriter &rewriter) const override {
198
+ FailureOr<Value> maybeCond = matchFullSelect (rewriter, loadOp.getMask ());
199
+ if (failed (maybeCond)) {
200
+ return failure ();
201
+ }
202
+
203
+ Value cond = maybeCond.value ();
204
+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205
+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp);
206
+ rewriter.create <scf::YieldOp>(loc, res);
207
+ };
208
+ auto falseBuilder = [&](OpBuilder &builder, Location loc) {
209
+ rewriter.create <scf::YieldOp>(loc, loadOp.getPassThru ());
210
+ };
211
+ auto ifOp = rewriter.create <scf::IfOp>(loadOp.getLoc (), cond, trueBuilder,
212
+ falseBuilder);
213
+ rewriter.replaceOp (loadOp, ifOp);
214
+ return success ();
215
+ }
216
+ };
217
+
218
+ struct FullMaskedStoreToConditionalStore
219
+ : OpRewritePattern<vector::MaskedStoreOp> {
220
+ using OpRewritePattern::OpRewritePattern;
221
+
222
+ public:
223
+ LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
224
+ PatternRewriter &rewriter) const override {
225
+ FailureOr<Value> maybeCond = matchFullSelect (rewriter, storeOp.getMask ());
226
+ if (failed (maybeCond)) {
227
+ return failure ();
228
+ }
229
+ Value cond = maybeCond.value ();
230
+
231
+ auto trueBuilder = [&](OpBuilder &builder, Location loc) {
232
+ rewriter.create <vector::StoreOp>(loc, storeOp.getValueToStore (),
233
+ storeOp.getBase (), storeOp.getIndices ());
234
+ rewriter.create <scf::YieldOp>(loc);
235
+ };
236
+ auto ifOp = rewriter.create <scf::IfOp>(storeOp.getLoc (), cond, trueBuilder);
237
+ rewriter.replaceOp (storeOp, ifOp);
238
+ return success ();
239
+ }
240
+ };
241
+
151
242
} // namespace
152
243
153
244
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns (
154
245
RewritePatternSet &patterns, PatternBenefit benefit) {
155
- patterns.add <MaskedLoadLowering>(patterns.getContext (), benefit);
246
+ patterns.add <MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
247
+ FullMaskedStoreToConditionalStore>(patterns.getContext (),
248
+ benefit);
156
249
}
157
250
158
251
struct AmdgpuMaskedloadToLoadPass final
0 commit comments