27
27
#include < algorithm>
28
28
#include < numeric>
29
29
#include < sstream>
30
+ #include < type_traits>
30
31
31
32
namespace tc {
32
33
namespace polyhedral {
@@ -128,6 +129,21 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
128
129
}
129
130
}
130
131
132
+ /*
133
+ * Starting from the root, find all thread specific markers. Use
134
+ * DFSPreorder to make sure order is specified and consistent for tests.
135
+ */
136
+ template <typename T>
137
+ std::vector<T> findThreadSpecificMarkers (T root) {
138
+ using namespace tc ::polyhedral::detail;
139
+ static_assert (
140
+ std::is_convertible<T, const ScheduleTree*>::value,
141
+ " expecting ScheduleTree" );
142
+
143
+ return ScheduleTree::collectDFSPreorder (
144
+ root, ScheduleTreeType::ThreadSpecificMarker);
145
+ }
146
+
131
147
/*
132
148
* Transform schedule bands into a union_map.
133
149
* Takes all partial schedules at leaves as MUPAs (without accounting for
@@ -555,51 +571,28 @@ void promoteGreedilyAtDepth(
555
571
mapCopiesToThreads (mscop, unrollCopies);
556
572
}
557
573
558
- // Assuming the mapping to threads happens in inverse order, i.e. the innermost
559
- // loop is mapped to thread x, promote below that depth.
560
- void promoteToRegistersBelowThreads (
561
- Scop& scop,
562
- const ThreadIdxXScheduleDepthState& threadIdxXScheduleDepthState,
563
- size_t nRegisters) {
574
+ // Promote at the positions of the thread specific markers.
575
+ void promoteToRegistersBelowThreads (Scop& scop, size_t nRegisters) {
564
576
using namespace tc ::polyhedral::detail;
565
577
566
578
auto root = scop.scheduleRoot ();
567
579
568
580
auto fullSched = fullSchedule (root);
569
- for (const auto & kvp : threadIdxXScheduleDepthState) {
570
- auto depth = kvp.second + 1 ;
571
- auto subdomain = kvp.first ;
572
-
573
- // Collect all bands where a member is located at the given depth.
574
- auto bands = bandsContainingScheduleDepth (root, depth);
575
- // We may have no band members mapped to thread x in case when we
576
- // force-mapped everything to one thread.
577
- if (bands.size () == 0 ) {
578
- continue ;
579
- }
580
-
581
- // Keep only those bands for which this depth was recorded.
582
- std::function<bool (ScheduleTree*)> keepActive =
583
- [root, subdomain](const ScheduleTree* tree) {
584
- isl::union_set active = activeDomainPoints (root, tree);
585
- return !active.intersect (subdomain).is_empty ();
586
- };
587
- bands = functional::Filter (keepActive, bands);
588
-
589
- // Make sure the band ends at thread x depth so we can promote below it.
590
- bands = bandsSplitAfterDepth (bands, root, depth);
581
+ {
582
+ auto markers = findThreadSpecificMarkers (root);
591
583
592
- for (auto band : bands ) {
584
+ for (auto marker : markers ) {
593
585
// Find out how many threads are actually mapped. Active domain points
594
586
// will involve all mapping parameters when we take them below the
595
587
// mapping. Skip mapping parameters obviously mapped to 0, because they
596
588
// do not correspond to band members that should be fixed to obtain
597
589
// per-thread-group access relations.
598
- auto points = activeDomainPoints (root, band );
599
- auto partialSched = partialSchedule (root, band );
590
+ auto points = activeDomainPoints (root, marker );
591
+ auto partialSched = prefixSchedule (root, marker );
600
592
// Pure affine schedule without (mapping) filters.
601
- auto partialSchedMupa = partialScheduleMupa (root, band );
593
+ auto partialSchedMupa = prefixScheduleMupa (root, marker );
602
594
595
+ auto depth = marker->scheduleDepth (root);
603
596
size_t nMappedThreads = 0 ;
604
597
for (unsigned j = 0 ; j < points.dim (isl::dim_type::param); ++j) {
605
598
auto id = points.get_space ().get_dim_id (isl::dim_type::param, j);
@@ -616,7 +609,7 @@ void promoteToRegistersBelowThreads(
616
609
}
617
610
}
618
611
619
- auto groupMap = TensorReferenceGroup::accessedBySubtree (band , scop);
612
+ auto groupMap = TensorReferenceGroup::accessedBySubtree (marker , scop);
620
613
for (auto & tensorGroups : groupMap) {
621
614
auto tensorId = tensorGroups.first ;
622
615
@@ -642,7 +635,7 @@ void promoteToRegistersBelowThreads(
642
635
Scop::PromotedDecl::Kind::Register,
643
636
tensorId,
644
637
std::move (group),
645
- band ,
638
+ marker ,
646
639
partialSched);
647
640
}
648
641
}
0 commit comments