@@ -230,65 +230,50 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
230
230
if (!isReductionMember (member, updates, scop ())) {
231
231
return false ;
232
232
}
233
- auto reductionTree = bandSplitOut (scop_->scheduleRoot (), tree, reductionDim);
234
233
// Order the init statements (if any) before the update statements
235
234
// to ensure the band from which the reduction band has been split off
236
235
// only contains update statements.
237
236
// Note that this relies on the outer members being coincident.
238
237
if (!inits.is_empty ()) {
239
238
orderBefore (scop_->scheduleRoot (), tree, inits);
240
239
}
241
- reductionFromParent_.emplace (tree, reductionTree);
242
- reductionBandUpdates_.emplace (reductionTree, updateIds);
240
+ reductionBandUpdates_.emplace (tree, Reduction (updateIds, reductionDim));
243
241
return true ;
244
242
}
245
243
246
244
bool MappedScop::needReductionSeparation (const detail::ScheduleTree* st) {
247
- // It is the parent band of the reduction band that needs to be separated.
248
- if (reductionFromParent_.count (st) != 1 ) {
245
+ if (reductionBandUpdates_.count (st) != 1 ) {
249
246
return false ;
250
247
}
251
- st = reductionFromParent_.at (st);
252
- CHECK (reductionBandUpdates_.count (st) == 1 );
253
248
// No need to separate if already separated.
254
249
return !reductionBandUpdates_.at (st).separated ;
255
250
}
256
251
257
252
isl::multi_union_pw_aff MappedScop::reductionMapSchedule (
258
253
const detail::ScheduleTree* st) {
259
- CHECK (reductionFromParent_.count (st) == 1 );
260
- auto parent = st;
261
- st = reductionFromParent_.at (st);
262
254
CHECK (reductionBandUpdates_.count (st) == 1 );
263
-
264
255
auto reductionBand = st->elemAs <detail::ScheduleTreeElemBand>();
265
256
CHECK (reductionBand);
266
- // Start with the schedule of the reduction band (in last position).
267
- auto reductionSchedule = reductionBand->mupa_ ;
268
257
269
- // Total size of returned schedule needs to be equal
270
- // to the number of thread identifiers.
271
- if (numThreads.view .size () > 1 ) {
272
- CHECK (parent != st);
273
- }
274
- // Prepend last members of parent band (if any).
275
- if (parent != st) {
276
- auto parentBand = parent->elemAs <detail::ScheduleTreeElemBand>();
277
- CHECK (parentBand);
278
- auto parentSchedule = parentBand->mupa_ ;
279
- auto nMember = parentBand->nMember ();
280
- CHECK_GE (nMember, numThreads.view .size () - 1 );
281
- parentSchedule = parentSchedule.drop_dims (
282
- isl::dim_type::set, 0 , nMember - (numThreads.view .size () - 1 ));
283
- reductionSchedule = parentSchedule.flat_range_product (reductionSchedule);
284
- }
258
+ // Drop band members following the reduction dimension and preceding those
259
+ // mapped to threads.
260
+ auto reductionSchedule = reductionBand->mupa_ ;
261
+ auto nMember = reductionBand->nMember ();
262
+ auto reductionDim = reductionBandUpdates_.at (st).reductionDim ;
263
+ auto nMappedThreads =
264
+ std::min (numThreads.view .size (), reductionBand->nOuterCoincident () + 1 );
265
+ CHECK_GE (nMember, reductionDim);
266
+ CHECK_GE (reductionDim + 1 , nMappedThreads);
267
+ reductionSchedule = reductionSchedule.drop_dims (
268
+ isl::dim_type::set, reductionDim + 1 , nMember - (reductionDim + 1 ));
269
+ reductionSchedule = reductionSchedule.drop_dims (
270
+ isl::dim_type::set, 0 , reductionDim - nMappedThreads + 1 );
285
271
286
272
return reductionSchedule;
287
273
}
288
274
289
275
detail::ScheduleTree* MappedScop::separateReduction (detail::ScheduleTree* st) {
290
- CHECK (reductionFromParent_.count (st) == 1 );
291
- auto reduction = reductionFromParent_.at (st);
276
+ auto reduction = st;
292
277
// This function either separates full blocks (if needed) or
293
278
// disables the reduction handling.
294
279
reductionBandUpdates_.at (reduction).separated = true ;
@@ -336,58 +321,53 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
336
321
return st->ancestor (root, 2 );
337
322
}
338
323
339
- size_t MappedScop::mapToThreads (detail::ScheduleTree* band, size_t nInner ) {
324
+ size_t MappedScop::mapToThreads (detail::ScheduleTree* band) {
340
325
using namespace tc ::polyhedral::detail;
341
326
342
- if (nInner >= numThreads.view .size ()) {
343
- return nInner;
327
+ auto bandNode = band->elemAs <ScheduleTreeElemBand>();
328
+ // Cannot map non-permutable bands.
329
+ if (!bandNode->permutable_ ) {
330
+ return 0 ;
344
331
}
332
+
333
+ int nMappedReductionThreads = 0 ;
345
334
if (reductionBandUpdates_.count (band) == 1 ) {
346
335
// A reduction is assumed to get mapped to threadIdx.x
347
- if (nInner != 0 ) {
348
- reductionBandUpdates_.erase (band);
349
- return nInner;
350
- }
351
336
CHECK (reductionBandUpdates_.at (band).separated );
352
337
threadIdxXScheduleDepthState.emplace_back (std::make_pair (
353
338
activeDomainPoints (schedule (), band),
354
339
band->scheduleDepth (schedule ()) + 0 ));
355
- band = map (band, 0 , mapping::ThreadId::x ());
356
- markUnroll (scop_->scheduleRoot (), band, unroll);
357
- return 1 ;
358
- }
359
- auto bandNode = band->elemAs <ScheduleTreeElemBand>();
360
- // If any inner node was mapped to threads and
361
- // the current node has a non-coincident member,
362
- // then synchronization needs to be introduced.
363
- // This also implies that the mapping needs to be completed first.
364
- if (anyNonCoincidentMember (bandNode) && nInner > 0 ) {
365
- // Since some thread identifiers were mapped already (nInner > 0),
366
- // the band should have descendants. Double check.
367
- CHECK_EQ (band->numChildren (), 1 );
368
- mapRemaining<mapping::ThreadId>(
369
- band->child ({0 }), nInner, numThreads.view .size ());
370
- scop_->insertSyncAfter (band->child ({0 }));
371
- return numThreads.view .size ();
340
+ auto reductionDim = reductionBandUpdates_.at (band).reductionDim ;
341
+ band = map (band, reductionDim, mapping::ThreadId::x ());
342
+ nMappedReductionThreads = 1 ;
372
343
}
344
+
373
345
// With current isl scheduler, if coincident dimensions exist in a band,
374
346
// they are outermost.
375
347
// If a band has more than 3 coincident dimensions,
376
348
// then the innermost of those will be used.
377
349
auto nOuterCoincident = bandNode->nOuterCoincident ();
378
- if (!bandNode-> permutable_ || nOuterCoincident < 1 ) {
379
- return nInner ;
350
+ if (nOuterCoincident < 1 ) {
351
+ return nMappedReductionThreads ;
380
352
}
381
353
382
354
auto nMappedThreads = std::min (
383
- numThreads.view .size () - nInner, static_cast <size_t >(nOuterCoincident));
384
- CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
385
- CHECK_LE (nMappedThreads, 3 - nInner) << " mapping to too many threads" ;
355
+ numThreads.view .size () - nMappedReductionThreads,
356
+ static_cast <size_t >(nOuterCoincident));
357
+
358
+ // Immediately return if mapping to one thread dimension only was requested
359
+ // and a reduction was already mapped. (Note that reduction is detected only
360
+ // if there are not enough outer coincident members, 0 in this case).
361
+ if (nMappedThreads == 0 ) {
362
+ return nMappedReductionThreads;
363
+ }
364
+ CHECK_LE (nMappedThreads, 3 - nMappedReductionThreads)
365
+ << " mapping to too many threads" ;
386
366
387
367
// Map the coincident dimensions to threads starting from the innermost and
388
- // from thread x.
368
+ // from thread x unless it was already mapped to a reduction .
389
369
for (int i = 0 ; i < nMappedThreads; ++i) {
390
- auto id = mapping::ThreadId::makeId (nInner + i);
370
+ auto id = mapping::ThreadId::makeId (nMappedReductionThreads + i);
391
371
auto dim = nOuterCoincident - 1 - i;
392
372
if (id == mapping::ThreadId::x ()) {
393
373
threadIdxXScheduleDepthState.emplace_back (std::make_pair (
@@ -397,11 +377,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band, size_t nInner) {
397
377
band = map (band, dim, id);
398
378
}
399
379
400
- if (nInner == 0 ) {
401
- markUnroll (scop_->scheduleRoot (), band, unroll);
402
- }
403
-
404
- return nInner + nMappedThreads;
380
+ return nMappedReductionThreads + nMappedThreads;
405
381
}
406
382
407
383
namespace {
@@ -431,8 +407,12 @@ bool hasOuterSequentialMember(
431
407
}
432
408
} // namespace
433
409
434
- // Maps bands to threads in DFS postorder, keeping track of
435
- // the (maximal) number of threads already mapped by descendants.
410
+ // Maps bands to threads in DFS postorder.
411
+ // Mapping is only allowed if descendants are not already mapped to threads.
412
+ // Mapping nested bands to threads is invalid because members of those bands
413
+ // are not necessarily permutable, and there is no guaranteed nesting between
414
+ // thread dimensions (e.g., there is no guarantee that all threads with
415
+ // threadIdx.y=0 will be executed before any thread with threadIdx.y=1).
436
416
//
437
417
// If any separation is needed for mapping reductions to full blocks,
438
418
// then do so first.
@@ -484,8 +464,24 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
484
464
}
485
465
}
486
466
487
- if (st->elemAs <detail::ScheduleTreeElemBand>()) {
488
- n = mapToThreads (st, n);
467
+ if (auto band = st->elemAs <detail::ScheduleTreeElemBand>()) {
468
+ if (n == 0 ) {
469
+ // If children were not mapped to threads, the current band can be mapped.
470
+ auto nMapped = mapToThreads (st);
471
+ markUnroll (scop_->scheduleRoot (), st, unroll);
472
+ return nMapped;
473
+ } else if (anyNonCoincidentMember (band)) {
474
+ // If children were mapped to threads, and this band has a non-coincident
475
+ // member, insert a synchronization after its last child.
476
+ // This also implies the mapping must be completed first.
477
+ // The node must have children if some of them were mapped to threads,
478
+ // double-check. Note that a band node has at most one child.
479
+ CHECK_EQ (st->numChildren (), 1 );
480
+ mapRemaining<mapping::ThreadId>(
481
+ st->child ({0 }), n, numThreads.view .size ());
482
+ scop_->insertSyncAfter (st->child ({0 }));
483
+ return numThreads.view .size ();
484
+ }
489
485
}
490
486
491
487
return n;
@@ -592,6 +588,22 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
592
588
mappedScopForCodegen->numThreads );
593
589
}
594
590
591
+ // Split out reduction loops into separate bands and insert reduction
592
+ // synchronizations outside those bands.
593
+ void MappedScop::splitOutReductionsAndInsertSyncs () {
594
+ using namespace polyhedral ::detail;
595
+
596
+ for (auto bandUpdate : reductionBandUpdates_) {
597
+ auto tree = bandSplitOut (
598
+ scop_->scheduleRoot (),
599
+ const_cast <ScheduleTree*>(bandUpdate.first ),
600
+ bandUpdate.second .reductionDim );
601
+ for (auto updateId : bandUpdate.second .ids ) {
602
+ scop_->insertReductionSync1D (tree, updateId);
603
+ }
604
+ }
605
+ }
606
+
595
607
std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy (
596
608
std::unique_ptr<Scop>&& scopUPtr,
597
609
const CudaMappingOptions& cudaOptions) {
@@ -655,7 +667,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
655
667
LOG_IF (INFO, FLAGS_debug_tc_mapper) << " After mapping to blocks:" << std::endl
656
668
<< *mappedScop->schedule ();
657
669
658
- // 7. Promote to shared memory below the loops mapped to blocks.
670
+ // 7. Insert reduction synchronizations if necessary.
671
+ mappedScop->splitOutReductionsAndInsertSyncs ();
672
+ LOG_IF (INFO, FLAGS_debug_tc_mapper)
673
+ << " After inserting reduction synchronization:" << std::endl
674
+ << *mappedScop->schedule ();
675
+
676
+ // 8. Promote to shared memory below the loops mapped to blocks.
659
677
// This may split the outer band, so find the new outer band after promotion.
660
678
if (cudaOptions.proto ().use_shared_memory ()) {
661
679
size_t sharedMemorySize = cudaOptions.proto ().has_max_shared_memory ()
@@ -702,24 +720,16 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
702
720
}
703
721
}
704
722
705
- // 8 . Promote to registers below the loops mapped to threads.
723
+ // 9 . Promote to registers below the loops mapped to threads.
706
724
if (cudaOptions.proto ().use_private_memory ()) {
707
725
promoteToRegistersBelowThreads (
708
726
mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
709
727
}
710
728
711
- // 9 . Insert mapping context
729
+ // 10 . Insert mapping context
712
730
mappedScop->insertMappingContext ();
713
-
714
- // 10. Optionally insert reduction synchronizations
715
- for (auto bandUpdate : mappedScop->reductionBandUpdates_ ) {
716
- for (auto updateId : bandUpdate.second .ids ) {
717
- scop->insertReductionSync1D (
718
- const_cast <ScheduleTree*>(bandUpdate.first ), updateId);
719
- }
720
- }
721
731
LOG_IF (INFO, FLAGS_debug_tc_mapper)
722
- << " After inserting reduction synchronization :" << std::endl
732
+ << " After outerBlockInnerThread strategy :" << std::endl
723
733
<< *mappedScop->schedule ();
724
734
725
735
return mappedScop;
0 commit comments