@@ -52,13 +52,15 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
52
52
}
53
53
54
54
static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
55
- vector::MaskedLoadOp maskedOp) {
55
+ vector::MaskedLoadOp maskedOp,
56
+ bool passthru) {
56
57
VectorType vectorType = maskedOp.getVectorType ();
57
58
Value load = builder.create <vector::LoadOp>(
58
59
loc, vectorType, maskedOp.getBase (), maskedOp.getIndices ());
59
- Value res = builder.create <arith::SelectOp>(
60
- loc, vectorType, maskedOp.getMask (), load, maskedOp.getPassThru ());
61
- return res;
60
+ if (passthru)
61
+ load = builder.create <arith::SelectOp>(loc, vectorType, maskedOp.getMask (),
62
+ load, maskedOp.getPassThru ());
63
+ return load;
62
64
}
63
65
64
66
// / Check if the given value comes from a:
@@ -112,8 +114,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
112
114
// so, take the fast path and don't generate a if condition, because we know
113
115
// doing the oob load is always safe.
114
116
if (succeeded (matchFullSelect (rewriter, maskedOp.getMask ()))) {
115
- Value load =
116
- createVectorLoadForMaskedLoad (rewriter, maskedOp. getLoc (), maskedOp );
117
+ Value load = createVectorLoadForMaskedLoad (rewriter, maskedOp. getLoc (),
118
+ maskedOp, /* passthru= */ true );
117
119
rewriter.replaceOp (maskedOp, load);
118
120
return success ();
119
121
}
@@ -175,7 +177,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
175
177
};
176
178
177
179
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
178
- Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp);
180
+ Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp,
181
+ /* passthru=*/ true );
179
182
rewriter.create <scf::YieldOp>(loc, res);
180
183
};
181
184
@@ -202,7 +205,8 @@ struct FullMaskedLoadToConditionalLoad
202
205
203
206
Value cond = maybeCond.value ();
204
207
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205
- Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp);
208
+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp,
209
+ /* passthru=*/ false );
206
210
rewriter.create <scf::YieldOp>(loc, res);
207
211
};
208
212
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
0 commit comments