Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 92b874b

Browse files
Merge pull request #418 from facebookresearch/pr/reduction
insert reduction synchronization outside thread mapping
2 parents 6c8a77e + 62fa601 commit 92b874b

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
372372
bandSplit(scop_->scheduleRoot(), band, nCanMap - nMappedThreads);
373373
auto child = band->child({0});
374374
if (isReduction) {
375-
// Update reductionBandUpdates_ such that splitOutReductionAndInsertSyncs
375+
// Update reductionBandUpdates_ such that
376+
// splitOutReductionTileAndInsertSyncs
376377
// can find the information it needs.
377378
reductionBandUpdates_.emplace(child, reductionBandUpdates_.at(band));
378379
reductionBandUpdates_.erase(band);
@@ -387,12 +388,12 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
387388

388389
CHECK_GT(nMappedThreads, 0u) << "not mapping to threads";
389390

390-
mapThreadsBackward(band);
391-
392391
if (isReduction) {
393-
splitOutReductionAndInsertSyncs(band, nMappedThreads - 1);
392+
band = splitOutReductionTileAndInsertSyncs(band);
394393
}
395394

395+
mapThreadsBackward(band);
396+
396397
return numThreads.view.size();
397398
}
398399

@@ -946,17 +947,32 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
946947
mappedScopForCodegen->numThreads);
947948
}
948949

949-
// Split out reduction member at position "dim" in "band" and
950-
// insert reduction synchronizations outside this split off band.
951-
void MappedScop::splitOutReductionAndInsertSyncs(
952-
detail::ScheduleTree* band,
953-
int dim) {
950+
// Split out a single reduction tile (in the directions other than
951+
// the reduction) and insert reduction synchronizations outside this tile.
952+
// Return a pointer to the split off tile.
953+
detail::ScheduleTree* MappedScop::splitOutReductionTileAndInsertSyncs(
954+
detail::ScheduleTree* band) {
954955
using namespace polyhedral::detail;
956+
size_t n = numThreads.view.size();
957+
958+
// The current band contains only full blocks.
959+
// Split off a band that iterates over these blocks,
960+
// such that only a single block gets mapped to thread identifiers.
961+
// The mapping to thread identifier X is allowed to iterate
962+
// over multiple blocks, so this direction is not tiled.
963+
std::vector<size_t> sizes(n);
964+
for (size_t i = 1; i < n; ++i) {
965+
sizes[n - 1 - i] = numThreads.view[i];
966+
}
967+
sizes[n - 1] = 0;
968+
bandTile(band, sizes, TileOptions::ScaleTileLoops);
955969

956-
auto tree = bandSplitOut(scop_->scheduleRoot(), band, dim);
970+
// Insert synchronization outside the single block.
971+
auto child = band->child({0});
957972
for (auto updateId : reductionBandUpdates_.at(band).ids) {
958-
scop_->insertReductionSync1D(tree, updateId);
973+
scop_->insertReductionSync1D(child, updateId);
959974
}
975+
return child;
960976
}
961977

962978
std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ class MappedScop {
182182
private:
183183
// Insert the optimal combination of synchronizations in the sequence
184184
void insertBestSyncInSeq(detail::ScheduleTree* seq);
185-
// Split out reduction bands and insert reduction synchronizations.
186-
void splitOutReductionsAndInsertSyncs();
187-
// Split out reduction member at position "dim" in "band" and
188-
// insert reduction synchronizations.
189-
void splitOutReductionAndInsertSyncs(detail::ScheduleTree* band, int dim);
185+
// Split out a single reduction tile (in the directions other than
186+
// the reduction) and insert reduction synchronizations.
187+
// Return a pointer to the split off tile.
188+
detail::ScheduleTree* splitOutReductionTileAndInsertSyncs(
189+
detail::ScheduleTree* band);
190190
// Map "band" to thread identifiers using as many blockSizes values as outer
191191
// coincident dimensions (plus reduction dimension, if any),
192192
// insert synchronization in case of a reduction, and

0 commit comments

Comments
 (0)