Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 0f7c562

Browse files
committed
promoteToSharedBelow: extract out sortTensorGroupMap
This function will be reused in an upcoming commit to sort groups before register promotion.
1 parent 648cbe7 commit 0f7c562

File tree

1 file changed

+33
-28
lines changed

1 file changed

+33
-28
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,37 @@ bool isInThreadMappedScope(
446446
return false;
447447
}
448448

449+
static std::vector<std::pair<isl::id, TensorGroupsInfo>> sortTensorGroupMap(
450+
TensorGroups&& groupMap) {
451+
// Prepare groups for sorting, to have specified order necessary for
452+
// reproducibility and tests.
453+
using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
454+
std::vector<TensorGroupList> groupLists(
455+
std::make_move_iterator(groupMap.begin()),
456+
std::make_move_iterator(groupMap.end()));
457+
458+
// Computes the total number of references in all groups.
459+
auto refsCount = [](const TensorGroupsInfo& info) {
460+
size_t refs = 0;
461+
for (auto const& group : info) {
462+
refs += group->referenceIds().size();
463+
}
464+
return refs;
465+
};
466+
467+
// Sort by the total number of references, then by name. Because names are
468+
// guarenteed to be unique, the order is total.
469+
std::sort(
470+
groupLists.begin(),
471+
groupLists.end(),
472+
[refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
473+
auto r1 = refsCount(l1.second);
474+
auto r2 = refsCount(l2.second);
475+
return r1 == r2 ? l1.first.get_name() < l2.first.get_name() : r1 < r2;
476+
});
477+
return groupLists;
478+
}
479+
449480
/*
450481
* Promote to shared memory in "scop" below "node". Use at most
451482
* "remainingMemory" bytes, and update the variable to reflect the amount of
@@ -474,37 +505,11 @@ void promoteToSharedBelow(
474505
auto partialSched = partialSchedule(root, node);
475506
auto mapping = collectMappingsTo<mapping::BlockId>(scop);
476507

477-
auto groupMap = TensorReferenceGroup::accessedWithin(
478-
partialSched.intersect_domain(mapping), scop.body);
508+
auto groupLists = sortTensorGroupMap(TensorReferenceGroup::accessedWithin(
509+
partialSched.intersect_domain(mapping), scop.body));
479510
// Pure affine schedule without (mapping) filters.
480511
auto partialSchedMupa = partialScheduleMupa(root, node);
481512

482-
// Prepare groups for sorting, to have specified order necessary for
483-
// reproducibility and tests.
484-
using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
485-
std::vector<TensorGroupList> groupLists(
486-
std::make_move_iterator(groupMap.begin()),
487-
std::make_move_iterator(groupMap.end()));
488-
489-
// Computes the total number of references in all groups.
490-
auto refsCount = [](const TensorGroupsInfo& info) {
491-
size_t refs = 0;
492-
for (auto const& group : info) {
493-
refs += group->referenceIds().size();
494-
}
495-
return refs;
496-
};
497-
498-
// Sort by the total number of references, then by name. Because names are
499-
// guarenteed to be unique, the order is total.
500-
std::sort(
501-
groupLists.begin(),
502-
groupLists.end(),
503-
[refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
504-
auto r1 = refsCount(l1.second);
505-
auto r2 = refsCount(l2.second);
506-
return r1 == r2 ? l1.first.get_name() < l2.first.get_name() : r1 < r2;
507-
});
508513
for (auto& tensorGroups : groupLists) {
509514
auto tensorId = tensorGroups.first;
510515
// Sort the reference groups to prioritize groups with more references as

0 commit comments

Comments
 (0)