Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ea23712

Browse files
Merge pull request #344 from facebookresearch/pr/marker
replace ThreadIdxXScheduleDepthState by thread specific markers
2 parents 76ae999 + 594be20 commit ea23712

12 files changed

+244
-184
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,6 @@ void MappedScop::mapRemaining(detail::ScheduleTree* tree, size_t nMapped) {
124124
auto filter = makeFixRemainingZeroFilter(domain, ids);
125125
auto mapping = detail::ScheduleTree::makeMappingFilter(filter, ids);
126126
insertNodeAbove(root, tree, std::move(mapping));
127-
128-
for (size_t i = nMapped; i < nToMap; ++i) {
129-
if (MappingTypeId::makeId(i) == mapping::ThreadId::x()) {
130-
threadIdxXScheduleDepthState.emplace_back(std::make_pair(
131-
activeDomainPoints(schedule(), tree),
132-
tree->scheduleDepth(schedule())));
133-
}
134-
}
135127
}
136128

137129
// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -161,6 +153,7 @@ void MappedScop::mapToBlocksAndScaleBand(
161153
* Given a node in the schedule tree of a mapped scop,
162154
* insert a mapping filter underneath (if needed) that fixes
163155
* the remaining thread identifiers starting at "begin" to zero.
156+
* Add a marker underneath that marks the subtree that is thread specific.
164157
*/
165158
void fixThreadsBelow(
166159
MappedScop& mscop,
@@ -173,6 +166,9 @@ void fixThreadsBelow(
173166

174167
auto band = detail::ScheduleTree::makeEmptyBand(mscop.scop().scheduleRoot());
175168
auto bandTree = insertNodeBelow(tree, std::move(band));
169+
auto ctx = tree->ctx_;
170+
insertNodeBelow(
171+
bandTree, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
176172
mscop.mapRemaining<mapping::ThreadId>(bandTree, begin);
177173
}
178174

@@ -338,8 +334,29 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
338334
return 0;
339335
}
340336

341-
auto nMappedThreads =
342-
std::min(numThreads.view.size(), static_cast<size_t>(nCanMap));
337+
auto nMappedThreads = nCanMap;
338+
if (nMappedThreads > numThreads.view.size()) {
339+
// Split band such that mapping filters get inserted
340+
// right above the first member mapped to a thread identifier.
341+
nMappedThreads = numThreads.view.size();
342+
bandSplit(scop_->scheduleRoot(), band, nCanMap - nMappedThreads);
343+
auto child = band->child({0});
344+
if (isReduction) {
345+
// Update reductionBandUpdates_ such that splitOutReductionAndInsertSyncs
346+
// can find the information it needs.
347+
reductionBandUpdates_.emplace(child, reductionBandUpdates_.at(band));
348+
reductionBandUpdates_.erase(band);
349+
}
350+
band = child;
351+
bandNode = band->elemAs<ScheduleTreeElemBand>();
352+
}
353+
354+
if (nMappedThreads < bandNode->nMember()) {
355+
bandSplit(scop_->scheduleRoot(), band, nMappedThreads);
356+
}
357+
358+
auto ctx = band->ctx_;
359+
insertNodeBelow(band, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
343360

344361
CHECK_GT(nMappedThreads, 0) << "not mapping to threads";
345362
CHECK_LE(nMappedThreads, 3) << "mapping to too many threads";
@@ -348,20 +365,16 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
348365
// from thread x.
349366
for (size_t i = 0; i < nMappedThreads; ++i) {
350367
auto id = mapping::ThreadId::makeId(i);
351-
auto dim = nCanMap - 1 - i;
352-
if (id == mapping::ThreadId::x()) {
353-
threadIdxXScheduleDepthState.emplace_back(std::make_pair(
354-
activeDomainPoints(schedule(), band),
355-
band->scheduleDepth(schedule()) + dim));
356-
}
368+
auto dim = nMappedThreads - 1 - i;
357369
band = map(band, dim, id);
358370
}
371+
mapRemaining<mapping::ThreadId>(band, nMappedThreads);
359372

360373
if (isReduction) {
361-
splitOutReductionAndInsertSyncs(band, nCanMap - 1);
374+
splitOutReductionAndInsertSyncs(band, nMappedThreads - 1);
362375
}
363376

364-
return nMappedThreads;
377+
return numThreads.view.size();
365378
}
366379

367380
namespace {
@@ -450,9 +463,8 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
450463
// because we cannot map parent bands anyway.
451464
auto nMapped = mapToThreads(st);
452465
if (nMapped > 0) {
453-
mapRemaining<mapping::ThreadId>(st, nMapped);
454466
markUnroll(scop_->scheduleRoot(), st, unroll);
455-
return numThreads.view.size();
467+
return nMapped;
456468
}
457469
} else if (anyNonCoincidentMember(band)) {
458470
// If children were mapped to threads, and this band has a non-coincident
@@ -633,7 +645,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
633645
auto child = outerBand->child({0});
634646
size_t numMappedInnerThreads =
635647
mappedScop->mapInnermostBandsToThreads(child);
636-
mappedScop->mapRemaining<mapping::ThreadId>(child, numMappedInnerThreads);
648+
fixThreadsBelow(*mappedScop, outerBand, numMappedInnerThreads);
637649
LOG_IF(INFO, FLAGS_debug_tc_mapper)
638650
<< "After mapping to threads:" << std::endl
639651
<< *mappedScop->schedule();
@@ -677,7 +689,6 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
677689

678690
promoteGreedilyAtDepth(
679691
*mappedScop,
680-
mappedScop->threadIdxXScheduleDepthState,
681692
std::min(band->nOuterCoincident(), mappedScop->numBlocks.view.size()),
682693
sharedMemorySize,
683694
cudaOptions.proto().unroll_copy_shared() &&
@@ -694,8 +705,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
694705

695706
// 8. Promote to registers below the loops mapped to threads.
696707
if (cudaOptions.proto().use_private_memory()) {
697-
promoteToRegistersBelowThreads(
698-
mappedScop->scop(), mappedScop->threadIdxXScheduleDepthState, -1ull);
708+
promoteToRegistersBelowThreads(mappedScop->scop(), -1ull);
699709
}
700710

701711
// 9. Insert mapping context

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ class MappedScop {
162162
// coincident dimensions (plus reduction dimension, if any),
163163
// insert synchronization in case of a reduction, and
164164
// return the number of mapped thread identifiers.
165+
// A marker is added to mark the part of the tree that is thread specific
166+
// (right underneath the innermost band member mapped to a thread identifier).
165167
size_t mapToThreads(detail::ScheduleTree* band);
166168
// Map innermost bands to thread identifiers,
167169
// inserting synchronization in case of a reduction, and
@@ -176,13 +178,6 @@ class MappedScop {
176178
const ::tc::Block numThreads;
177179
const uint64_t unroll;
178180

179-
// The schedule depth that was mapped to Thread::x for specific parts of the
180-
// domain.
181-
// XXX: this is a partially redundant state as this information can
182-
// potentially be extracted from the schedule tree; however, until we get a
183-
// first-class MappingNode, it requires some dirty hacks.
184-
ThreadIdxXScheduleDepthState threadIdxXScheduleDepthState;
185-
186181
private:
187182
// Information about a detected reduction that can potentially
188183
// be mapped to a library call.

0 commit comments

Comments
 (0)