@@ -91,13 +91,26 @@ isl::union_set makeFixRemainingZeroFilter(
91
91
bool anyNonCoincidentMember (const detail::ScheduleTreeElemBand* band) {
92
92
return band->nOuterCoincident () < band->nMember ();
93
93
}
94
+
95
+ /*
96
+ * Return a reference to the mapping sizes
97
+ * for the mapping of type "MappingTypeId".
98
+ */
99
+ template <typename MappingTypeId>
100
+ const CudaDim& mappingSize (const MappedScop* mscop);
101
+ template <>
102
+ const CudaDim& mappingSize<mapping::BlockId>(const MappedScop* mscop) {
103
+ return mscop->numBlocks ;
104
+ }
105
+ template <>
106
+ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
107
+ return mscop->numThreads ;
108
+ }
94
109
} // namespace
95
110
96
111
template <typename MappingTypeId>
97
- void MappedScop::mapRemaining (
98
- detail::ScheduleTree* tree,
99
- size_t nMapped,
100
- size_t nToMap) {
112
+ void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
113
+ size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
101
114
if (nMapped >= nToMap) {
102
115
return ;
103
116
}
@@ -140,7 +153,7 @@ void MappedScop::mapToBlocksAndScaleBand(
140
153
for (size_t i = 0 ; i < nBlocksToMap; ++i) {
141
154
band = map (band, i, mapping::BlockId::makeId (i));
142
155
}
143
- mapRemaining<mapping::BlockId>(band, nBlocksToMap, numBlocks. view . size () );
156
+ mapRemaining<mapping::BlockId>(band, nBlocksToMap);
144
157
bandScale (band, tileSizes);
145
158
}
146
159
@@ -462,7 +475,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
462
475
// because we cannot map parent bands anyway.
463
476
auto nMapped = mapToThreads (st);
464
477
if (nMapped > 0 ) {
465
- mapRemaining<mapping::ThreadId>(st, nMapped, numThreads. view . size () );
478
+ mapRemaining<mapping::ThreadId>(st, nMapped);
466
479
markUnroll (scop_->scheduleRoot (), st, unroll);
467
480
return numThreads.view .size ();
468
481
}
@@ -645,8 +658,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
645
658
auto child = outerBand->child ({0 });
646
659
size_t numMappedInnerThreads =
647
660
mappedScop->mapInnermostBandsToThreads (child);
648
- mappedScop->mapRemaining <mapping::ThreadId>(
649
- child, numMappedInnerThreads, mappedScop->numThreads .view .size ());
661
+ mappedScop->mapRemaining <mapping::ThreadId>(child, numMappedInnerThreads);
650
662
LOG_IF (INFO, FLAGS_debug_tc_mapper)
651
663
<< " After mapping to threads:" << std::endl
652
664
<< *mappedScop->schedule ();
0 commit comments