@@ -124,14 +124,6 @@ void MappedScop::mapRemaining(detail::ScheduleTree* tree, size_t nMapped) {
124
124
auto filter = makeFixRemainingZeroFilter (domain, ids);
125
125
auto mapping = detail::ScheduleTree::makeMappingFilter (filter, ids);
126
126
insertNodeAbove (root, tree, std::move (mapping));
127
-
128
- for (size_t i = nMapped; i < nToMap; ++i) {
129
- if (MappingTypeId::makeId (i) == mapping::ThreadId::x ()) {
130
- threadIdxXScheduleDepthState.emplace_back (std::make_pair (
131
- activeDomainPoints (schedule (), tree),
132
- tree->scheduleDepth (schedule ())));
133
- }
134
- }
135
127
}
136
128
137
129
// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -161,6 +153,7 @@ void MappedScop::mapToBlocksAndScaleBand(
161
153
* Given a node in the schedule tree of a mapped scop,
162
154
* insert a mapping filter underneath (if needed) that fixes
163
155
* the remaining thread identifiers starting at "begin" to zero.
156
+ * Add a marker underneath that marks the subtree that is thread specific.
164
157
*/
165
158
void fixThreadsBelow (
166
159
MappedScop& mscop,
@@ -173,6 +166,9 @@ void fixThreadsBelow(
173
166
174
167
auto band = detail::ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
175
168
auto bandTree = insertNodeBelow (tree, std::move (band));
169
+ auto ctx = tree->ctx_ ;
170
+ insertNodeBelow (
171
+ bandTree, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
176
172
mscop.mapRemaining <mapping::ThreadId>(bandTree, begin);
177
173
}
178
174
@@ -338,8 +334,29 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
338
334
return 0 ;
339
335
}
340
336
341
- auto nMappedThreads =
342
- std::min (numThreads.view .size (), static_cast <size_t >(nCanMap));
337
+ auto nMappedThreads = nCanMap;
338
+ if (nMappedThreads > numThreads.view .size ()) {
339
+ // Split band such that mapping filters get inserted
340
+ // right above the first member mapped to a thread identifier.
341
+ nMappedThreads = numThreads.view .size ();
342
+ bandSplit (scop_->scheduleRoot (), band, nCanMap - nMappedThreads);
343
+ auto child = band->child ({0 });
344
+ if (isReduction) {
345
+ // Update reductionBandUpdates_ such that splitOutReductionAndInsertSyncs
346
+ // can find the information it needs.
347
+ reductionBandUpdates_.emplace (child, reductionBandUpdates_.at (band));
348
+ reductionBandUpdates_.erase (band);
349
+ }
350
+ band = child;
351
+ bandNode = band->elemAs <ScheduleTreeElemBand>();
352
+ }
353
+
354
+ if (nMappedThreads < bandNode->nMember ()) {
355
+ bandSplit (scop_->scheduleRoot (), band, nMappedThreads);
356
+ }
357
+
358
+ auto ctx = band->ctx_ ;
359
+ insertNodeBelow (band, detail::ScheduleTree::makeThreadSpecificMarker (ctx));
343
360
344
361
CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
345
362
CHECK_LE (nMappedThreads, 3 ) << " mapping to too many threads" ;
@@ -348,20 +365,16 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
348
365
// from thread x.
349
366
for (size_t i = 0 ; i < nMappedThreads; ++i) {
350
367
auto id = mapping::ThreadId::makeId (i);
351
- auto dim = nCanMap - 1 - i;
352
- if (id == mapping::ThreadId::x ()) {
353
- threadIdxXScheduleDepthState.emplace_back (std::make_pair (
354
- activeDomainPoints (schedule (), band),
355
- band->scheduleDepth (schedule ()) + dim));
356
- }
368
+ auto dim = nMappedThreads - 1 - i;
357
369
band = map (band, dim, id);
358
370
}
371
+ mapRemaining<mapping::ThreadId>(band, nMappedThreads);
359
372
360
373
if (isReduction) {
361
- splitOutReductionAndInsertSyncs (band, nCanMap - 1 );
374
+ splitOutReductionAndInsertSyncs (band, nMappedThreads - 1 );
362
375
}
363
376
364
- return nMappedThreads ;
377
+ return numThreads. view . size () ;
365
378
}
366
379
367
380
namespace {
@@ -450,9 +463,8 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
450
463
// because we cannot map parent bands anyway.
451
464
auto nMapped = mapToThreads (st);
452
465
if (nMapped > 0 ) {
453
- mapRemaining<mapping::ThreadId>(st, nMapped);
454
466
markUnroll (scop_->scheduleRoot (), st, unroll);
455
- return numThreads. view . size () ;
467
+ return nMapped ;
456
468
}
457
469
} else if (anyNonCoincidentMember (band)) {
458
470
// If children were mapped to threads, and this band has a non-coincident
@@ -633,7 +645,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
633
645
auto child = outerBand->child ({0 });
634
646
size_t numMappedInnerThreads =
635
647
mappedScop->mapInnermostBandsToThreads (child);
636
- mappedScop-> mapRemaining <mapping::ThreadId>(child , numMappedInnerThreads);
648
+ fixThreadsBelow (*mappedScop, outerBand , numMappedInnerThreads);
637
649
LOG_IF (INFO, FLAGS_debug_tc_mapper)
638
650
<< " After mapping to threads:" << std::endl
639
651
<< *mappedScop->schedule ();
@@ -677,7 +689,6 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
677
689
678
690
promoteGreedilyAtDepth (
679
691
*mappedScop,
680
- mappedScop->threadIdxXScheduleDepthState ,
681
692
std::min (band->nOuterCoincident (), mappedScop->numBlocks .view .size ()),
682
693
sharedMemorySize,
683
694
cudaOptions.proto ().unroll_copy_shared () &&
@@ -694,8 +705,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
694
705
695
706
// 8. Promote to registers below the loops mapped to threads.
696
707
if (cudaOptions.proto ().use_private_memory ()) {
697
- promoteToRegistersBelowThreads (
698
- mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
708
+ promoteToRegistersBelowThreads (mappedScop->scop (), -1ull );
699
709
}
700
710
701
711
// 9. Insert mapping context
0 commit comments