@@ -338,11 +338,12 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
338
338
// then the innermost of those will be used.
339
339
auto nCanMap = bandNode->nOuterCoincident ();
340
340
341
+ auto isReduction = reductionBandUpdates_.count (band) == 1 ;
341
342
// If the band has a detected reduction, then the first member
342
343
// after the coincident members is the reduction member and
343
344
// this member has to be mapped as well.
344
345
// In particular, it will get mapped to threadIdx.x
345
- if (reductionBandUpdates_. count (band) == 1 ) {
346
+ if (isReduction ) {
346
347
CHECK (reductionBandUpdates_.at (band).separated );
347
348
nCanMap++;
348
349
}
@@ -370,6 +371,10 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
370
371
band = map (band, dim, id);
371
372
}
372
373
374
+ if (isReduction) {
375
+ splitOutReductionAndInsertSyncs (band, nCanMap - 1 );
376
+ }
377
+
373
378
return nMappedThreads;
374
379
}
375
380
@@ -585,19 +590,16 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
585
590
mappedScopForCodegen->numThreads );
586
591
}
587
592
588
- // Split out reduction loops into separate bands and insert reduction
589
- // synchronizations outside those bands.
590
- void MappedScop::splitOutReductionsAndInsertSyncs () {
593
+ // Split out reduction member at position "dim" in "band" and
594
+ // insert reduction synchronizations outside this split off band.
595
+ void MappedScop::splitOutReductionAndInsertSyncs (
596
+ detail::ScheduleTree* band,
597
+ int dim) {
591
598
using namespace polyhedral ::detail;
592
599
593
- for (auto bandUpdate : reductionBandUpdates_) {
594
- auto tree = bandSplitOut (
595
- scop_->scheduleRoot (),
596
- const_cast <ScheduleTree*>(bandUpdate.first ),
597
- bandUpdate.second .reductionDim );
598
- for (auto updateId : bandUpdate.second .ids ) {
599
- scop_->insertReductionSync1D (tree, updateId);
600
- }
600
+ auto tree = bandSplitOut (scop_->scheduleRoot (), band, dim);
601
+ for (auto updateId : reductionBandUpdates_.at (band).ids ) {
602
+ scop_->insertReductionSync1D (tree, updateId);
601
603
}
602
604
}
603
605
@@ -664,13 +666,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
664
666
LOG_IF (INFO, FLAGS_debug_tc_mapper) << " After mapping to blocks:" << std::endl
665
667
<< *mappedScop->schedule ();
666
668
667
- // 7. Insert reduction synchronizations if necessary.
668
- mappedScop->splitOutReductionsAndInsertSyncs ();
669
- LOG_IF (INFO, FLAGS_debug_tc_mapper)
670
- << " After inserting reduction synchronization:" << std::endl
671
- << *mappedScop->schedule ();
672
-
673
- // 8. Promote to shared memory below the loops mapped to blocks.
669
+ // 7. Promote to shared memory below the loops mapped to blocks.
674
670
// This may split the outer band, so find the new outer band after promotion.
675
671
if (cudaOptions.proto ().use_shared_memory ()) {
676
672
size_t sharedMemorySize = cudaOptions.proto ().has_max_shared_memory ()
@@ -717,13 +713,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
717
713
}
718
714
}
719
715
720
- // 9 . Promote to registers below the loops mapped to threads.
716
+ // 8 . Promote to registers below the loops mapped to threads.
721
717
if (cudaOptions.proto ().use_private_memory ()) {
722
718
promoteToRegistersBelowThreads (
723
719
mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
724
720
}
725
721
726
- // 10 . Insert mapping context
722
+ // 9 . Insert mapping context
727
723
mappedScop->insertMappingContext ();
728
724
LOG_IF (INFO, FLAGS_debug_tc_mapper)
729
725
<< " After outerBlockInnerThread strategy:" << std::endl
0 commit comments