15
15
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
16
16
#include " mlir/Dialect/SCF/IR/SCF.h"
17
17
#include " mlir/Dialect/Vector/IR/VectorOps.h"
18
+ #include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
18
19
#include " mlir/IR/BuiltinTypes.h"
19
20
#include " mlir/IR/OpDefinition.h"
20
21
#include " mlir/IR/PatternMatch.h"
@@ -52,42 +53,24 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
52
53
}
53
54
54
55
static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
55
- vector::MaskedLoadOp maskedOp) {
56
+ vector::MaskedLoadOp maskedOp,
57
+ bool passthru) {
56
58
VectorType vectorType = maskedOp.getVectorType ();
57
59
Value load = builder.create <vector::LoadOp>(
58
60
loc, vectorType, maskedOp.getBase (), maskedOp.getIndices ());
59
- Value res = builder.create <arith::SelectOp>(
60
- loc, vectorType, maskedOp.getMask (), load, maskedOp.getPassThru ());
61
- return res;
61
+ if (passthru)
62
+ load = builder.create <arith::SelectOp>(loc, vectorType, maskedOp.getMask (),
63
+ load, maskedOp.getPassThru ());
64
+ return load;
62
65
}
63
66
64
- // / Check if the given value comes from a:
65
- // /
66
- // / arith.select %cond, TRUE/FALSE, TRUE/FALSE
67
- // /
68
- // / i.e the condition is either always true or it's always false.
69
- // /
70
- // / Returns the condition to use for scf.if (condition) { true } else { false }.
71
- static FailureOr<Value> matchFullSelect (OpBuilder &b, Value val) {
72
- auto selectOp = val.getDefiningOp <arith::SelectOp>();
73
- if (!selectOp)
74
- return failure ();
75
- std::optional<int64_t > trueInt = getConstantIntValue (selectOp.getTrueValue ());
76
- std::optional<int64_t > falseInt =
77
- getConstantIntValue (selectOp.getFalseValue ());
78
- if (!trueInt || !falseInt)
67
+ // / Check if the given value comes from a broadcasted i1 condition.
68
+ static FailureOr<Value> matchFullMask (OpBuilder &b, Value val) {
69
+ auto broadcastOp = val.getDefiningOp <vector::BroadcastOp>();
70
+ if (!broadcastOp)
79
71
return failure ();
80
- // getConstantIntValue returns -1 for "true" for bools.
81
- if (trueInt.value () == -1 && falseInt.value () == 0 )
82
- return selectOp.getCondition ();
83
-
84
- if (trueInt.value () == 0 && falseInt.value () == -1 ) {
85
- Value cond = selectOp.getCondition ();
86
- Value one = b.create <arith::ConstantIntOp>(cond.getLoc (), /* value=*/ true ,
87
- /* width=*/ 1 );
88
- Value inverse = b.create <arith::XOrIOp>(cond.getLoc (), cond, one);
89
- return inverse;
90
- }
72
+ if (!isa<VectorType>(broadcastOp.getSourceType ()))
73
+ return broadcastOp.getSource ();
91
74
return failure ();
92
75
}
93
76
@@ -109,11 +92,11 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
109
92
}
110
93
111
94
// Check if this is either a full inbounds load or an empty, oob load. If
112
- // so, take the fast path and don't generate a if condition, because we know
113
- // doing the oob load is always safe.
114
- if (succeeded (matchFullSelect (rewriter, maskedOp.getMask ()))) {
115
- Value load =
116
- createVectorLoadForMaskedLoad (rewriter, maskedOp. getLoc (), maskedOp );
95
+ // so, take the fast path and don't generate an if condition, because we
96
+ // know doing the oob load is always safe.
97
+ if (succeeded (matchFullMask (rewriter, maskedOp.getMask ()))) {
98
+ Value load = createVectorLoadForMaskedLoad (rewriter, maskedOp. getLoc (),
99
+ maskedOp, /* passthru= */ true );
117
100
rewriter.replaceOp (maskedOp, load);
118
101
return success ();
119
102
}
@@ -175,7 +158,8 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
175
158
};
176
159
177
160
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
178
- Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp);
161
+ Value res = createVectorLoadForMaskedLoad (builder, loc, maskedOp,
162
+ /* passthru=*/ true );
179
163
rewriter.create <scf::YieldOp>(loc, res);
180
164
};
181
165
@@ -192,17 +176,17 @@ struct FullMaskedLoadToConditionalLoad
192
176
: OpRewritePattern<vector::MaskedLoadOp> {
193
177
using OpRewritePattern::OpRewritePattern;
194
178
195
- public:
196
179
LogicalResult matchAndRewrite (vector::MaskedLoadOp loadOp,
197
180
PatternRewriter &rewriter) const override {
198
- FailureOr<Value> maybeCond = matchFullSelect (rewriter, loadOp.getMask ());
181
+ FailureOr<Value> maybeCond = matchFullMask (rewriter, loadOp.getMask ());
199
182
if (failed (maybeCond)) {
200
183
return failure ();
201
184
}
202
185
203
186
Value cond = maybeCond.value ();
204
187
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
205
- Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp);
188
+ Value res = createVectorLoadForMaskedLoad (builder, loc, loadOp,
189
+ /* passthru=*/ false );
206
190
rewriter.create <scf::YieldOp>(loc, res);
207
191
};
208
192
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
@@ -219,10 +203,9 @@ struct FullMaskedStoreToConditionalStore
219
203
: OpRewritePattern<vector::MaskedStoreOp> {
220
204
using OpRewritePattern::OpRewritePattern;
221
205
222
- public:
223
206
LogicalResult matchAndRewrite (vector::MaskedStoreOp storeOp,
224
207
PatternRewriter &rewriter) const override {
225
- FailureOr<Value> maybeCond = matchFullSelect (rewriter, storeOp.getMask ());
208
+ FailureOr<Value> maybeCond = matchFullMask (rewriter, storeOp.getMask ());
226
209
if (failed (maybeCond)) {
227
210
return failure ();
228
211
}
0 commit comments