Skip to content

Commit da495c7

Browse files
committed
no passthru for if case
1 parent 8b0c534 commit da495c7

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
5252
}
5353

5454
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
55-
vector::MaskedLoadOp maskedOp) {
55+
vector::MaskedLoadOp maskedOp,
56+
bool passthru) {
5657
VectorType vectorType = maskedOp.getVectorType();
5758
Value load = builder.create<vector::LoadOp>(
5859
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59-
Value res = builder.create<arith::SelectOp>(
60-
loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
61-
return res;
60+
if (passthru)
61+
load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
62+
load, maskedOp.getPassThru());
63+
return load;
6264
}
6365

6466
/// Check if the given value comes from a:
@@ -112,8 +114,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
112114
// so, take the fast path and don't generate a if condition, because we know
113115
// doing the oob load is always safe.
114116
if (succeeded(matchFullSelect(rewriter, maskedOp.getMask()))) {
115-
Value load =
116-
createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp);
117+
Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
118+
maskedOp, /*passthru=*/true);
117119
rewriter.replaceOp(maskedOp, load);
118120
return success();
119121
}
@@ -175,7 +177,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
175177
};
176178

177179
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
178-
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
180+
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
181+
/*passthru=*/true);
179182
rewriter.create<scf::YieldOp>(loc, res);
180183
};
181184

@@ -202,7 +205,8 @@ struct FullMaskedLoadToConditionalLoad
202205

203206
Value cond = maybeCond.value();
204207
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205-
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp);
208+
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
209+
/*passthru=*/false);
206210
rewriter.create<scf::YieldOp>(loc, res);
207211
};
208212
auto falseBuilder = [&](OpBuilder &builder, Location loc) {

0 commit comments

Comments
 (0)