Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit f8b49cc

Browse files
committed
promoteToRegistersBelowThreads: extract out promoteToRegistersBelow
The extracted function will be extended to support register promotion below any node in the tree.
1 parent a3ad5b0 commit f8b49cc

File tree

1 file changed

+57
-45
lines changed

1 file changed

+57
-45
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -649,60 +649,72 @@ void promoteGreedilyAtDepth(
649649
mapCopiesToThreads(mscop, unrollCopies);
650650
}
651651

652-
// Promote at the positions of the thread specific markers.
653-
void promoteToRegistersBelowThreads(MappedScop& mscop, size_t nRegisters) {
654-
using namespace tc::polyhedral::detail;
652+
namespace {
655653

654+
/*
655+
* Perform promotion to registers below the thread specific marker "marker"
656+
* in the schedule tree of "mscop".
657+
*/
658+
void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* marker) {
656659
auto& scop = mscop.scop();
657660
auto root = scop.scheduleRoot();
658661
auto threadMapping = mscop.threadMappingSchedule(root);
659662

660-
{
661-
auto markers = findThreadSpecificMarkers(root);
662-
663-
for (auto marker : markers) {
664-
auto partialSched = prefixSchedule(root, marker);
665-
// Pure affine schedule without (mapping) filters.
666-
auto partialSchedMupa = partialScheduleMupa(root, marker);
667-
668-
// Because this function is called below the thread mapping marker,
669-
// partialSched has been intersected with both the block and the thread
670-
// mapping filters. Therefore, groups will be computed relative to
671-
// blocks and threads.
672-
auto groupMap = TensorReferenceGroup::accessedWithin(
673-
partialSched, scop.reads, scop.writes);
674-
for (auto& tensorGroups : groupMap) {
675-
auto tensorId = tensorGroups.first;
676-
677-
// TODO: sorting of groups and counting the number of promoted elements
678-
679-
for (auto& group : tensorGroups.second) {
680-
auto sizes = group->approximationSizes();
681-
// No point in promoting a scalar that will go to a register anyway.
682-
if (sizes.size() == 0) {
683-
continue;
684-
}
685-
if (!isPromotableToRegistersBelow(
686-
*group, root, marker, partialSchedMupa, threadMapping)) {
687-
continue;
688-
}
689-
if (!hasReuseWithin(*group, partialSchedMupa)) {
690-
continue;
691-
}
692-
// TODO: if something is already in shared, but reuse it within one
693-
// thread only, there is no point in keeping it in shared _if_ it
694-
// gets promoted into a register.
695-
scop.promoteGroup(
696-
Scop::PromotedDecl::Kind::Register,
697-
tensorId,
698-
std::move(group),
699-
marker,
700-
partialSched);
701-
}
663+
auto partialSched = prefixSchedule(root, marker);
664+
// Pure affine schedule without (mapping) filters.
665+
auto partialSchedMupa = partialScheduleMupa(root, marker);
666+
667+
// Because this function is called below the thread mapping marker,
668+
// partialSched has been intersected with both the block and the thread
669+
// mapping filters. Therefore, groups will be computed relative to
670+
// blocks and threads.
671+
auto groupMap = TensorReferenceGroup::accessedWithin(
672+
partialSched, scop.reads, scop.writes);
673+
for (auto& tensorGroups : groupMap) {
674+
auto tensorId = tensorGroups.first;
675+
676+
// TODO: sorting of groups and counting the number of promoted elements
677+
678+
for (auto& group : tensorGroups.second) {
679+
auto sizes = group->approximationSizes();
680+
// No point in promoting a scalar that will go to a register anyway.
681+
if (sizes.size() == 0) {
682+
continue;
683+
}
684+
if (!isPromotableToRegistersBelow(
685+
*group, root, marker, partialSchedMupa, threadMapping)) {
686+
continue;
702687
}
688+
if (!hasReuseWithin(*group, partialSchedMupa)) {
689+
continue;
690+
}
691+
// TODO: if something is already in shared, but reuse it within one
692+
// thread only, there is no point in keeping it in shared _if_ it
693+
// gets promoted into a register.
694+
scop.promoteGroup(
695+
Scop::PromotedDecl::Kind::Register,
696+
tensorId,
697+
std::move(group),
698+
marker,
699+
partialSched);
703700
}
704701
}
705702
}
706703

704+
} // namespace
705+
706+
// Promote at the positions of the thread specific markers.
707+
void promoteToRegistersBelowThreads(MappedScop& mscop, size_t nRegisters) {
708+
using namespace tc::polyhedral::detail;
709+
710+
auto& scop = mscop.scop();
711+
auto root = scop.scheduleRoot();
712+
auto markers = findThreadSpecificMarkers(root);
713+
714+
for (auto marker : markers) {
715+
promoteToRegistersBelow(mscop, marker);
716+
}
717+
}
718+
707719
} // namespace polyhedral
708720
} // namespace tc

0 commit comments

Comments
 (0)