Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2e9b79e

Browse files
authored
Merge pull request #489 from facebookresearch/pr/registers
register promotion: promote below any node in the tree
2 parents 49a2965 + f904ccd commit 2e9b79e

12 files changed

+568
-146
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -461,45 +461,13 @@ bool hasOuterSequentialMember(
461461
return false;
462462
}
463463

464+
// Name of the space of blocks inside the grid
465+
constexpr auto kGrid = "grid";
464466
// Name of the space of threads inside a block
465467
constexpr auto kBlock = "block";
466468
// Name of the space of warps
467469
constexpr auto kWarp = "warp";
468470

469-
/*
470-
* Extract a mapping from the domain elements active at "tree"
471-
* to the thread identifiers, where all branches in "tree"
472-
* are assumed to have been mapped to thread identifiers.
473-
* "nThread" is the number of thread identifiers.
474-
* The result lives in a space of the form block[x, ...].
475-
*/
476-
isl::multi_union_pw_aff extractDomainToThread(
477-
const detail::ScheduleTree* tree,
478-
size_t nThread) {
479-
using namespace polyhedral::detail;
480-
481-
auto space = isl::space(tree->ctx_, 0);
482-
auto empty = isl::union_set::empty(space);
483-
auto id = isl::id(tree->ctx_, kBlock);
484-
space = space.named_set_from_params_id(id, nThread);
485-
auto zero = isl::multi_val::zero(space);
486-
auto domainToThread = isl::multi_union_pw_aff(empty, zero);
487-
488-
for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) {
489-
auto mappingNode = mapping->elemAs<ScheduleTreeElemMappingFilter>();
490-
auto list = isl::union_pw_aff_list(tree->ctx_, nThread);
491-
for (size_t i = 0; i < nThread; ++i) {
492-
auto threadId = mapping::ThreadId::makeId(i);
493-
auto threadMap = mappingNode->mapping.at(threadId);
494-
list = list.add(threadMap);
495-
}
496-
auto nodeToThread = isl::multi_union_pw_aff(space, list);
497-
domainToThread = domainToThread.union_add(nodeToThread);
498-
}
499-
500-
return domainToThread;
501-
}
502-
503471
/*
504472
* Construct a mapping
505473
*
@@ -534,6 +502,26 @@ isl::multi_aff constructThreadToWarp(
534502
}
535503
} // namespace
536504

505+
isl::multi_union_pw_aff MappedScop::threadMappingSchedule(
506+
const detail::ScheduleTree* tree) const {
507+
std::vector<mapping::MappingId> ids;
508+
for (size_t i = 0; i < numThreads.view.size(); ++i) {
509+
ids.emplace_back(mapping::ThreadId::makeId(i));
510+
}
511+
auto tupleId = isl::id(tree->ctx_, kBlock);
512+
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
513+
}
514+
515+
isl::multi_union_pw_aff MappedScop::blockMappingSchedule(
516+
const detail::ScheduleTree* tree) const {
517+
std::vector<mapping::MappingId> ids;
518+
for (size_t i = 0; i < numBlocks.view.size(); ++i) {
519+
ids.emplace_back(mapping::BlockId::makeId(i));
520+
}
521+
auto tupleId = isl::id(tree->ctx_, kGrid);
522+
return extractDomainToIds(scop_->scheduleRoot(), tree, ids, tupleId);
523+
}
524+
537525
Scop::SyncLevel MappedScop::findBestSync(
538526
detail::ScheduleTree* st1,
539527
detail::ScheduleTree* st2,
@@ -724,7 +712,7 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
724712

725713
auto outer = hasOuterSequentialMember(scop_->scheduleRoot(), seq);
726714

727-
auto domainToThread = extractDomainToThread(seq, numThreads.view.size());
715+
auto domainToThread = threadMappingSchedule(seq);
728716
auto threadToWarp = constructThreadToWarp(seq->ctx_, 32, numThreads);
729717
auto domainToWarp = domainToThread.apply(threadToWarp);
730718

@@ -1080,7 +1068,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
10801068

10811069
// 9. Promote to registers below the loops mapped to threads.
10821070
if (cudaOptions.proto().use_private_memory()) {
1083-
promoteToRegistersBelowThreads(mappedScop->scop(), -1ull);
1071+
promoteToRegistersBelowThreads(*mappedScop, -1ull);
10841072
}
10851073

10861074
LOG_IF(INFO, FLAGS_debug_tc_mapper)

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,24 @@ class MappedScop {
186186
size_t nChildren,
187187
bool hasOuterSequentialMember);
188188

189+
// Extract a mapping from the domain elements active at "tree"
190+
// to the thread identifiers, where all branches in "tree"
191+
// are assumed to have been mapped to thread identifiers.
192+
// The result lives in a space of the form block[x, ...].
193+
//
194+
// Note: this function ignores statements introduced by extension nodes.
195+
isl::multi_union_pw_aff threadMappingSchedule(
196+
const detail::ScheduleTree* tree) const;
197+
198+
// Extract a mapping from the domain elements active at "tree"
199+
// to the block identifiers, where all branches in "tree"
200+
// are assumed to have been mapped to block identifiers.
201+
// The result lives in a space of the form grid[x, ...].
202+
//
203+
// Note: this function ignores statements introduced by extension nodes.
204+
isl::multi_union_pw_aff blockMappingSchedule(
205+
const detail::ScheduleTree* tree) const;
206+
189207
private:
190208
// Insert the optimal combination of synchronizations in the sequence
191209
void insertBestSyncInSeq(detail::ScheduleTree* seq);

0 commit comments

Comments
 (0)