Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1e5ad91

Browse files
committed
promoteToRegistersBelow: sort tensor reference groups
Follow the same strategy as with shared memory promotion: first, sort tensors in decreasing order of the total number of references; then, for each tensor, sort groups based on the number of references in this group. Tensor groups with more references are expected to benefit more from promotion as more global memory accesses may be avoided thanks to explicit caching in faster layers of the memory hierarchy. Note that since there is no limit on the number of registers to use, all groups that can be promoted into registers are promoted, and the sorting has no effect on the outcome. Such limit will be introduced next.
1 parent 4aa377d commit 1e5ad91

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,8 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
656656
auto mapping =
657657
collectMappingsTo<mapping::ThreadId>(scop).intersect(blockMapping);
658658
auto schedule = partialSchedule(scop.scheduleRoot(), scope);
659-
auto groupMap = TensorReferenceGroup::accessedWithin(
660-
schedule.intersect_domain(mapping), scop.body);
659+
auto groupLists = sortTensorGroupMap(TensorReferenceGroup::accessedWithin(
660+
schedule.intersect_domain(mapping), scop.body));
661661

662662
auto threadSchedule = mscop.threadMappingSchedule(mscop.schedule());
663663
auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule());
@@ -673,10 +673,10 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
673673
// identical dimensions without affecting the result of the checks.
674674
partialSchedMupa = partialSchedMupa.flat_range_product(blockSchedule);
675675

676-
for (auto& tensorGroups : groupMap) {
676+
for (auto& tensorGroups : groupLists) {
677677
auto tensorId = tensorGroups.first;
678-
679-
// TODO: sorting of groups and counting the number of promoted elements
678+
sortTensorGroups(tensorGroups.second);
679+
// TODO: counting the number of promoted elements
680680

681681
for (auto& group : tensorGroups.second) {
682682
auto sizes = group->approximationSizes();

0 commit comments

Comments
 (0)