@@ -259,102 +259,6 @@ void emitReductionOpName(const Halide::Expr& e, const CodegenContext& context) {
259
259
}
260
260
}
261
261
262
- namespace {
263
- // Compute the range of parameter values in a given set. Both sides of the
264
- // range are inclusive.
265
- std::pair<isl::val, isl::val> computeParamRange (isl::set domain, int pos) {
266
- // Coerce the set to the shape [N] -> {[i]: only N here }
267
- domain = domain.params ().from_params ();
268
- domain = domain.project_out (isl::dim_type::param, 0 , pos);
269
- domain = domain.project_out (
270
- isl::dim_type::param, 1 , domain.dim (isl::dim_type::param) - 1 );
271
- domain = domain.insert_dims (isl::dim_type::set, 0 , 1 );
272
-
273
- // Connect parameter to a set dimension [N] -> {[i]: i = N and ...}
274
- auto lspace = isl::local_space (domain.get_space ());
275
- auto paramAff = isl::aff (lspace, isl::dim_type::param, 0 );
276
- auto varAff = isl::aff (lspace, isl::dim_type::set, 0 );
277
- domain = domain & (isl::aff_set (paramAff) == varAff);
278
-
279
- // Remove the remaining parameter to move its constraints to the set dimension
280
- domain = domain.project_out (isl::dim_type::param, 0 , 1 );
281
-
282
- // Get min and max.
283
- auto lower = domain.dim_min (0 );
284
- auto upper = domain.dim_max (0 );
285
-
286
- // Compute the range
287
- CHECK (lower.is_cst () && upper.is_cst ())
288
- << " expected constant lower and upper bounds" ;
289
-
290
- // Without parameters at all, we must have a single piece in the bound PA.
291
- auto lowerPA = isl::PA (lower);
292
- auto upperPA = isl::PA (upper);
293
- CHECK (lowerPA.size () == 1 && upperPA.size () == 1 );
294
-
295
- return std::make_pair (
296
- lowerPA[0 ].second .get_constant_val (),
297
- upperPA[0 ].second .get_constant_val ());
298
- }
299
-
300
- // Given the iteratorMaps, whose domain was affected by the mapping filters, in
301
- // the provided context, compute the range of thread mapping parameters. If
302
- // the statement is not mapped to some threads, they will not appear in the
303
- // result.
304
- std::unordered_map<isl::id, long , isl::IslIdIslHash> activeThreadsInBlock (
305
- const CodegenStatementContext& context) {
306
- auto iterMap = context.iteratorMap ();
307
- auto dom =
308
- iterMap.domain ()
309
- .intersect_params (context.mappedScop .scop ().globalParameterContext )
310
- .params ()
311
- .from_params ();
312
-
313
- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
314
- std::vector<isl::id> threadIds{TX, TY, TZ};
315
- std::unordered_map<isl::id, long , isl::IslIdIslHash> activeThreads;
316
-
317
- for (auto id : threadIds) {
318
- int pos = dom.find_dim_by_id (isl::dim_type::param, id);
319
- if (pos < 0 ) {
320
- continue ;
321
- }
322
- auto range = computeParamRange (dom, pos);
323
- CHECK_EQ (range.first .get_den_si (), 1 ) << " fractional parameters?" ;
324
- CHECK_EQ (range.second .get_den_si (), 1 ) << " fractional parameters?" ;
325
- CHECK_EQ (range.first .get_num_si (), 0 )
326
- << " NYI: active threads starting not from 0" ;
327
-
328
- activeThreads[id] =
329
- range.second .get_num_si () - range.first .get_num_si () + 1 ;
330
- }
331
- return activeThreads;
332
- }
333
-
334
- // Given the iteratorMaps, whose domain was affected by the mapping filters, in
335
- // the provided context, compute the range of thread mapping parameters. If
336
- // the statement is not mapped to some threads, they will _still appear_ in the
337
- // result with the range 1.
338
- std::array<long , 3 > activeThreadsInBlockWithDefaults (
339
- const CodegenStatementContext& context) {
340
- auto active = activeThreadsInBlock (context);
341
- std::array<long , 3 > result;
342
-
343
- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
344
- std::vector<isl::id> threadIds{TX, TY, TZ};
345
- int i = 0 ;
346
- for (auto id : threadIds) {
347
- if (active.count (id) != 1 ) {
348
- result[i] = MappingId::unmapped;
349
- } else {
350
- result[i] = active[id];
351
- }
352
- ++i;
353
- }
354
- return result;
355
- }
356
- } // namespace
357
-
358
262
// Emit a cross-thread tree reduce.
359
263
// For now this is only expected to work with threadIdx.x.
360
264
void emitTreeSyncCall (
@@ -373,25 +277,16 @@ void emitTreeSyncCall(
373
277
std::array<size_t , 3 > dims = {TX.mappingSize (context.mappedScop .numThreads ),
374
278
TY.mappingSize (context.mappedScop .numThreads ),
375
279
TZ.mappingSize (context.mappedScop .numThreads )};
376
- std::array<long , 3 > active = activeThreadsInBlockWithDefaults (context);
377
-
378
- for (int i = 0 ; i < 3 ; ++i) {
379
- if (active[i] < dims[i]) {
380
- LOG (INFO) << " Reduction statement " << updateId << " mapped to "
381
- << dims[i] << " threads along dim: " << i << " but only "
382
- << active[i] << " are non-empty" ;
383
- }
384
- }
385
280
386
281
context.ss << tc::code::cuda::kCUBReductionName ;
387
282
388
283
// Template mapping dimension
389
284
context.ss << " <" ;
390
- context.ss << active [0 ];
285
+ context.ss << dims [0 ];
391
286
context.ss << " ," ;
392
- context.ss << active [1 ];
287
+ context.ss << dims [1 ];
393
288
context.ss << " ," ;
394
- context.ss << active [2 ];
289
+ context.ss << dims [2 ];
395
290
context.ss << " ," ;
396
291
emitReductionOpName (provide->values [0 ], context);
397
292
context.ss << " >(" ;
0 commit comments