@@ -342,11 +342,25 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
342
342
return 0 ;
343
343
}
344
344
345
- auto nMappedThreads =
346
- std::min (numThreads.view .size (), static_cast <size_t >(nCanMap));
345
+ auto nMappedThreads = nCanMap;
346
+ if (nMappedThreads > numThreads.view .size ()) {
347
+ // Split band such that mapping filters get inserted
348
+ // right above the first member mapped to a thread identifier.
349
+ nMappedThreads = numThreads.view .size ();
350
+ bandSplit (scop_->scheduleRoot (), band, nCanMap - nMappedThreads);
351
+ auto child = band->child ({0 });
352
+ if (isReduction) {
353
+ // Update reductionBandUpdates_ such that splitOutReductionAndInsertSyncs
354
+ // can find the information it needs.
355
+ reductionBandUpdates_.emplace (child, reductionBandUpdates_.at (band));
356
+ reductionBandUpdates_.erase (band);
357
+ }
358
+ band = child;
359
+ bandNode = band->elemAs <ScheduleTreeElemBand>();
360
+ }
347
361
348
- if (nCanMap < bandNode->nMember ()) {
349
- bandSplit (scop_->scheduleRoot (), band, nCanMap );
362
+ if (nMappedThreads < bandNode->nMember ()) {
363
+ bandSplit (scop_->scheduleRoot (), band, nMappedThreads );
350
364
}
351
365
352
366
auto ctx = band->ctx_ ;
@@ -359,7 +373,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
359
373
// from thread x.
360
374
for (size_t i = 0 ; i < nMappedThreads; ++i) {
361
375
auto id = mapping::ThreadId::makeId (i);
362
- auto dim = nCanMap - 1 - i;
376
+ auto dim = nMappedThreads - 1 - i;
363
377
if (id == mapping::ThreadId::x ()) {
364
378
threadIdxXScheduleDepthState.emplace_back (std::make_pair (
365
379
activeDomainPoints (schedule (), band),
@@ -369,7 +383,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
369
383
}
370
384
371
385
if (isReduction) {
372
- splitOutReductionAndInsertSyncs (band, nCanMap - 1 );
386
+ splitOutReductionAndInsertSyncs (band, nMappedThreads - 1 );
373
387
}
374
388
375
389
return nMappedThreads;
0 commit comments