@@ -649,60 +649,72 @@ void promoteGreedilyAtDepth(
649
649
mapCopiesToThreads (mscop, unrollCopies);
650
650
}
651
651
652
- // Promote at the positions of the thread specific markers.
653
- void promoteToRegistersBelowThreads (MappedScop& mscop, size_t nRegisters) {
654
- using namespace tc ::polyhedral::detail;
652
+ namespace {
655
653
654
+ /*
655
+ * Perform promotion to registers below the thread specific marker "marker"
656
+ * in the schedule tree of "mscop".
657
+ */
658
+ void promoteToRegistersBelow (MappedScop& mscop, detail::ScheduleTree* marker) {
656
659
auto & scop = mscop.scop ();
657
660
auto root = scop.scheduleRoot ();
658
661
auto threadMapping = mscop.threadMappingSchedule (root);
659
662
660
- {
661
- auto markers = findThreadSpecificMarkers (root);
662
-
663
- for (auto marker : markers) {
664
- auto partialSched = prefixSchedule (root, marker);
665
- // Pure affine schedule without (mapping) filters.
666
- auto partialSchedMupa = partialScheduleMupa (root, marker);
667
-
668
- // Because this function is called below the thread mapping marker,
669
- // partialSched has been intersected with both the block and the thread
670
- // mapping filters. Therefore, groups will be computed relative to
671
- // blocks and threads.
672
- auto groupMap = TensorReferenceGroup::accessedWithin (
673
- partialSched, scop.reads , scop.writes );
674
- for (auto & tensorGroups : groupMap) {
675
- auto tensorId = tensorGroups.first ;
676
-
677
- // TODO: sorting of groups and counting the number of promoted elements
678
-
679
- for (auto & group : tensorGroups.second ) {
680
- auto sizes = group->approximationSizes ();
681
- // No point in promoting a scalar that will go to a register anyway.
682
- if (sizes.size () == 0 ) {
683
- continue ;
684
- }
685
- if (!isPromotableToRegistersBelow (
686
- *group, root, marker, partialSchedMupa, threadMapping)) {
687
- continue ;
688
- }
689
- if (!hasReuseWithin (*group, partialSchedMupa)) {
690
- continue ;
691
- }
692
- // TODO: if something is already in shared, but reuse it within one
693
- // thread only, there is no point in keeping it in shared _if_ it
694
- // gets promoted into a register.
695
- scop.promoteGroup (
696
- Scop::PromotedDecl::Kind::Register,
697
- tensorId,
698
- std::move (group),
699
- marker,
700
- partialSched);
701
- }
663
+ auto partialSched = prefixSchedule (root, marker);
664
+ // Pure affine schedule without (mapping) filters.
665
+ auto partialSchedMupa = partialScheduleMupa (root, marker);
666
+
667
+ // Because this function is called below the thread mapping marker,
668
+ // partialSched has been intersected with both the block and the thread
669
+ // mapping filters. Therefore, groups will be computed relative to
670
+ // blocks and threads.
671
+ auto groupMap = TensorReferenceGroup::accessedWithin (
672
+ partialSched, scop.reads , scop.writes );
673
+ for (auto & tensorGroups : groupMap) {
674
+ auto tensorId = tensorGroups.first ;
675
+
676
+ // TODO: sorting of groups and counting the number of promoted elements
677
+
678
+ for (auto & group : tensorGroups.second ) {
679
+ auto sizes = group->approximationSizes ();
680
+ // No point in promoting a scalar that will go to a register anyway.
681
+ if (sizes.size () == 0 ) {
682
+ continue ;
683
+ }
684
+ if (!isPromotableToRegistersBelow (
685
+ *group, root, marker, partialSchedMupa, threadMapping)) {
686
+ continue ;
702
687
}
688
+ if (!hasReuseWithin (*group, partialSchedMupa)) {
689
+ continue ;
690
+ }
691
+ // TODO: if something is already in shared, but reuse it within one
692
+ // thread only, there is no point in keeping it in shared _if_ it
693
+ // gets promoted into a register.
694
+ scop.promoteGroup (
695
+ Scop::PromotedDecl::Kind::Register,
696
+ tensorId,
697
+ std::move (group),
698
+ marker,
699
+ partialSched);
703
700
}
704
701
}
705
702
}
706
703
704
+ } // namespace
705
+
706
+ // Promote at the positions of the thread specific markers.
707
+ void promoteToRegistersBelowThreads (MappedScop& mscop, size_t nRegisters) {
708
+ using namespace tc ::polyhedral::detail;
709
+
710
+ auto & scop = mscop.scop ();
711
+ auto root = scop.scheduleRoot ();
712
+ auto markers = findThreadSpecificMarkers (root);
713
+
714
+ for (auto marker : markers) {
715
+ promoteToRegistersBelow (mscop, marker);
716
+ }
717
+ }
718
+
707
719
} // namespace polyhedral
708
720
} // namespace tc
0 commit comments