Skip to content

Commit f964922

Browse files
authored
[mlir][AMDGPU] Add better load/store lowering for full mask (#146748)
This patch adds a better maskedload/maskedstore lowering on amdgpu backend for loads which are either fully masked or fully unmasked. For these cases, we can either generate a oob buffer load with no if condition, or we can generate a normal load with a if condition (if no fat_raw_buffer space).
1 parent 86320e0 commit f964922

File tree

2 files changed

+135
-6
lines changed

2 files changed

+135
-6
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1616
#include "mlir/Dialect/SCF/IR/SCF.h"
1717
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/IR/PatternMatch.h"
@@ -52,13 +53,25 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
5253
}
5354

5455
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
55-
vector::MaskedLoadOp maskedOp) {
56+
vector::MaskedLoadOp maskedOp,
57+
bool passthru) {
5658
VectorType vectorType = maskedOp.getVectorType();
5759
Value load = builder.create<vector::LoadOp>(
5860
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59-
Value res = builder.create<arith::SelectOp>(
60-
loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
61-
return res;
61+
if (passthru)
62+
load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
63+
load, maskedOp.getPassThru());
64+
return load;
65+
}
66+
67+
/// Check if the given value comes from a broadcasted i1 condition.
68+
static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
69+
auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
70+
if (!broadcastOp)
71+
return failure();
72+
if (isa<VectorType>(broadcastOp.getSourceType()))
73+
return failure();
74+
return broadcastOp.getSource();
6275
}
6376

6477
static constexpr char kMaskedloadNeedsMask[] =
@@ -78,6 +91,16 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
7891
return failure();
7992
}
8093

94+
// Check if this is either a full inbounds load or an empty, oob load. If
95+
// so, take the fast path and don't generate an if condition, because we
96+
// know doing the oob load is always safe.
97+
if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
98+
Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
99+
maskedOp, /*passthru=*/true);
100+
rewriter.replaceOp(maskedOp, load);
101+
return success();
102+
}
103+
81104
Location loc = maskedOp.getLoc();
82105
Value src = maskedOp.getBase();
83106

@@ -135,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
135158
};
136159

137160
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
138-
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
161+
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
162+
/*passthru=*/true);
139163
rewriter.create<scf::YieldOp>(loc, res);
140164
};
141165

@@ -148,11 +172,63 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
148172
}
149173
};
150174

175+
struct FullMaskedLoadToConditionalLoad
176+
: OpRewritePattern<vector::MaskedLoadOp> {
177+
using OpRewritePattern::OpRewritePattern;
178+
179+
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
180+
PatternRewriter &rewriter) const override {
181+
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
182+
if (failed(maybeCond)) {
183+
return failure();
184+
}
185+
186+
Value cond = maybeCond.value();
187+
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
188+
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
189+
/*passthru=*/false);
190+
rewriter.create<scf::YieldOp>(loc, res);
191+
};
192+
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
193+
rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
194+
};
195+
auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
196+
falseBuilder);
197+
rewriter.replaceOp(loadOp, ifOp);
198+
return success();
199+
}
200+
};
201+
202+
struct FullMaskedStoreToConditionalStore
203+
: OpRewritePattern<vector::MaskedStoreOp> {
204+
using OpRewritePattern::OpRewritePattern;
205+
206+
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
207+
PatternRewriter &rewriter) const override {
208+
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
209+
if (failed(maybeCond)) {
210+
return failure();
211+
}
212+
Value cond = maybeCond.value();
213+
214+
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
215+
rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
216+
storeOp.getBase(), storeOp.getIndices());
217+
rewriter.create<scf::YieldOp>(loc);
218+
};
219+
auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
220+
rewriter.replaceOp(storeOp, ifOp);
221+
return success();
222+
}
223+
};
224+
151225
} // namespace
152226

153227
void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns(
154228
RewritePatternSet &patterns, PatternBenefit benefit) {
155-
patterns.add<MaskedLoadLowering>(patterns.getContext(), benefit);
229+
patterns.add<MaskedLoadLowering, FullMaskedLoadToConditionalLoad,
230+
FullMaskedStoreToConditionalStore>(patterns.getContext(),
231+
benefit);
156232
}
157233

158234
struct AmdgpuMaskedloadToLoadPass final

mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,56 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
114114
// CHECK: %[[IF:.*]] = scf.if
115115
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[ARG1]]]
116116
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[ARG3]]
117+
118+
// -----
119+
120+
func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> {
121+
%res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf32> into vector<4xf32>
122+
return %res : vector<4xf32>
123+
}
124+
125+
// -----
126+
127+
// CHECK-LABEL: func.func @full_select_maskedload_fatrawbuffer_to_load
128+
func.func @full_select_maskedload_fatrawbuffer_to_load(%arg0: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
129+
%0 = vector.broadcast %arg2 : i1 to vector<4xi1>
130+
%1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
131+
return %1 : vector<4xf16>
132+
}
133+
// CHECK-NOT: vector.maskedload
134+
// CHECK: vector.load
135+
// CHECK: arith.select
136+
137+
// -----
138+
139+
// CHECK-LABEL: func.func @full_select_maskedload_to_load
140+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
141+
// CHECK-SAME: %[[IDX:.+]]: index,
142+
// CHECK-SAME: %[[PRED:.+]]: i1,
143+
// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
144+
func.func @full_select_maskedload_to_load(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
145+
%0 = vector.broadcast %arg2 : i1 to vector<4xi1>
146+
%1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
147+
return %1 : vector<4xf16>
148+
}
149+
// CHECK-NOT: vector.maskedload
150+
// CHECK: scf.if %[[PRED]]
151+
// CHECK: %[[LOAD:.+]] = vector.load
152+
// CHECK: scf.yield %[[LOAD]]
153+
// CHECK: else
154+
// CHECK: scf.yield %[[PASSTHRU]]
155+
156+
// -----
157+
158+
// CHECK-LABEL: func.func @full_mask_maskedstore_to_store
159+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
160+
// CHECK-SAME: %[[IDX:.+]]: index,
161+
// CHECK-SAME: %[[PRED:.+]]: i1,
162+
func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) {
163+
%0 = vector.broadcast %arg2 : i1 to vector<4xi1>
164+
vector.maskedstore %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16>
165+
return
166+
}
167+
// CHECK-NOT: vector.maskedstore
168+
// CHECK: scf.if %[[PRED]]
169+
// CHECK: vector.store

0 commit comments

Comments
 (0)