Skip to content

Commit 86115e7

Browse files
committed
Fix impl and address comments
1 parent 8b0c534 commit 86115e7

File tree

2 files changed

+64
-53
lines changed

2 files changed

+64
-53
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1616
#include "mlir/Dialect/SCF/IR/SCF.h"
1717
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/IR/PatternMatch.h"
@@ -52,42 +53,24 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
5253
}
5354

5455
static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
55-
vector::MaskedLoadOp maskedOp) {
56+
vector::MaskedLoadOp maskedOp,
57+
bool passthru) {
5658
VectorType vectorType = maskedOp.getVectorType();
5759
Value load = builder.create<vector::LoadOp>(
5860
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59-
Value res = builder.create<arith::SelectOp>(
60-
loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru());
61-
return res;
61+
if (passthru)
62+
load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
63+
load, maskedOp.getPassThru());
64+
return load;
6265
}
6366

64-
/// Check if the given value comes from a:
65-
///
66-
/// arith.select %cond, TRUE/FALSE, TRUE/FALSE
67-
///
68-
/// i.e the condition is either always true or it's always false.
69-
///
70-
/// Returns the condition to use for scf.if (condition) { true } else { false }.
71-
static FailureOr<Value> matchFullSelect(OpBuilder &b, Value val) {
72-
auto selectOp = val.getDefiningOp<arith::SelectOp>();
73-
if (!selectOp)
74-
return failure();
75-
std::optional<int64_t> trueInt = getConstantIntValue(selectOp.getTrueValue());
76-
std::optional<int64_t> falseInt =
77-
getConstantIntValue(selectOp.getFalseValue());
78-
if (!trueInt || !falseInt)
67+
/// Check if the given value comes from a broadcasted i1 condition.
68+
static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
69+
auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
70+
if (!broadcastOp)
7971
return failure();
80-
// getConstantIntValue returns -1 for "true" for bools.
81-
if (trueInt.value() == -1 && falseInt.value() == 0)
82-
return selectOp.getCondition();
83-
84-
if (trueInt.value() == 0 && falseInt.value() == -1) {
85-
Value cond = selectOp.getCondition();
86-
Value one = b.create<arith::ConstantIntOp>(cond.getLoc(), /*value=*/true,
87-
/*width=*/1);
88-
Value inverse = b.create<arith::XOrIOp>(cond.getLoc(), cond, one);
89-
return inverse;
90-
}
72+
if (!isa<VectorType>(broadcastOp.getSourceType()))
73+
return broadcastOp.getSource();
9174
return failure();
9275
}
9376

@@ -109,11 +92,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
10992
}
11093

11194
// Check if this is either a full inbounds load or an empty, oob load. If
112-
// so, take the fast path and don't generate a if condition, because we know
113-
// doing the oob load is always safe.
114-
if (succeeded(matchFullSelect(rewriter, maskedOp.getMask()))) {
115-
Value load =
116-
createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp);
95+
// so, take the fast path and don't generate an if condition, because we
96+
// know doing the oob load is always safe.
97+
if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
98+
Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
99+
maskedOp, /*passthru=*/true);
117100
rewriter.replaceOp(maskedOp, load);
118101
return success();
119102
}
@@ -175,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
175158
};
176159

177160
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
178-
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp);
161+
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
162+
/*passthru=*/true);
179163
rewriter.create<scf::YieldOp>(loc, res);
180164
};
181165

@@ -192,17 +176,17 @@ struct FullMaskedLoadToConditionalLoad
192176
: OpRewritePattern<vector::MaskedLoadOp> {
193177
using OpRewritePattern::OpRewritePattern;
194178

195-
public:
196179
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
197180
PatternRewriter &rewriter) const override {
198-
FailureOr<Value> maybeCond = matchFullSelect(rewriter, loadOp.getMask());
181+
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
199182
if (failed(maybeCond)) {
200183
return failure();
201184
}
202185

203186
Value cond = maybeCond.value();
204187
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205-
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp);
188+
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
189+
/*passthru=*/false);
206190
rewriter.create<scf::YieldOp>(loc, res);
207191
};
208192
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
@@ -219,10 +203,9 @@ struct FullMaskedStoreToConditionalStore
219203
: OpRewritePattern<vector::MaskedStoreOp> {
220204
using OpRewritePattern::OpRewritePattern;
221205

222-
public:
223206
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
224207
PatternRewriter &rewriter) const override {
225-
FailureOr<Value> maybeCond = matchFullSelect(rewriter, storeOp.getMask());
208+
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
226209
if (failed(maybeCond)) {
227210
return failure();
228211
}

mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,46 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
124124

125125
// -----
126126

127-
func.func @full_select_maskedload_fatrawbuffer_to_load(%mem : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> {
128-
%true = arith.constant dense<true> : vector<4xi1>
129-
%false = arith.constant dense<false> : vector<4xi1>
130-
%mask = arith.select %cond, %true, %false : vector<4xi1>
131-
%res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
132-
return %res : vector<4xf16>
127+
// CHECK-LABEL: func.func @full_select_maskedload_fatrawbuffer_to_load
128+
func.func @full_select_maskedload_fatrawbuffer_to_load(%arg0: memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
129+
%0 = vector.broadcast %arg2 : i1 to vector<4xi1>
130+
%1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi1>, vector<4xf16> into vector<4xf16>
131+
return %1 : vector<4xf16>
133132
}
133+
// CHECK-NOT: vector.maskedload
134+
// CHECK: vector.load
135+
// CHECK: arith.select
134136

135-
func.func @full_select_maskedload_to_load(%mem : memref<8x8xf16>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> {
136-
%true = arith.constant dense<true> : vector<4xi1>
137-
%false = arith.constant dense<false> : vector<4xi1>
138-
%mask = arith.select %cond, %true, %false : vector<4xi1>
139-
%res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
140-
return %res : vector<4xf16>
137+
// -----
138+
139+
// CHECK-LABEL: func.func @full_select_maskedload_to_load
140+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
141+
// CHECK-SAME: %[[IDX:.+]]: index,
142+
// CHECK-SAME: %[[PRED:.+]]: i1,
143+
// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
144+
func.func @full_select_maskedload_to_load(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
145+
%0 = vector.broadcast %arg2 : i1 to vector<4xi1>
146+
%1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
147+
return %1 : vector<4xf16>
148+
}
149+
// CHECK-NOT: vector.maskedload
150+
// CHECK: scf.if %[[PRED]]
151+
// CHECK: %[[LOAD:.+]] = vector.load
152+
// CHECK: scf.yield %[[LOAD]]
153+
// CHECK: else
154+
// CHECK: scf.yield %[[PASSTHRU]]
155+
156+
// -----
157+
158+
// CHECK-LABEL: func.func @full_mask_maskedstore_to_store
159+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
160+
// CHECK-SAME: %[[IDX:.+]]: index,
161+
// CHECK-SAME: %[[PRED:.+]]: i1,
162+
func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) {
163+
%0 = vector.broadcast %arg2 : i1 to vector<4xi1>
164+
vector.maskedstore %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16>
165+
return
141166
}
167+
// CHECK-NOT: vector.maskedstore
168+
// CHECK: scf.if %[[PRED]]
169+
// CHECK: vector.store

0 commit comments

Comments
 (0)