@@ -372,7 +372,8 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
372
372
bandSplit (scop_->scheduleRoot (), band, nCanMap - nMappedThreads);
373
373
auto child = band->child ({0 });
374
374
if (isReduction) {
375
- // Update reductionBandUpdates_ such that splitOutReductionAndInsertSyncs
375
+ // Update reductionBandUpdates_ such that
376
+ // splitOutReductionTileAndInsertSyncs
376
377
// can find the information it needs.
377
378
reductionBandUpdates_.emplace (child, reductionBandUpdates_.at (band));
378
379
reductionBandUpdates_.erase (band);
@@ -387,12 +388,12 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
387
388
388
389
CHECK_GT (nMappedThreads, 0u ) << " not mapping to threads" ;
389
390
390
- mapThreadsBackward (band);
391
-
392
391
if (isReduction) {
393
- splitOutReductionAndInsertSyncs ( band, nMappedThreads - 1 );
392
+ band = splitOutReductionTileAndInsertSyncs (band );
394
393
}
395
394
395
+ mapThreadsBackward (band);
396
+
396
397
return numThreads.view .size ();
397
398
}
398
399
@@ -946,17 +947,32 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
946
947
mappedScopForCodegen->numThreads );
947
948
}
948
949
949
- // Split out reduction member at position "dim" in "band" and
950
- // insert reduction synchronizations outside this split off band .
951
- void MappedScop::splitOutReductionAndInsertSyncs (
952
- detail::ScheduleTree* band,
953
- int dim ) {
950
+ // Split out a single reduction tile ( in the directions other than
951
+ // the reduction) and insert reduction synchronizations outside this tile .
952
+ // Return a pointer to the split off tile.
953
+ detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs (
954
+ detail::ScheduleTree* band ) {
954
955
using namespace polyhedral ::detail;
956
+ size_t n = numThreads.view .size ();
957
+
958
+ // The current band contains only full blocks.
959
+ // Split off a band that iterates over these blocks,
960
+ // such that only a single block gets mapped to thread identifiers.
961
+ // The mapping to thread identifier X is allowed to iterate
962
+ // over multiple blocks, so this direction is not tiled.
963
+ std::vector<size_t > sizes (n);
964
+ for (size_t i = 1 ; i < n; ++i) {
965
+ sizes[n - 1 - i] = numThreads.view [i];
966
+ }
967
+ sizes[n - 1 ] = 0 ;
968
+ bandTile (band, sizes, TileOptions::ScaleTileLoops);
955
969
956
- auto tree = bandSplitOut (scop_->scheduleRoot (), band, dim);
970
+ // Insert synchronization outside the single block.
971
+ auto child = band->child ({0 });
957
972
for (auto updateId : reductionBandUpdates_.at (band).ids ) {
958
- scop_->insertReductionSync1D (tree , updateId);
973
+ scop_->insertReductionSync1D (child , updateId);
959
974
}
975
+ return child;
960
976
}
961
977
962
978
std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy (
0 commit comments