@@ -91,13 +91,26 @@ isl::union_set makeFixRemainingZeroFilter(
91
91
bool anyNonCoincidentMember (const detail::ScheduleTreeElemBand* band) {
92
92
return band->nOuterCoincident () < band->nMember ();
93
93
}
94
+
95
+ /*
96
+ * Return a reference to the mapping sizes
97
+ * for the mapping of type "MappingTypeId".
98
+ */
99
+ template <typename MappingTypeId>
100
+ const CudaDim& mappingSize (const MappedScop* mscop);
101
+ template <>
102
+ const CudaDim& mappingSize<mapping::BlockId>(const MappedScop* mscop) {
103
+ return mscop->numBlocks ;
104
+ }
105
+ template <>
106
+ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
107
+ return mscop->numThreads ;
108
+ }
94
109
} // namespace
95
110
96
111
template <typename MappingTypeId>
97
- void MappedScop::mapRemaining (
98
- detail::ScheduleTree* tree,
99
- size_t nMapped,
100
- size_t nToMap) {
112
+ void MappedScop::mapRemaining (detail::ScheduleTree* tree, size_t nMapped) {
113
+ size_t nToMap = mappingSize<MappingTypeId>(this ).view .size ();
101
114
if (nMapped >= nToMap) {
102
115
return ;
103
116
}
@@ -140,52 +153,27 @@ void MappedScop::mapToBlocksAndScaleBand(
140
153
for (size_t i = 0 ; i < nBlocksToMap; ++i) {
141
154
band = map (band, i, mapping::BlockId::makeId (i));
142
155
}
143
- mapRemaining<mapping::BlockId>(band, nBlocksToMap, numBlocks. view . size () );
156
+ mapRemaining<mapping::BlockId>(band, nBlocksToMap);
144
157
bandScale (band, tileSizes);
145
158
}
146
159
147
160
/*
148
- * Given a filter node in the schedule tree of a mapped scop,
149
- * insert another filter underneath (if needed) that fixes
150
- * the thread identifiers in the range [ begin, end) to zero.
161
+ * Given a node in the schedule tree of a mapped scop,
162
+ * insert a mapping filter underneath (if needed) that fixes
163
+ * the remaining thread identifiers starting at " begin" to zero.
151
164
*/
152
- void fixThreadsBelowFilter (
165
+ void fixThreadsBelow (
153
166
MappedScop& mscop,
154
- detail::ScheduleTree* filterTree ,
155
- size_t begin,
156
- size_t end) {
167
+ detail::ScheduleTree* tree ,
168
+ size_t begin) {
169
+ size_t end = mscop. numThreads . view . size ();
157
170
if (begin == end) {
158
171
return ;
159
172
}
160
173
161
- std::unordered_set<mapping::ThreadId, mapping::ThreadId::Hash> ids;
162
- for (size_t i = begin; i < end; ++i) {
163
- ids.insert (mapping::ThreadId::makeId (i));
164
- }
165
- auto root = mscop.schedule ();
166
- auto domain = activeDomainPoints (root, filterTree);
167
- auto mappingFilter = makeFixRemainingZeroFilter (domain, ids);
168
- auto filter = filterTree->elemAs <detail::ScheduleTreeElemFilter>();
169
- CHECK (filter) << " Not a filter: " << *filter;
170
- // Active domain points will contain spaces for different statements
171
- // When inserting below a leaf filter, this would break the tightening
172
- // invariant that leaf mapping filters have a single space.
173
- // So we intersect with the universe set of the filter to only keep the
174
- // space for the legitimate statement.
175
- mappingFilter = mappingFilter & filter->filter_ .universe ();
176
- auto mapping = detail::ScheduleTree::makeMappingFilter (mappingFilter, ids);
177
- insertNodeBelow (filterTree, std::move (mapping));
178
-
179
- for (size_t i = begin; i < end; ++i) {
180
- if (mapping::ThreadId::makeId (i) == mapping::ThreadId::x ()) {
181
- // Mapping happened below filterTree, so we need points active for its
182
- // children. After insertion, filterTree is guaranteed to have at least
183
- // one child.
184
- mscop.threadIdxXScheduleDepthState .emplace_back (std::make_pair (
185
- activeDomainPoints (mscop.schedule (), filterTree->child ({0 })),
186
- filterTree->scheduleDepth (mscop.schedule ())));
187
- }
188
- }
174
+ auto band = detail::ScheduleTree::makeEmptyBand (mscop.scop ().scheduleRoot ());
175
+ auto bandTree = insertNodeBelow (tree, std::move (band));
176
+ mscop.mapRemaining <mapping::ThreadId>(bandTree, begin);
189
177
}
190
178
191
179
bool MappedScop::detectReductions (detail::ScheduleTree* tree) {
@@ -239,7 +227,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
239
227
if (!inits.is_empty ()) {
240
228
orderBefore (scop_->scheduleRoot (), tree, inits);
241
229
}
242
- reductionBandUpdates_.emplace (tree, Reduction (updateIds, reductionDim ));
230
+ reductionBandUpdates_.emplace (tree, Reduction (updateIds));
243
231
return true ;
244
232
}
245
233
@@ -261,11 +249,9 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
261
249
// mapped to threads.
262
250
auto reductionSchedule = reductionBand->mupa_ ;
263
251
auto nMember = reductionBand->nMember ();
264
- auto reductionDim = reductionBandUpdates_.at (st).reductionDim ;
265
- auto nMappedThreads =
266
- std::min (numThreads.view .size (), reductionBand->nOuterCoincident () + 1 );
252
+ auto reductionDim = reductionBand->nOuterCoincident ();
253
+ auto nMappedThreads = std::min (numThreads.view .size (), reductionDim + 1 );
267
254
CHECK_GE (nMember, reductionDim);
268
- CHECK_GE (reductionDim + 1 , nMappedThreads);
269
255
reductionSchedule = reductionSchedule.drop_dims (
270
256
isl::dim_type::set, reductionDim + 1 , nMember - (reductionDim + 1 ));
271
257
reductionSchedule = reductionSchedule.drop_dims (
@@ -332,45 +318,37 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
332
318
return 0 ;
333
319
}
334
320
335
- size_t nMappedReductionThreads = 0 ;
336
- if (reductionBandUpdates_.count (band) == 1 ) {
337
- // A reduction is assumed to get mapped to threadIdx.x
338
- CHECK (reductionBandUpdates_.at (band).separated );
339
- auto reductionDim = reductionBandUpdates_.at (band).reductionDim ;
340
- threadIdxXScheduleDepthState.emplace_back (std::make_pair (
341
- activeDomainPoints (schedule (), band),
342
- band->scheduleDepth (schedule ()) + reductionDim));
343
- band = map (band, reductionDim, mapping::ThreadId::x ());
344
- nMappedReductionThreads = 1 ;
345
- }
346
-
347
321
// With current isl scheduler, if coincident dimensions exist in a band,
348
322
// they are outermost.
349
323
// If a band has more than 3 coincident dimensions,
350
324
// then the innermost of those will be used.
351
- auto nOuterCoincident = bandNode->nOuterCoincident ();
352
- if (nOuterCoincident < 1 ) {
353
- return nMappedReductionThreads;
325
+ auto nCanMap = bandNode->nOuterCoincident ();
326
+
327
+ auto isReduction = reductionBandUpdates_.count (band) == 1 ;
328
+ // If the band has a detected reduction, then the first member
329
+ // after the coincident members is the reduction member and
330
+ // this member has to be mapped as well.
331
+ // In particular, it will get mapped to threadIdx.x
332
+ if (isReduction) {
333
+ CHECK (reductionBandUpdates_.at (band).separated );
334
+ nCanMap++;
354
335
}
355
336
356
- auto nMappedThreads = std::min (
357
- numThreads.view .size () - nMappedReductionThreads,
358
- static_cast <size_t >(nOuterCoincident));
359
-
360
- // Immediately return if mapping to one thread dimension only was requested
361
- // and a reduction was already mapped. (Note that reduction is detected only
362
- // if there are not enough outer coincident members, 0 in this case).
363
- if (nMappedThreads == 0 ) {
364
- return nMappedReductionThreads;
337
+ if (nCanMap < 1 ) {
338
+ return 0 ;
365
339
}
366
- CHECK_LE (nMappedThreads, 3 - nMappedReductionThreads)
367
- << " mapping to too many threads" ;
340
+
341
+ auto nMappedThreads =
342
+ std::min (numThreads.view .size (), static_cast <size_t >(nCanMap));
343
+
344
+ CHECK_GT (nMappedThreads, 0 ) << " not mapping to threads" ;
345
+ CHECK_LE (nMappedThreads, 3 ) << " mapping to too many threads" ;
368
346
369
347
// Map the coincident dimensions to threads starting from the innermost and
370
- // from thread x unless it was already mapped to a reduction .
348
+ // from thread x.
371
349
for (size_t i = 0 ; i < nMappedThreads; ++i) {
372
- auto id = mapping::ThreadId::makeId (nMappedReductionThreads + i);
373
- auto dim = nOuterCoincident - 1 - i;
350
+ auto id = mapping::ThreadId::makeId (i);
351
+ auto dim = nCanMap - 1 - i;
374
352
if (id == mapping::ThreadId::x ()) {
375
353
threadIdxXScheduleDepthState.emplace_back (std::make_pair (
376
354
activeDomainPoints (schedule (), band),
@@ -379,7 +357,11 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
379
357
band = map (band, dim, id);
380
358
}
381
359
382
- return nMappedReductionThreads + nMappedThreads;
360
+ if (isReduction) {
361
+ splitOutReductionAndInsertSyncs (band, nCanMap - 1 );
362
+ }
363
+
364
+ return nMappedThreads;
383
365
}
384
366
385
367
namespace {
@@ -419,21 +401,16 @@ bool hasOuterSequentialMember(
419
401
// If any separation is needed for mapping reductions to full blocks,
420
402
// then do so first.
421
403
//
422
- // If "st" has multiple children, then make sure they are mapped
423
- // to the same number of thread identifiers by fixing those
424
- // that are originally mapped to fewer identifiers to value zero
425
- // for the remaining thread identifiers.
404
+ // If "st" has multiple children and if any of those children
405
+ // is mapped to threads, then make sure the other children
406
+ // are also mapped to threads, by fixing the thread identifiers to value zero.
426
407
// If, moreover, "st" is a sequence node and at least one of its
427
408
// children is mapped to threads, then introduce synchronization
428
409
// before and after children that are mapped to threads.
429
410
// Also add synchronization between the last child and
430
411
// the next iteration of the first child if there may be such
431
412
// a next iteration that is not already covered by synchronization
432
413
// on an outer node.
433
- // If any synchronization is introduced, then the mapping
434
- // to threads needs to be completed to all thread ids
435
- // because the synchronization needs to be introduced outside
436
- // any mapping to threads.
437
414
size_t MappedScop::mapInnermostBandsToThreads (detail::ScheduleTree* st) {
438
415
if (needReductionSeparation (st)) {
439
416
st = separateReduction (st);
@@ -447,11 +424,10 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
447
424
auto n = nChildren > 0 ? *std::max_element (nInner.begin (), nInner.end ()) : 0 ;
448
425
if (nChildren > 1 ) {
449
426
auto needSync = st->elemAs <detail::ScheduleTreeElemSequence>() && n > 0 ;
450
- if (needSync) {
451
- n = numThreads.view .size ();
452
- }
453
- for (size_t i = 0 ; i < nChildren; ++i) {
454
- fixThreadsBelowFilter (*this , children[i], nInner[i], n);
427
+ if (n > 0 ) {
428
+ for (size_t i = 0 ; i < nChildren; ++i) {
429
+ fixThreadsBelow (*this , children[i], nInner[i]);
430
+ }
455
431
}
456
432
if (needSync) {
457
433
auto outer = hasOuterSequentialMember (scop_->scheduleRoot (), st);
@@ -474,7 +450,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
474
450
// because we cannot map parent bands anyway.
475
451
auto nMapped = mapToThreads (st);
476
452
if (nMapped > 0 ) {
477
- mapRemaining<mapping::ThreadId>(st, nMapped, numThreads. view . size () );
453
+ mapRemaining<mapping::ThreadId>(st, nMapped);
478
454
markUnroll (scop_->scheduleRoot (), st, unroll);
479
455
return numThreads.view .size ();
480
456
}
@@ -594,19 +570,16 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
594
570
mappedScopForCodegen->numThreads );
595
571
}
596
572
597
- // Split out reduction loops into separate bands and insert reduction
598
- // synchronizations outside those bands.
599
- void MappedScop::splitOutReductionsAndInsertSyncs () {
573
+ // Split out reduction member at position "dim" in "band" and
574
+ // insert reduction synchronizations outside this split off band.
575
+ void MappedScop::splitOutReductionAndInsertSyncs (
576
+ detail::ScheduleTree* band,
577
+ int dim) {
600
578
using namespace polyhedral ::detail;
601
579
602
- for (auto bandUpdate : reductionBandUpdates_) {
603
- auto tree = bandSplitOut (
604
- scop_->scheduleRoot (),
605
- const_cast <ScheduleTree*>(bandUpdate.first ),
606
- bandUpdate.second .reductionDim );
607
- for (auto updateId : bandUpdate.second .ids ) {
608
- scop_->insertReductionSync1D (tree, updateId);
609
- }
580
+ auto tree = bandSplitOut (scop_->scheduleRoot (), band, dim);
581
+ for (auto updateId : reductionBandUpdates_.at (band).ids ) {
582
+ scop_->insertReductionSync1D (tree, updateId);
610
583
}
611
584
}
612
585
@@ -660,8 +633,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
660
633
auto child = outerBand->child ({0 });
661
634
size_t numMappedInnerThreads =
662
635
mappedScop->mapInnermostBandsToThreads (child);
663
- mappedScop->mapRemaining <mapping::ThreadId>(
664
- child, numMappedInnerThreads, mappedScop->numThreads .view .size ());
636
+ mappedScop->mapRemaining <mapping::ThreadId>(child, numMappedInnerThreads);
665
637
LOG_IF (INFO, FLAGS_debug_tc_mapper)
666
638
<< " After mapping to threads:" << std::endl
667
639
<< *mappedScop->schedule ();
@@ -673,13 +645,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
673
645
LOG_IF (INFO, FLAGS_debug_tc_mapper) << " After mapping to blocks:" << std::endl
674
646
<< *mappedScop->schedule ();
675
647
676
- // 7. Insert reduction synchronizations if necessary.
677
- mappedScop->splitOutReductionsAndInsertSyncs ();
678
- LOG_IF (INFO, FLAGS_debug_tc_mapper)
679
- << " After inserting reduction synchronization:" << std::endl
680
- << *mappedScop->schedule ();
681
-
682
- // 8. Promote to shared memory below the loops mapped to blocks.
648
+ // 7. Promote to shared memory below the loops mapped to blocks.
683
649
// This may split the outer band, so find the new outer band after promotion.
684
650
if (cudaOptions.proto ().use_shared_memory ()) {
685
651
size_t sharedMemorySize = cudaOptions.proto ().has_max_shared_memory ()
@@ -726,13 +692,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
726
692
}
727
693
}
728
694
729
- // 9 . Promote to registers below the loops mapped to threads.
695
+ // 8 . Promote to registers below the loops mapped to threads.
730
696
if (cudaOptions.proto ().use_private_memory ()) {
731
697
promoteToRegistersBelowThreads (
732
698
mappedScop->scop (), mappedScop->threadIdxXScheduleDepthState , -1ull );
733
699
}
734
700
735
- // 10 . Insert mapping context
701
+ // 9 . Insert mapping context
736
702
mappedScop->insertMappingContext ();
737
703
LOG_IF (INFO, FLAGS_debug_tc_mapper)
738
704
<< " After outerBlockInnerThread strategy:" << std::endl
0 commit comments