Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 4aa377d

Browse files
committed
promoteToSharedBelow: extract out sortTensorGroups
This function will be reused in an upcoming commit.
1 parent 0f7c562 commit 4aa377d

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,20 @@ static std::vector<std::pair<isl::id, TensorGroupsInfo>> sortTensorGroupMap(
477477
return groupLists;
478478
}
479479

480+
/* Sorts the given vector of tensor groups in place following the number of
481+
* references in the group in decreasing order. This prioritize groups with
482+
* more references as they are more likely to benefit from promotion.
483+
*/
484+
static void sortTensorGroups(TensorGroupsInfo& tensorGroups) {
485+
std::sort(
486+
tensorGroups.begin(),
487+
tensorGroups.end(),
488+
[](const std::unique_ptr<TensorReferenceGroup>& group1,
489+
const std::unique_ptr<TensorReferenceGroup>& group2) {
490+
return group1->referenceIds().size() > group2->referenceIds().size();
491+
});
492+
}
493+
480494
/*
481495
* Promote to shared memory in "scop" below "node". Use at most
482496
* "remainingMemory" bytes, and update the variable to reflect the amount of
@@ -512,15 +526,7 @@ void promoteToSharedBelow(
512526

513527
for (auto& tensorGroups : groupLists) {
514528
auto tensorId = tensorGroups.first;
515-
// Sort the reference groups to prioritize groups with more references as
516-
// they are more likely to benefit from promotion.
517-
std::sort(
518-
tensorGroups.second.begin(),
519-
tensorGroups.second.end(),
520-
[](const std::unique_ptr<TensorReferenceGroup>& group1,
521-
const std::unique_ptr<TensorReferenceGroup>& group2) {
522-
return group1->referenceIds().size() > group2->referenceIds().size();
523-
});
529+
sortTensorGroups(tensorGroups.second);
524530

525531
for (auto& group : tensorGroups.second) {
526532
auto sizes = group->approximationSizes();

0 commit comments

Comments
 (0)