@@ -446,6 +446,37 @@ bool isInThreadMappedScope(
446
446
return false ;
447
447
}
448
448
449
+ static std::vector<std::pair<isl::id, TensorGroupsInfo>> sortTensorGroupMap (
450
+ TensorGroups&& groupMap) {
451
+ // Prepare groups for sorting, to have specified order necessary for
452
+ // reproducibility and tests.
453
+ using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
454
+ std::vector<TensorGroupList> groupLists (
455
+ std::make_move_iterator (groupMap.begin ()),
456
+ std::make_move_iterator (groupMap.end ()));
457
+
458
+ // Computes the total number of references in all groups.
459
+ auto refsCount = [](const TensorGroupsInfo& info) {
460
+ size_t refs = 0 ;
461
+ for (auto const & group : info) {
462
+ refs += group->referenceIds ().size ();
463
+ }
464
+ return refs;
465
+ };
466
+
467
+ // Sort by the total number of references, then by name. Because names are
468
+ // guarenteed to be unique, the order is total.
469
+ std::sort (
470
+ groupLists.begin (),
471
+ groupLists.end (),
472
+ [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
473
+ auto r1 = refsCount (l1.second );
474
+ auto r2 = refsCount (l2.second );
475
+ return r1 == r2 ? l1.first .get_name () < l2.first .get_name () : r1 < r2;
476
+ });
477
+ return groupLists;
478
+ }
479
+
449
480
/*
450
481
* Promote to shared memory in "scop" below "node". Use at most
451
482
* "remainingMemory" bytes, and update the variable to reflect the amount of
@@ -474,37 +505,11 @@ void promoteToSharedBelow(
474
505
auto partialSched = partialSchedule (root, node);
475
506
auto mapping = collectMappingsTo<mapping::BlockId>(scop);
476
507
477
- auto groupMap = TensorReferenceGroup::accessedWithin (
478
- partialSched.intersect_domain (mapping), scop.body );
508
+ auto groupLists = sortTensorGroupMap ( TensorReferenceGroup::accessedWithin (
509
+ partialSched.intersect_domain (mapping), scop.body )) ;
479
510
// Pure affine schedule without (mapping) filters.
480
511
auto partialSchedMupa = partialScheduleMupa (root, node);
481
512
482
- // Prepare groups for sorting, to have specified order necessary for
483
- // reproducibility and tests.
484
- using TensorGroupList = std::pair<isl::id, TensorGroupsInfo>;
485
- std::vector<TensorGroupList> groupLists (
486
- std::make_move_iterator (groupMap.begin ()),
487
- std::make_move_iterator (groupMap.end ()));
488
-
489
- // Computes the total number of references in all groups.
490
- auto refsCount = [](const TensorGroupsInfo& info) {
491
- size_t refs = 0 ;
492
- for (auto const & group : info) {
493
- refs += group->referenceIds ().size ();
494
- }
495
- return refs;
496
- };
497
-
498
- // Sort by the total number of references, then by name. Because names are
499
- // guarenteed to be unique, the order is total.
500
- std::sort (
501
- groupLists.begin (),
502
- groupLists.end (),
503
- [refsCount](const TensorGroupList& l1, const TensorGroupList& l2) {
504
- auto r1 = refsCount (l1.second );
505
- auto r2 = refsCount (l2.second );
506
- return r1 == r2 ? l1.first .get_name () < l2.first .get_name () : r1 < r2;
507
- });
508
513
for (auto & tensorGroups : groupLists) {
509
514
auto tensorId = tensorGroups.first ;
510
515
// Sort the reference groups to prioritize groups with more references as
0 commit comments