Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit fadace1

Browse files
Merge pull request #343 from facebookresearch/pr/clean-up
clean-ups in the mapper
2 parents c5c1066 + fe8da9b commit fadace1

File tree

7 files changed

+112
-158
lines changed

7 files changed

+112
-158
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 77 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,26 @@ isl::union_set makeFixRemainingZeroFilter(
9191
bool anyNonCoincidentMember(const detail::ScheduleTreeElemBand* band) {
9292
return band->nOuterCoincident() < band->nMember();
9393
}
94+
95+
/*
96+
* Return a reference to the mapping sizes
97+
* for the mapping of type "MappingTypeId".
98+
*/
99+
template <typename MappingTypeId>
100+
const CudaDim& mappingSize(const MappedScop* mscop);
101+
template <>
102+
const CudaDim& mappingSize<mapping::BlockId>(const MappedScop* mscop) {
103+
return mscop->numBlocks;
104+
}
105+
template <>
106+
const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
107+
return mscop->numThreads;
108+
}
94109
} // namespace
95110

96111
template <typename MappingTypeId>
97-
void MappedScop::mapRemaining(
98-
detail::ScheduleTree* tree,
99-
size_t nMapped,
100-
size_t nToMap) {
112+
void MappedScop::mapRemaining(detail::ScheduleTree* tree, size_t nMapped) {
113+
size_t nToMap = mappingSize<MappingTypeId>(this).view.size();
101114
if (nMapped >= nToMap) {
102115
return;
103116
}
@@ -140,52 +153,27 @@ void MappedScop::mapToBlocksAndScaleBand(
140153
for (size_t i = 0; i < nBlocksToMap; ++i) {
141154
band = map(band, i, mapping::BlockId::makeId(i));
142155
}
143-
mapRemaining<mapping::BlockId>(band, nBlocksToMap, numBlocks.view.size());
156+
mapRemaining<mapping::BlockId>(band, nBlocksToMap);
144157
bandScale(band, tileSizes);
145158
}
146159

147160
/*
148-
* Given a filter node in the schedule tree of a mapped scop,
149-
* insert another filter underneath (if needed) that fixes
150-
* the thread identifiers in the range [begin, end) to zero.
161+
* Given a node in the schedule tree of a mapped scop,
162+
* insert a mapping filter underneath (if needed) that fixes
163+
* the remaining thread identifiers starting at "begin" to zero.
151164
*/
152-
void fixThreadsBelowFilter(
165+
void fixThreadsBelow(
153166
MappedScop& mscop,
154-
detail::ScheduleTree* filterTree,
155-
size_t begin,
156-
size_t end) {
167+
detail::ScheduleTree* tree,
168+
size_t begin) {
169+
size_t end = mscop.numThreads.view.size();
157170
if (begin == end) {
158171
return;
159172
}
160173

161-
std::unordered_set<mapping::ThreadId, mapping::ThreadId::Hash> ids;
162-
for (size_t i = begin; i < end; ++i) {
163-
ids.insert(mapping::ThreadId::makeId(i));
164-
}
165-
auto root = mscop.schedule();
166-
auto domain = activeDomainPoints(root, filterTree);
167-
auto mappingFilter = makeFixRemainingZeroFilter(domain, ids);
168-
auto filter = filterTree->elemAs<detail::ScheduleTreeElemFilter>();
169-
CHECK(filter) << "Not a filter: " << *filter;
170-
// Active domain points will contain spaces for different statements
171-
// When inserting below a leaf filter, this would break the tightening
172-
// invariant that leaf mapping filters have a single space.
173-
// So we intersect with the universe set of the filter to only keep the
174-
// space for the legitimate statement.
175-
mappingFilter = mappingFilter & filter->filter_.universe();
176-
auto mapping = detail::ScheduleTree::makeMappingFilter(mappingFilter, ids);
177-
insertNodeBelow(filterTree, std::move(mapping));
178-
179-
for (size_t i = begin; i < end; ++i) {
180-
if (mapping::ThreadId::makeId(i) == mapping::ThreadId::x()) {
181-
// Mapping happened below filterTree, so we need points active for its
182-
// children. After insertion, filterTree is guaranteed to have at least
183-
// one child.
184-
mscop.threadIdxXScheduleDepthState.emplace_back(std::make_pair(
185-
activeDomainPoints(mscop.schedule(), filterTree->child({0})),
186-
filterTree->scheduleDepth(mscop.schedule())));
187-
}
188-
}
174+
auto band = detail::ScheduleTree::makeEmptyBand(mscop.scop().scheduleRoot());
175+
auto bandTree = insertNodeBelow(tree, std::move(band));
176+
mscop.mapRemaining<mapping::ThreadId>(bandTree, begin);
189177
}
190178

191179
bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
@@ -239,7 +227,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
239227
if (!inits.is_empty()) {
240228
orderBefore(scop_->scheduleRoot(), tree, inits);
241229
}
242-
reductionBandUpdates_.emplace(tree, Reduction(updateIds, reductionDim));
230+
reductionBandUpdates_.emplace(tree, Reduction(updateIds));
243231
return true;
244232
}
245233

@@ -261,11 +249,9 @@ isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
261249
// mapped to threads.
262250
auto reductionSchedule = reductionBand->mupa_;
263251
auto nMember = reductionBand->nMember();
264-
auto reductionDim = reductionBandUpdates_.at(st).reductionDim;
265-
auto nMappedThreads =
266-
std::min(numThreads.view.size(), reductionBand->nOuterCoincident() + 1);
252+
auto reductionDim = reductionBand->nOuterCoincident();
253+
auto nMappedThreads = std::min(numThreads.view.size(), reductionDim + 1);
267254
CHECK_GE(nMember, reductionDim);
268-
CHECK_GE(reductionDim + 1, nMappedThreads);
269255
reductionSchedule = reductionSchedule.drop_dims(
270256
isl::dim_type::set, reductionDim + 1, nMember - (reductionDim + 1));
271257
reductionSchedule = reductionSchedule.drop_dims(
@@ -332,45 +318,37 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
332318
return 0;
333319
}
334320

335-
size_t nMappedReductionThreads = 0;
336-
if (reductionBandUpdates_.count(band) == 1) {
337-
// A reduction is assumed to get mapped to threadIdx.x
338-
CHECK(reductionBandUpdates_.at(band).separated);
339-
auto reductionDim = reductionBandUpdates_.at(band).reductionDim;
340-
threadIdxXScheduleDepthState.emplace_back(std::make_pair(
341-
activeDomainPoints(schedule(), band),
342-
band->scheduleDepth(schedule()) + reductionDim));
343-
band = map(band, reductionDim, mapping::ThreadId::x());
344-
nMappedReductionThreads = 1;
345-
}
346-
347321
// With current isl scheduler, if coincident dimensions exist in a band,
348322
// they are outermost.
349323
// If a band has more than 3 coincident dimensions,
350324
// then the innermost of those will be used.
351-
auto nOuterCoincident = bandNode->nOuterCoincident();
352-
if (nOuterCoincident < 1) {
353-
return nMappedReductionThreads;
325+
auto nCanMap = bandNode->nOuterCoincident();
326+
327+
auto isReduction = reductionBandUpdates_.count(band) == 1;
328+
// If the band has a detected reduction, then the first member
329+
// after the coincident members is the reduction member and
330+
// this member has to be mapped as well.
331+
// In particular, it will get mapped to threadIdx.x
332+
if (isReduction) {
333+
CHECK(reductionBandUpdates_.at(band).separated);
334+
nCanMap++;
354335
}
355336

356-
auto nMappedThreads = std::min(
357-
numThreads.view.size() - nMappedReductionThreads,
358-
static_cast<size_t>(nOuterCoincident));
359-
360-
// Immediately return if mapping to one thread dimension only was requested
361-
// and a reduction was already mapped. (Note that reduction is detected only
362-
// if there are not enough outer coincident members, 0 in this case).
363-
if (nMappedThreads == 0) {
364-
return nMappedReductionThreads;
337+
if (nCanMap < 1) {
338+
return 0;
365339
}
366-
CHECK_LE(nMappedThreads, 3 - nMappedReductionThreads)
367-
<< "mapping to too many threads";
340+
341+
auto nMappedThreads =
342+
std::min(numThreads.view.size(), static_cast<size_t>(nCanMap));
343+
344+
CHECK_GT(nMappedThreads, 0) << "not mapping to threads";
345+
CHECK_LE(nMappedThreads, 3) << "mapping to too many threads";
368346

369347
// Map the coincident dimensions to threads starting from the innermost and
370-
// from thread x unless it was already mapped to a reduction.
348+
// from thread x.
371349
for (size_t i = 0; i < nMappedThreads; ++i) {
372-
auto id = mapping::ThreadId::makeId(nMappedReductionThreads + i);
373-
auto dim = nOuterCoincident - 1 - i;
350+
auto id = mapping::ThreadId::makeId(i);
351+
auto dim = nCanMap - 1 - i;
374352
if (id == mapping::ThreadId::x()) {
375353
threadIdxXScheduleDepthState.emplace_back(std::make_pair(
376354
activeDomainPoints(schedule(), band),
@@ -379,7 +357,11 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
379357
band = map(band, dim, id);
380358
}
381359

382-
return nMappedReductionThreads + nMappedThreads;
360+
if (isReduction) {
361+
splitOutReductionAndInsertSyncs(band, nCanMap - 1);
362+
}
363+
364+
return nMappedThreads;
383365
}
384366

385367
namespace {
@@ -419,21 +401,16 @@ bool hasOuterSequentialMember(
419401
// If any separation is needed for mapping reductions to full blocks,
420402
// then do so first.
421403
//
422-
// If "st" has multiple children, then make sure they are mapped
423-
// to the same number of thread identifiers by fixing those
424-
// that are originally mapped to fewer identifiers to value zero
425-
// for the remaining thread identifiers.
404+
// If "st" has multiple children and if any of those children
405+
// is mapped to threads, then make sure the other children
406+
// are also mapped to threads, by fixing the thread identifiers to value zero.
426407
// If, moreover, "st" is a sequence node and at least one of its
427408
// children is mapped to threads, then introduce synchronization
428409
// before and after children that are mapped to threads.
429410
// Also add synchronization between the last child and
430411
// the next iteration of the first child if there may be such
431412
// a next iteration that is not already covered by synchronization
432413
// on an outer node.
433-
// If any synchronization is introduced, then the mapping
434-
// to threads needs to be completed to all thread ids
435-
// because the synchronization needs to be introduced outside
436-
// any mapping to threads.
437414
size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
438415
if (needReductionSeparation(st)) {
439416
st = separateReduction(st);
@@ -447,11 +424,10 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
447424
auto n = nChildren > 0 ? *std::max_element(nInner.begin(), nInner.end()) : 0;
448425
if (nChildren > 1) {
449426
auto needSync = st->elemAs<detail::ScheduleTreeElemSequence>() && n > 0;
450-
if (needSync) {
451-
n = numThreads.view.size();
452-
}
453-
for (size_t i = 0; i < nChildren; ++i) {
454-
fixThreadsBelowFilter(*this, children[i], nInner[i], n);
427+
if (n > 0) {
428+
for (size_t i = 0; i < nChildren; ++i) {
429+
fixThreadsBelow(*this, children[i], nInner[i]);
430+
}
455431
}
456432
if (needSync) {
457433
auto outer = hasOuterSequentialMember(scop_->scheduleRoot(), st);
@@ -474,7 +450,7 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
474450
// because we cannot map parent bands anyway.
475451
auto nMapped = mapToThreads(st);
476452
if (nMapped > 0) {
477-
mapRemaining<mapping::ThreadId>(st, nMapped, numThreads.view.size());
453+
mapRemaining<mapping::ThreadId>(st, nMapped);
478454
markUnroll(scop_->scheduleRoot(), st, unroll);
479455
return numThreads.view.size();
480456
}
@@ -594,19 +570,16 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
594570
mappedScopForCodegen->numThreads);
595571
}
596572

597-
// Split out reduction loops into separate bands and insert reduction
598-
// synchronizations outside those bands.
599-
void MappedScop::splitOutReductionsAndInsertSyncs() {
573+
// Split out reduction member at position "dim" in "band" and
574+
// insert reduction synchronizations outside this split off band.
575+
void MappedScop::splitOutReductionAndInsertSyncs(
576+
detail::ScheduleTree* band,
577+
int dim) {
600578
using namespace polyhedral::detail;
601579

602-
for (auto bandUpdate : reductionBandUpdates_) {
603-
auto tree = bandSplitOut(
604-
scop_->scheduleRoot(),
605-
const_cast<ScheduleTree*>(bandUpdate.first),
606-
bandUpdate.second.reductionDim);
607-
for (auto updateId : bandUpdate.second.ids) {
608-
scop_->insertReductionSync1D(tree, updateId);
609-
}
580+
auto tree = bandSplitOut(scop_->scheduleRoot(), band, dim);
581+
for (auto updateId : reductionBandUpdates_.at(band).ids) {
582+
scop_->insertReductionSync1D(tree, updateId);
610583
}
611584
}
612585

@@ -660,8 +633,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
660633
auto child = outerBand->child({0});
661634
size_t numMappedInnerThreads =
662635
mappedScop->mapInnermostBandsToThreads(child);
663-
mappedScop->mapRemaining<mapping::ThreadId>(
664-
child, numMappedInnerThreads, mappedScop->numThreads.view.size());
636+
mappedScop->mapRemaining<mapping::ThreadId>(child, numMappedInnerThreads);
665637
LOG_IF(INFO, FLAGS_debug_tc_mapper)
666638
<< "After mapping to threads:" << std::endl
667639
<< *mappedScop->schedule();
@@ -673,13 +645,7 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
673645
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After mapping to blocks:" << std::endl
674646
<< *mappedScop->schedule();
675647

676-
// 7. Insert reduction synchronizations if necessary.
677-
mappedScop->splitOutReductionsAndInsertSyncs();
678-
LOG_IF(INFO, FLAGS_debug_tc_mapper)
679-
<< "After inserting reduction synchronization:" << std::endl
680-
<< *mappedScop->schedule();
681-
682-
// 8. Promote to shared memory below the loops mapped to blocks.
648+
// 7. Promote to shared memory below the loops mapped to blocks.
683649
// This may split the outer band, so find the new outer band after promotion.
684650
if (cudaOptions.proto().use_shared_memory()) {
685651
size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory()
@@ -726,13 +692,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
726692
}
727693
}
728694

729-
// 9. Promote to registers below the loops mapped to threads.
695+
// 8. Promote to registers below the loops mapped to threads.
730696
if (cudaOptions.proto().use_private_memory()) {
731697
promoteToRegistersBelowThreads(
732698
mappedScop->scop(), mappedScop->threadIdxXScheduleDepthState, -1ull);
733699
}
734700

735-
// 10. Insert mapping context
701+
// 9. Insert mapping context
736702
mappedScop->insertMappingContext();
737703
LOG_IF(INFO, FLAGS_debug_tc_mapper)
738704
<< "After outerBlockInnerThread strategy:" << std::endl

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ class MappedScop {
102102
}
103103

104104
// Given that "nMapped" identifiers of type "MappingTypeId" have already
105-
// been mapped, map the remaining ones (up to "nToMap") to zero
105+
// been mapped, map the remaining ones to zero
106106
// for all statement instances.
107107
template <typename MappingTypeId>
108-
void mapRemaining(detail::ScheduleTree* tree, size_t nMapped, size_t nToMap);
108+
void mapRemaining(detail::ScheduleTree* tree, size_t nMapped);
109109

110110
// Fix the values of the specified parameters in the context
111111
// to the corresponding specified values.
@@ -155,13 +155,16 @@ class MappedScop {
155155
// The remaining parts, if any, are no longer considered for replacement
156156
// by a library call.
157157
detail::ScheduleTree* separateReduction(detail::ScheduleTree* band);
158-
// Split out reduction bands and insert reduction synchronizations.
159-
void splitOutReductionsAndInsertSyncs();
158+
// Split out reduction member at position "dim" in "band" and
159+
// insert reduction synchronizations.
160+
void splitOutReductionAndInsertSyncs(detail::ScheduleTree* band, int dim);
160161
// Map "band" to thread identifiers using as many blockSizes values as outer
161-
// coincident dimensions, unroll band members that execute at most "unroll"
162-
// instances and return the number of mapped thread identifiers.
162+
// coincident dimensions (plus reduction dimension, if any),
163+
// insert synchronization in case of a reduction, and
164+
// return the number of mapped thread identifiers.
163165
size_t mapToThreads(detail::ScheduleTree* band);
164-
// Map innermost bands to thread identifiers and
166+
// Map innermost bands to thread identifiers,
167+
// inserting synchronization in case of a reduction, and
165168
// return the number of mapped thread identifiers.
166169
size_t mapInnermostBandsToThreads(detail::ScheduleTree* st);
167170

@@ -184,14 +187,11 @@ class MappedScop {
184187
// Information about a detected reduction that can potentially
185188
// be mapped to a library call.
186189
struct Reduction {
187-
Reduction(std::vector<isl::id> ids, size_t index)
188-
: ids(ids), separated(false), reductionDim(index) {}
190+
Reduction(std::vector<isl::id> ids) : ids(ids), separated(false) {}
189191
// The statement identifiers of the reduction update statements.
190192
std::vector<isl::id> ids;
191193
// Has the reduction been separated out as a full block?
192194
bool separated;
193-
// Index of the band member in which the reduction was detected.
194-
size_t reductionDim;
195195
};
196196
// Map isolated innermost reduction band members to information
197197
// about the detected reduction.

0 commit comments

Comments
 (0)