Skip to content

Commit 8b0c534

Browse files
committed
[mlir][AMDGPU] Add better load/store lowering for full select mask
1 parent e9be528 commit 8b0c534

File tree

2 files changed

+119
-1
lines changed

2 files changed

+119
-1
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,36 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
6161
return res;
6262
}
6363

64+
/// Check if the given value comes from a:
65+
///
66+
/// arith.select %cond, TRUE/FALSE, TRUE/FALSE
67+
///
68+
/// i.e the condition is either always true or it's always false.
69+
///
70+
/// Returns the condition to use for scf.if (condition) { true } else { false }.
71+
static FailureOr<Value> matchFullSelect(OpBuilder &b, Value val) {
72+
auto selectOp = val.getDefiningOp<arith::SelectOp>();
73+
if (!selectOp)
74+
return failure();
75+
std::optional<int64_t> trueInt = getConstantIntValue(selectOp.getTrueValue());
76+
std::optional<int64_t> falseInt =
77+
getConstantIntValue(selectOp.getFalseValue());
78+
if (!trueInt || !falseInt)
79+
return failure();
80+
// getConstantIntValue returns -1 for "true" for bools.
81+
if (trueInt.value() == -1 && falseInt.value() == 0)
82+
return selectOp.getCondition();
83+
84+
if (trueInt.value() == 0 && falseInt.value() == -1) {
85+
Value cond = selectOp.getCondition();
86+
Value one = b.create<arith::ConstantIntOp>(cond.getLoc(), /*value=*/true,
87+
/*width=*/1);
88+
Value inverse = b.create<arith::XOrIOp>(cond.getLoc(), cond, one);
89+
return inverse;
90+
}
91+
return failure();
92+
}
93+
6494
static constexpr char kMaskedloadNeedsMask[] =
6595
"amdgpu.buffer_maskedload_needs_mask";
6696

@@ -78,6 +108,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
78108
return failure();
79109
}
80110

111+
// Check if this is either a full inbounds load or an empty, oob load. If
112+
// so, take the fast path and don't generate a if condition, because we know
113+
// doing the oob load is always safe.
114+
if (succeeded(matchFullSelect(rewriter, maskedOp.getMask()))) {
115+
Value load =
116+
createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp);
117+
rewriter.replaceOp(maskedOp, load);
118+
return success();
119+
}
120+
81121
Location loc = maskedOp.getLoc();
82122
Value src = maskedOp.getBase();
83123

@@ -148,11 +188,64 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
148188
}
149189
};
150190

191+
struct FullMaskedLoadToConditionalLoad
192+
: OpRewritePattern<vector::MaskedLoadOp> {
193+
using OpRewritePattern::OpRewritePattern;
194+
195+
public:
196+
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
197+
PatternRewriter &rewriter) const override {
198+
FailureOr<Value> maybeCond = matchFullSelect(rewriter, loadOp.getMask());
199+
if (failed(maybeCond)) {
200+
return failure();
201+
}
202+
203+
Value cond = maybeCond.value();
204+
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205+
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp);
206+
rewriter.create<scf::YieldOp>(loc, res);
207+
};
208+
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
209+
rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
210+
};
211+
auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
212+
falseBuilder);
213+
rewriter.replaceOp(loadOp, ifOp);
214+
return success();
215+
}
216+
};
217+
218+
struct FullMaskedStoreToConditionalStore
219+
: OpRewritePattern<vector::MaskedStoreOp> {
220+
using OpRewritePattern::OpRewritePattern;
221+
222+
public:
223+
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
224+
PatternRewriter &rewriter) const override {
225+
FailureOr<Value> maybeCond = matchFullSelect(rewriter, storeOp.getMask());
226+
if (failed(maybeCond)) {
227+
return failure();
228+
}
229+
Value cond = maybeCond.value();
230+
231+
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
232+
rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
233+
storeOp.getBase(), storeOp.getIndices());
234+
rewriter.create<scf::YieldOp>(loc);
235+
};
236+
auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
237+
rewriter.replaceOp(storeOp, ifOp);
238+
return success();
239+
}
240+
};
241+
151242
} // namespace
152243

153244
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
154245
RewritePatternSet &patterns, PatternBenefit benefit) {
155-
patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
246+
patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
247+
FullMaskedStoreToConditionalStore>(patterns.getContext(),
248+
benefit);
156249
}
157250

158251
struct AmdgpuMaskedloadToLoadPass final

mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,28 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
114114
// CHECK: %[[IF:.*]] = scf.if
115115
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
116116
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[ARG3]]
117+
118+
// -----
119+
120+
func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
121+
%res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
122+
return %res : vector<4xf32>
123+
}
124+
125+
// -----
126+
127+
func.func @full_select_maskedload_fatrawbuffer_to_load(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> {
128+
%true = arith.constant dense<true> : vector<4xi1>
129+
%false = arith.constant dense<false> : vector<4xi1>
130+
%mask = arith.select %cond, %true, %false : vector<4xi1>
131+
%res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
132+
return %res : vector<4xf16>
133+
}
134+
135+
func.func @full_select_maskedload_to_load(%mem : memref<8x8xf16>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> {
136+
%true = arith.constant dense<true> : vector<4xi1>
137+
%false = arith.constant dense<false> : vector<4xi1>
138+
%mask = arith.select %cond, %true, %false : vector<4xi1>
139+
%res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
140+
return %res : vector<4xf16>
141+
}

0 commit comments

Comments
 (0)