@@ -47,12 +47,57 @@ using namespace mlir::transform::gpu;
47
47
#define LDBG (X ) LLVM_DEBUG(DBGS() << (X) << " \n " )
48
48
#define DBGS_ALIAS () (llvm::dbgs() << ' [' << DEBUG_TYPE_ALIAS << " ] " )
49
49
50
+ // / Build predicates to filter execution by only the activeIds. Along each
51
+ // / dimension, 3 cases appear:
52
+ // / 1. activeMappingSize > availableMappingSize: this is an unsupported case
53
+ // / as this requires additional looping. An error message is produced to
54
+ // / advise the user to tile more or to use more threads.
55
+ // / 2. activeMappingSize == availableMappingSize: no predication is needed.
56
+ // / 3. activeMappingSize < availableMappingSize: only a subset of threads
57
+ // / should be active and we produce the boolean `id < activeMappingSize`
58
+ // / for further use in building predicated execution.
59
+ static FailureOr<SmallVector<Value>>
60
+ buildPredicates (RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds,
61
+ ArrayRef<int64_t > activeMappingSizes,
62
+ ArrayRef<int64_t > availableMappingSizes,
63
+ std::string &errorMsg) {
64
+ // clang-format off
65
+ LLVM_DEBUG (
66
+ llvm::interleaveComma (
67
+ activeMappingSizes, DBGS () << " ----activeMappingSizes: " );
68
+ DBGS () << " \n " ;
69
+ llvm::interleaveComma (
70
+ availableMappingSizes, DBGS () << " ----availableMappingSizes: " );
71
+ DBGS () << " \n " ;);
72
+ // clang-format on
73
+
74
+ SmallVector<Value> predicateOps;
75
+ for (auto [activeId, activeMappingSize, availableMappingSize] :
76
+ llvm::zip_equal (activeIds, activeMappingSizes, availableMappingSizes)) {
77
+ if (activeMappingSize > availableMappingSize) {
78
+ errorMsg = " Trying to map to fewer GPU threads than loop iterations but "
79
+ " overprovisioning is not yet supported. Try additional tiling "
80
+ " before mapping or map to more threads." ;
81
+ return failure ();
82
+ }
83
+ if (activeMappingSize == availableMappingSize)
84
+ continue ;
85
+ Value idx = rewriter.create <arith::ConstantIndexOp>(loc, activeMappingSize);
86
+ Value pred = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
87
+ activeId, idx);
88
+ predicateOps.push_back (pred);
89
+ }
90
+ return predicateOps;
91
+ }
92
+
50
93
// / Return a flattened thread id for the workgroup with given sizes.
51
94
template <typename ThreadOrBlockIdOp>
52
95
static Value buildLinearId (RewriterBase &rewriter, Location loc,
53
96
ArrayRef<OpFoldResult> originalBasisOfr) {
54
- LLVM_DEBUG (DBGS () << " ----buildLinearId with originalBasisOfr: "
55
- << llvm::interleaved (originalBasisOfr) << " \n " );
97
+ LLVM_DEBUG (llvm::interleaveComma (
98
+ originalBasisOfr,
99
+ DBGS () << " ----buildLinearId with originalBasisOfr: " );
100
+ llvm::dbgs () << " \n " );
56
101
assert (originalBasisOfr.size () == 3 && " expected 3 sizes" );
57
102
IndexType indexType = rewriter.getIndexType ();
58
103
AffineExpr tx, ty, tz, bdx, bdy;
@@ -79,44 +124,43 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
79
124
auto res = [multiplicity](RewriterBase &rewriter, Location loc,
80
125
ArrayRef<int64_t > forallMappingSizes,
81
126
ArrayRef<int64_t > originalBasis) {
127
+ // 1. Compute linearId.
82
128
SmallVector<OpFoldResult> originalBasisOfr =
83
129
getAsIndexOpFoldResult (rewriter.getContext (), originalBasis);
84
- OpFoldResult linearId =
130
+ Value physicalLinearId =
85
131
buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
132
+
133
+ // 2. Compute scaledLinearId.
134
+ AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
135
+ OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply (
136
+ rewriter, loc, d0.floorDiv (multiplicity), {physicalLinearId});
137
+
138
+ // 3. Compute remapped indices.
139
+ SmallVector<Value> ids;
86
140
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
87
141
// "row-major" order.
88
142
SmallVector<int64_t > reverseBasisSizes (llvm::reverse (forallMappingSizes));
89
143
SmallVector<int64_t > strides = computeStrides (reverseBasisSizes);
90
- AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
91
- OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply (
92
- rewriter, loc, d0.floorDiv (multiplicity), {linearId});
93
144
SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, strides);
94
- SmallVector<Value> ids;
95
145
// Reverse back to be in [0 .. n] order.
96
146
for (AffineExpr e : llvm::reverse (delinearizingExprs)) {
97
147
ids.push_back (
98
148
affine::makeComposedAffineApply (rewriter, loc, e, {scaledLinearId}));
99
149
}
100
150
101
- LLVM_DEBUG (DBGS () << " --delinearization basis: "
102
- << llvm::interleaved (reverseBasisSizes) << " \n " ;
103
- DBGS () << " --delinearization strides: "
104
- << llvm::interleaved (strides) << " \n " ;
105
- DBGS () << " --delinearization exprs: "
106
- << llvm::interleaved (delinearizingExprs) << " \n " ;
107
- DBGS () << " --ids: " << llvm::interleaved (ids) << " \n " );
108
-
109
- // Return n-D ids for indexing and 1-D size + id for predicate generation.
110
- return IdBuilderResult{
111
- /* mappingIdOps=*/ ids,
112
- /* availableMappingSizes=*/
113
- SmallVector<int64_t >{computeProduct (originalBasis)},
114
- // `forallMappingSizes` iterate in the scaled basis, they need to be
115
- // scaled back into the original basis to provide tight
116
- // activeMappingSizes quantities for predication.
117
- /* activeMappingSizes=*/
118
- SmallVector<int64_t >{computeProduct (forallMappingSizes) * multiplicity},
119
- /* activeIdOps=*/ SmallVector<Value>{cast<Value>(linearId)}};
151
+ // 4. Handle predicates using physicalLinearId.
152
+ std::string errorMsg;
153
+ SmallVector<Value> predicateOps;
154
+ FailureOr<SmallVector<Value>> maybePredicateOps =
155
+ buildPredicates (rewriter, loc, physicalLinearId,
156
+ computeProduct (forallMappingSizes) * multiplicity,
157
+ computeProduct (originalBasis), errorMsg);
158
+ if (succeeded (maybePredicateOps))
159
+ predicateOps = *maybePredicateOps;
160
+
161
+ return IdBuilderResult{/* errorMsg=*/ errorMsg,
162
+ /* mappingIdOps=*/ ids,
163
+ /* predicateOps=*/ predicateOps};
120
164
};
121
165
122
166
return res;
@@ -143,71 +187,65 @@ static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
143
187
// In the 3-D mapping case, unscale the first dimension by the multiplicity.
144
188
SmallVector<int64_t > forallMappingSizeInOriginalBasis (forallMappingSizes);
145
189
forallMappingSizeInOriginalBasis[0 ] *= multiplicity;
146
- return IdBuilderResult{
147
- /* mappingIdOps=*/ scaledIds,
148
- /* availableMappingSizes=*/ SmallVector<int64_t >{originalBasis},
149
- // `forallMappingSizes` iterate in the scaled basis, they need to be
150
- // scaled back into the original basis to provide tight
151
- // activeMappingSizes quantities for predication.
152
- /* activeMappingSizes=*/
153
- SmallVector<int64_t >{forallMappingSizeInOriginalBasis},
154
- /* activeIdOps=*/ ids};
190
+
191
+ std::string errorMsg;
192
+ SmallVector<Value> predicateOps;
193
+ FailureOr<SmallVector<Value>> maybePredicateOps =
194
+ buildPredicates (rewriter, loc, ids, forallMappingSizeInOriginalBasis,
195
+ originalBasis, errorMsg);
196
+ if (succeeded (maybePredicateOps))
197
+ predicateOps = *maybePredicateOps;
198
+
199
+ return IdBuilderResult{/* errorMsg=*/ errorMsg,
200
+ /* mappingIdOps=*/ scaledIds,
201
+ /* predicateOps=*/ predicateOps};
155
202
};
156
203
return res;
157
204
}
158
205
159
206
// / Create a lane id builder that takes the `originalBasis` and decompose
160
207
// / it in the basis of `forallMappingSizes`. The linear id builder returns an
161
208
// / n-D vector of ids for indexing and 1-D size + id for predicate generation.
162
- static GpuIdBuilderFnType laneIdBuilderFn (int64_t periodicity) {
163
- auto res = [periodicity](RewriterBase &rewriter, Location loc,
164
- ArrayRef<int64_t > forallMappingSizes,
165
- ArrayRef<int64_t > originalBasis) {
209
+ static GpuIdBuilderFnType laneIdBuilderFn (int64_t warpSize) {
210
+ auto res = [warpSize](RewriterBase &rewriter, Location loc,
211
+ ArrayRef<int64_t > forallMappingSizes,
212
+ ArrayRef<int64_t > originalBasis) {
213
+ // 1. Compute linearId.
166
214
SmallVector<OpFoldResult> originalBasisOfr =
167
215
getAsIndexOpFoldResult (rewriter.getContext (), originalBasis);
168
- OpFoldResult linearId =
216
+ Value physicalLinearId =
169
217
buildLinearId<ThreadIdOp>(rewriter, loc, originalBasisOfr);
218
+
219
+ // 2. Compute laneId.
170
220
AffineExpr d0 = getAffineDimExpr (0 , rewriter.getContext ());
171
- linearId = affine::makeComposedFoldedAffineApply (
172
- rewriter, loc, d0 % periodicity , {linearId });
221
+ OpFoldResult laneId = affine::makeComposedFoldedAffineApply (
222
+ rewriter, loc, d0 % warpSize , {physicalLinearId });
173
223
224
+ // 3. Compute remapped indices.
225
+ SmallVector<Value> ids;
174
226
// Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
175
227
// "row-major" order.
176
228
SmallVector<int64_t > reverseBasisSizes (llvm::reverse (forallMappingSizes));
177
229
SmallVector<int64_t > strides = computeStrides (reverseBasisSizes);
178
230
SmallVector<AffineExpr> delinearizingExprs = delinearize (d0, strides);
179
- SmallVector<Value> ids;
180
231
// Reverse back to be in [0 .. n] order.
181
232
for (AffineExpr e : llvm::reverse (delinearizingExprs)) {
182
233
ids.push_back (
183
- affine::makeComposedAffineApply (rewriter, loc, e, {linearId }));
234
+ affine::makeComposedAffineApply (rewriter, loc, e, {laneId }));
184
235
}
185
236
186
- // clang-format off
187
- LLVM_DEBUG (llvm::interleaveComma (reverseBasisSizes,
188
- DBGS () << " --delinearization basis: " );
189
- llvm::dbgs () << " \n " ;
190
- llvm::interleaveComma (strides,
191
- DBGS () << " --delinearization strides: " );
192
- llvm::dbgs () << " \n " ;
193
- llvm::interleaveComma (delinearizingExprs,
194
- DBGS () << " --delinearization exprs: " );
195
- llvm::dbgs () << " \n " ;
196
- llvm::interleaveComma (ids, DBGS () << " --ids: " );
197
- llvm::dbgs () << " \n " ;);
198
- // clang-format on
199
-
200
- // Return n-D ids for indexing and 1-D size + id for predicate generation.
201
- return IdBuilderResult{
202
- /* mappingIdOps=*/ ids,
203
- /* availableMappingSizes=*/
204
- SmallVector<int64_t >{computeProduct (originalBasis)},
205
- // `forallMappingSizes` iterate in the scaled basis, they need to be
206
- // scaled back into the original basis to provide tight
207
- // activeMappingSizes quantities for predication.
208
- /* activeMappingSizes=*/
209
- SmallVector<int64_t >{computeProduct (forallMappingSizes)},
210
- /* activeIdOps=*/ SmallVector<Value>{linearId.get <Value>()}};
237
+ // 4. Handle predicates using laneId.
238
+ std::string errorMsg;
239
+ SmallVector<Value> predicateOps;
240
+ FailureOr<SmallVector<Value>> maybePredicateOps = buildPredicates (
241
+ rewriter, loc, cast<Value>(laneId), computeProduct (forallMappingSizes),
242
+ computeProduct (originalBasis), errorMsg);
243
+ if (succeeded (maybePredicateOps))
244
+ predicateOps = *maybePredicateOps;
245
+
246
+ return IdBuilderResult{/* errorMsg=*/ errorMsg,
247
+ /* mappingIdOps=*/ ids,
248
+ /* predicateOps=*/ predicateOps};
211
249
};
212
250
213
251
return res;
0 commit comments