Skip to content

Commit ea62de5

Browse files
[mlir] NFC - refactor id builder and avoid leaking impl details (#146922)
1 parent 3277f62 commit ea62de5

File tree

3 files changed

+127
-113
lines changed

3 files changed

+127
-113
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,24 @@ namespace transform {
2828
namespace gpu {
2929

3030
/// Helper type for functions that generate ids for the mapping of a scf.forall.
31-
/// Operates on both 1) an "original" basis that represents the individual
32-
/// thread and block ids and 2) a "scaled" basis that represents grouped ids
33-
/// (e.g. block clusters, warpgroups and warps).
34-
/// The mapping of ids is done in the "scaled" basis (i.e. when mapping to warps
35-
/// a division by 32 occurs).
36-
/// The predication is in the "original" basis using the "active" quantities
37-
/// (`activeMappingSizes`, `availableMappingSizes` and `activeIdOps`).
3831
struct IdBuilderResult {
39-
// Ops used to replace the forall induction variables.
32+
/// Error message, if not empty then building the ids failed.
33+
std::string errorMsg;
34+
/// Values used to replace the forall induction variables.
4035
SmallVector<Value> mappingIdOps;
41-
// Available mapping sizes used to predicate the forall body when they are
42-
// larger than the predicate mapping sizes.
43-
SmallVector<int64_t> availableMappingSizes;
44-
// Actual mapping sizes used to predicate the forall body when they are
45-
// smaller than the available mapping sizes.
46-
SmallVector<int64_t> activeMappingSizes;
47-
// Ops used to predicate the forall body when activeMappingSizes is smaller
48-
// than the available mapping sizes.
49-
SmallVector<Value> activeIdOps;
36+
/// Values used to predicate the forall body when activeMappingSizes is
37+
/// smaller than the available mapping sizes.
38+
SmallVector<Value> predicateOps;
5039
};
5140

41+
inline raw_ostream &operator<<(raw_ostream &os, const IdBuilderResult &res) {
42+
llvm::interleaveComma(res.mappingIdOps, os << "----mappingIdOps: ");
43+
os << "\n";
44+
llvm::interleaveComma(res.predicateOps, os << "----predicateOps: ");
45+
os << "\n";
46+
return os;
47+
}
48+
5249
/// Common gpu id builder type, allows the configuration of lowering for various
5350
/// mapping schemes. Takes:
5451
/// - A rewriter with insertion point set before the forall op to rewrite.

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
491491

492492
IdBuilderResult builderResult =
493493
gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
494+
if (!builderResult.errorMsg.empty())
495+
return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg);
496+
497+
LLVM_DEBUG(DBGS() << builderResult);
494498

495499
// Step 4. Map the induction variables to the mappingIdOps, this may involve
496500
// a permutation.
@@ -501,7 +505,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
501505
forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
502506
auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
503507
Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
504-
LDBG("----map: " << iv << " to" << peIdOp);
508+
LDBG("----map: " << iv << " to " << peIdOp);
505509
bvm.map(iv, peIdOp);
506510
}
507511

@@ -510,32 +514,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
510514
// originalBasis and no predication occurs.
511515
Value predicate;
512516
if (originalBasisWasProvided) {
513-
SmallVector<int64_t> activeMappingSizes = builderResult.activeMappingSizes;
514-
SmallVector<int64_t> availableMappingSizes =
515-
builderResult.availableMappingSizes;
516-
SmallVector<Value> activeIdOps = builderResult.activeIdOps;
517-
LDBG("----activeMappingSizes: " << llvm::interleaved(activeMappingSizes));
518-
LDBG("----availableMappingSizes: "
519-
<< llvm::interleaved(availableMappingSizes));
520-
LDBG("----activeIdOps: " << llvm::interleaved(activeIdOps));
521-
for (auto [activeId, activeMappingSize, availableMappingSize] :
522-
llvm::zip_equal(activeIdOps, activeMappingSizes,
523-
availableMappingSizes)) {
524-
if (activeMappingSize > availableMappingSize) {
525-
return definiteFailureHelper(
526-
transformOp, forallOp,
527-
"Trying to map to fewer GPU threads than loop iterations but "
528-
"overprovisioning is not yet supported. "
529-
"Try additional tiling of the before mapping or map to more "
530-
"threads.");
531-
}
532-
if (activeMappingSize == availableMappingSize)
533-
continue;
534-
Value idx =
535-
rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
536-
Value tmpPredicate = rewriter.create<arith::CmpIOp>(
537-
loc, arith::CmpIPredicate::ult, activeId, idx);
538-
LDBG("----predicate: " << tmpPredicate);
517+
for (Value tmpPredicate : builderResult.predicateOps) {
539518
predicate = predicate ? rewriter.create<arith::AndIOp>(loc, predicate,
540519
tmpPredicate)
541520
: tmpPredicate;

mlir/lib/Dialect/GPU/TransformOps/Utils.cpp

Lines changed: 107 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
4747
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
4848
#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
4949

50+
/// Build predicates to filter execution by only the activeIds. Along each
51+
/// dimension, 3 cases appear:
52+
/// 1. activeMappingSize > availableMappingSize: this is an unsupported case
53+
/// as this requires additional looping. An error message is produced to
54+
/// advise the user to tile more or to use more threads.
55+
/// 2. activeMappingSize == availableMappingSize: no predication is needed.
56+
/// 3. activeMappingSize < availableMappingSize: only a subset of threads
57+
/// should be active and we produce the boolean `id < activeMappingSize`
58+
/// for further use in building predicated execution.
59+
static FailureOr<SmallVector<Value>>
60+
buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
61+
ArrayRef<int64_t> activeMappingSizes,
62+
ArrayRef<int64_t> availableMappingSizes,
63+
std::string &errorMsg) {
64+
// clang-format off
65+
LLVM_DEBUG(
66+
llvm::interleaveComma(
67+
activeMappingSizes, DBGS() << "----activeMappingSizes: ");
68+
DBGS() << "\n";
69+
llvm::interleaveComma(
70+
availableMappingSizes, DBGS() << "----availableMappingSizes: ");
71+
DBGS() << "\n";);
72+
// clang-format on
73+
74+
SmallVector<Value> predicateOps;
75+
for (auto [activeId, activeMappingSize, availableMappingSize] :
76+
llvm::zip_equal(activeIds, activeMappingSizes, availableMappingSizes)) {
77+
if (activeMappingSize > availableMappingSize) {
78+
errorMsg = "Trying to map to fewer GPU threads than loop iterations but "
79+
"overprovisioning is not yet supported. Try additional tiling "
80+
"before mapping or map to more threads.";
81+
return failure();
82+
}
83+
if (activeMappingSize == availableMappingSize)
84+
continue;
85+
Value idx = rewriter.create<arith::ConstantIndexOp>(loc, activeMappingSize);
86+
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
87+
activeId, idx);
88+
predicateOps.push_back(pred);
89+
}
90+
return predicateOps;
91+
}
92+
5093
/// Return a flattened thread id for the workgroup with given sizes.
5194
template <typename ThreadOrBlockIdOp>
5295
static Value buildLinearId(RewriterBase &rewriter, Location loc,
5396
ArrayRef<OpFoldResult> originalBasisOfr) {
54-
LLVM_DEBUG(DBGS() << "----buildLinearId with originalBasisOfr: "
55-
<< llvm::interleaved(originalBasisOfr) << "\n");
97+
LLVM_DEBUG(llvm::interleaveComma(
98+
originalBasisOfr,
99+
DBGS() << "----buildLinearId with originalBasisOfr: ");
100+
llvm::dbgs() << "\n");
56101
assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
57102
IndexType indexType = rewriter.getIndexType();
58103
AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
79124
auto res = [multiplicity](RewriterBase &rewriter, Location loc,
80125
ArrayRef<int64_t> forallMappingSizes,
81126
ArrayRef<int64_t> originalBasis) {
127+
// 1. Compute linearId.
82128
SmallVector<OpFoldResult> originalBasisOfr =
83129
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
84-
OpFoldResult linearId =
130+
Value physicalLinearId =
85131
buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
132+
133+
// 2. Compute scaledLinearId.
134+
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
135+
OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
136+
rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
137+
138+
// 3. Compute remapped indices.
139+
SmallVector<Value> ids;
86140
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
87141
// "row-major" order.
88142
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
89143
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
90-
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
91-
OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
92-
rewriter, loc, d0.floorDiv(multiplicity), {linearId});
93144
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
94-
SmallVector<Value> ids;
95145
// Reverse back to be in [0 .. n] order.
96146
for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
97147
ids.push_back(
98148
affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
99149
}
100150

101-
LLVM_DEBUG(DBGS() << "--delinearization basis: "
102-
<< llvm::interleaved(reverseBasisSizes) << "\n";
103-
DBGS() << "--delinearization strides: "
104-
<< llvm::interleaved(strides) << "\n";
105-
DBGS() << "--delinearization exprs: "
106-
<< llvm::interleaved(delinearizingExprs) << "\n";
107-
DBGS() << "--ids: " << llvm::interleaved(ids) << "\n");
108-
109-
// Return n-D ids for indexing and 1-D size + id for predicate generation.
110-
return IdBuilderResult{
111-
/*mappingIdOps=*/ids,
112-
/*availableMappingSizes=*/
113-
SmallVector<int64_t>{computeProduct(originalBasis)},
114-
// `forallMappingSizes` iterate in the scaled basis, they need to be
115-
// scaled back into the original basis to provide tight
116-
// activeMappingSizes quantities for predication.
117-
/*activeMappingSizes=*/
118-
SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
119-
/*activeIdOps=*/SmallVector<Value>{cast<Value>(linearId)}};
151+
// 4. Handle predicates using physicalLinearId.
152+
std::string errorMsg;
153+
SmallVector<Value> predicateOps;
154+
FailureOr<SmallVector<Value>> maybePredicateOps =
155+
buildPredicates(rewriter, loc, physicalLinearId,
156+
computeProduct(forallMappingSizes) * multiplicity,
157+
computeProduct(originalBasis), errorMsg);
158+
if (succeeded(maybePredicateOps))
159+
predicateOps = *maybePredicateOps;
160+
161+
return IdBuilderResult{/*errorMsg=*/errorMsg,
162+
/*mappingIdOps=*/ids,
163+
/*predicateOps=*/predicateOps};
120164
};
121165

122166
return res;
@@ -143,71 +187,65 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
143187
// In the 3-D mapping case, unscale the first dimension by the multiplicity.
144188
SmallVector<int64_t> forallMappingSizeInOriginalBasis(forallMappingSizes);
145189
forallMappingSizeInOriginalBasis[0] *= multiplicity;
146-
return IdBuilderResult{
147-
/*mappingIdOps=*/scaledIds,
148-
/*availableMappingSizes=*/SmallVector<int64_t>{originalBasis},
149-
// `forallMappingSizes` iterate in the scaled basis, they need to be
150-
// scaled back into the original basis to provide tight
151-
// activeMappingSizes quantities for predication.
152-
/*activeMappingSizes=*/
153-
SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
154-
/*activeIdOps=*/ids};
190+
191+
std::string errorMsg;
192+
SmallVector<Value> predicateOps;
193+
FailureOr<SmallVector<Value>> maybePredicateOps =
194+
buildPredicates(rewriter, loc, ids, forallMappingSizeInOriginalBasis,
195+
originalBasis, errorMsg);
196+
if (succeeded(maybePredicateOps))
197+
predicateOps = *maybePredicateOps;
198+
199+
return IdBuilderResult{/*errorMsg=*/errorMsg,
200+
/*mappingIdOps=*/scaledIds,
201+
/*predicateOps=*/predicateOps};
155202
};
156203
return res;
157204
}
158205

159206
/// Create a lane id builder that takes the `originalBasis` and decompose
160207
/// it in the basis of `forallMappingSizes`. The linear id builder returns an
161208
/// n-D vector of ids for indexing and 1-D size + id for predicate generation.
162-
static GpuIdBuilderFnType laneIdBuilderFn(int64_t periodicity) {
163-
auto res = [periodicity](RewriterBase &rewriter, Location loc,
164-
ArrayRef<int64_t> forallMappingSizes,
165-
ArrayRef<int64_t> originalBasis) {
209+
static GpuIdBuilderFnType laneIdBuilderFn(int64_t warpSize) {
210+
auto res = [warpSize](RewriterBase &rewriter, Location loc,
211+
ArrayRef<int64_t> forallMappingSizes,
212+
ArrayRef<int64_t> originalBasis) {
213+
// 1. Compute linearId.
166214
SmallVector<OpFoldResult> originalBasisOfr =
167215
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
168-
OpFoldResult linearId =
216+
Value physicalLinearId =
169217
buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
218+
219+
// 2. Compute laneId.
170220
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
171-
linearId = affine::makeComposedFoldedAffineApply(
172-
rewriter, loc, d0 % periodicity, {linearId});
221+
OpFoldResult laneId = affine::makeComposedFoldedAffineApply(
222+
rewriter, loc, d0 % warpSize, {physicalLinearId});
173223

224+
// 3. Compute remapped indices.
225+
SmallVector<Value> ids;
174226
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
175227
// "row-major" order.
176228
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
177229
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
178230
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
179-
SmallVector<Value> ids;
180231
// Reverse back to be in [0 .. n] order.
181232
for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
182233
ids.push_back(
183-
affine::makeComposedAffineApply(rewriter, loc, e, {linearId}));
234+
affine::makeComposedAffineApply(rewriter, loc, e, {laneId}));
184235
}
185236

186-
// clang-format off
187-
LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
188-
DBGS() << "--delinearization basis: ");
189-
llvm::dbgs() << "\n";
190-
llvm::interleaveComma(strides,
191-
DBGS() << "--delinearization strides: ");
192-
llvm::dbgs() << "\n";
193-
llvm::interleaveComma(delinearizingExprs,
194-
DBGS() << "--delinearization exprs: ");
195-
llvm::dbgs() << "\n";
196-
llvm::interleaveComma(ids, DBGS() << "--ids: ");
197-
llvm::dbgs() << "\n";);
198-
// clang-format on
199-
200-
// Return n-D ids for indexing and 1-D size + id for predicate generation.
201-
return IdBuilderResult{
202-
/*mappingIdOps=*/ids,
203-
/*availableMappingSizes=*/
204-
SmallVector<int64_t>{computeProduct(originalBasis)},
205-
// `forallMappingSizes` iterate in the scaled basis, they need to be
206-
// scaled back into the original basis to provide tight
207-
// activeMappingSizes quantities for predication.
208-
/*activeMappingSizes=*/
209-
SmallVector<int64_t>{computeProduct(forallMappingSizes)},
210-
/*activeIdOps=*/SmallVector<Value>{linearId.get<Value>()}};
237+
// 4. Handle predicates using laneId.
238+
std::string errorMsg;
239+
SmallVector<Value> predicateOps;
240+
FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates(
241+
rewriter, loc, cast<Value>(laneId), computeProduct(forallMappingSizes),
242+
computeProduct(originalBasis), errorMsg);
243+
if (succeeded(maybePredicateOps))
244+
predicateOps = *maybePredicateOps;
245+
246+
return IdBuilderResult{/*errorMsg=*/errorMsg,
247+
/*mappingIdOps=*/ids,
248+
/*predicateOps=*/predicateOps};
211249
};
212250

213251
return res;

0 commit comments

Comments
 (0)