Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 5e170c2

Browse files
authored
Merge pull request #283 from facebookresearch/dont-split-reductions
Don't split reductions to remove nested thread mapping
2 parents 8ffd5cd + 497d3c4 commit 5e170c2

File tree

2 files changed

+118
-102
lines changed

2 files changed

+118
-102
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 106 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
197197
return found;
198198
}
199199

200+
// Only reductions that appear in permutable bands are mapped to threads.
201+
if (!band->permutable_) {
202+
return false;
203+
}
204+
200205
// For now, only support reductions with a sufficient number
201206
// of coincident outer band members for the remaining thread identifiers.
202207
auto nCoincident = band->nOuterCoincident();
@@ -225,65 +230,50 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
225230
if (!isReductionMember(member, updates, scop())) {
226231
return false;
227232
}
228-
auto reductionTree = bandSplitOut(scop_->scheduleRoot(), tree, reductionDim);
229233
// Order the init statements (if any) before the update statements
230234
// to ensure the band from which the reduction band has been split off
231235
// only contains update statements.
232236
// Note that this relies on the outer members being coincident.
233237
if (!inits.is_empty()) {
234238
orderBefore(scop_->scheduleRoot(), tree, inits);
235239
}
236-
reductionFromParent_.emplace(tree, reductionTree);
237-
reductionBandUpdates_.emplace(reductionTree, updateIds);
240+
reductionBandUpdates_.emplace(tree, Reduction(updateIds, reductionDim));
238241
return true;
239242
}
240243

241244
bool MappedScop::needReductionSeparation(const detail::ScheduleTree* st) {
242-
// It is the parent band of the reduction band that needs to be separated.
243-
if (reductionFromParent_.count(st) != 1) {
245+
if (reductionBandUpdates_.count(st) != 1) {
244246
return false;
245247
}
246-
st = reductionFromParent_.at(st);
247-
CHECK(reductionBandUpdates_.count(st) == 1);
248248
// No need to separate if already separated.
249249
return !reductionBandUpdates_.at(st).separated;
250250
}
251251

252252
isl::multi_union_pw_aff MappedScop::reductionMapSchedule(
253253
const detail::ScheduleTree* st) {
254-
CHECK(reductionFromParent_.count(st) == 1);
255-
auto parent = st;
256-
st = reductionFromParent_.at(st);
257254
CHECK(reductionBandUpdates_.count(st) == 1);
258-
259255
auto reductionBand = st->elemAs<detail::ScheduleTreeElemBand>();
260256
CHECK(reductionBand);
261-
// Start with the schedule of the reduction band (in last position).
262-
auto reductionSchedule = reductionBand->mupa_;
263257

264-
// Total size of returned schedule needs to be equal
265-
// to the number of thread identifiers.
266-
if (numThreads.view.size() > 1) {
267-
CHECK(parent != st);
268-
}
269-
// Prepend last members of parent band (if any).
270-
if (parent != st) {
271-
auto parentBand = parent->elemAs<detail::ScheduleTreeElemBand>();
272-
CHECK(parentBand);
273-
auto parentSchedule = parentBand->mupa_;
274-
auto nMember = parentBand->nMember();
275-
CHECK_GE(nMember, numThreads.view.size() - 1);
276-
parentSchedule = parentSchedule.drop_dims(
277-
isl::dim_type::set, 0, nMember - (numThreads.view.size() - 1));
278-
reductionSchedule = parentSchedule.flat_range_product(reductionSchedule);
279-
}
258+
// Drop band members following the reduction dimension and preceding those
259+
// mapped to threads.
260+
auto reductionSchedule = reductionBand->mupa_;
261+
auto nMember = reductionBand->nMember();
262+
auto reductionDim = reductionBandUpdates_.at(st).reductionDim;
263+
auto nMappedThreads =
264+
std::min(numThreads.view.size(), reductionBand->nOuterCoincident() + 1);
265+
CHECK_GE(nMember, reductionDim);
266+
CHECK_GE(reductionDim + 1, nMappedThreads);
267+
reductionSchedule = reductionSchedule.drop_dims(
268+
isl::dim_type::set, reductionDim + 1, nMember - (reductionDim + 1));
269+
reductionSchedule = reductionSchedule.drop_dims(
270+
isl::dim_type::set, 0, reductionDim - nMappedThreads + 1);
280271

281272
return reductionSchedule;
282273
}
283274

284275
detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
285-
CHECK(reductionFromParent_.count(st) == 1);
286-
auto reduction = reductionFromParent_.at(st);
276+
auto reduction = st;
287277
// This function either separates full blocks (if needed) or
288278
// disables the reduction handling.
289279
reductionBandUpdates_.at(reduction).separated = true;
@@ -331,59 +321,54 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
331321
return st->ancestor(root, 2);
332322
}
333323

334-
size_t MappedScop::mapToThreads(detail::ScheduleTree* band, size_t nInner) {
324+
size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
335325
using namespace tc::polyhedral::detail;
336326

337-
if (nInner >= numThreads.view.size()) {
338-
return nInner;
327+
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
328+
// Cannot map non-permutable bands.
329+
if (!bandNode->permutable_) {
330+
return 0;
339331
}
332+
333+
int nMappedReductionThreads = 0;
340334
if (reductionBandUpdates_.count(band) == 1) {
341335
// A reduction is assumed to get mapped to threadIdx.x
342-
if (nInner != 0) {
343-
reductionBandUpdates_.erase(band);
344-
return nInner;
345-
}
346336
CHECK(reductionBandUpdates_.at(band).separated);
347337
threadIdxXScheduleDepthState.emplace_back(std::make_pair(
348338
activeDomainPoints(schedule(), band),
349339
band->scheduleDepth(schedule()) + 0));
350-
band = map(band, 0, mapping::ThreadId::x());
351-
markUnroll(scop_->scheduleRoot(), band, unroll);
352-
return 1;
353-
}
354-
auto bandNode = band->elemAs<ScheduleTreeElemBand>();
355-
// If any inner node was mapped to threads and
356-
// the current node has a non-coincident member,
357-
// then synchronization needs to be introduced.
358-
// This also implies that the mapping needs to be completed first.
359-
if (anyNonCoincidentMember(bandNode) && nInner > 0) {
360-
// Since some thread identifiers were mapped already (nInner > 0),
361-
// the band should have descendants. Double check.
362-
CHECK_EQ(band->numChildren(), 1);
363-
mapRemaining<mapping::ThreadId>(
364-
band->child({0}), nInner, numThreads.view.size());
365-
scop_->insertSyncAfter(band->child({0}));
366-
return numThreads.view.size();
340+
auto reductionDim = reductionBandUpdates_.at(band).reductionDim;
341+
band = map(band, reductionDim, mapping::ThreadId::x());
342+
nMappedReductionThreads = 1;
367343
}
344+
368345
// With current isl scheduler, if coincident dimensions exist in a band,
369346
// they are outermost.
370-
// If a band has more than 3 coincident dimensions, this will choose
371-
// outermost, but we may also want innermost.
347+
// If a band has more than 3 coincident dimensions,
348+
// then the innermost of those will be used.
372349
auto nOuterCoincident = bandNode->nOuterCoincident();
373-
if (!bandNode->permutable_ || nOuterCoincident < 1) {
374-
return nInner;
350+
if (nOuterCoincident < 1) {
351+
return nMappedReductionThreads;
375352
}
376353

377354
auto nMappedThreads = std::min(
378-
numThreads.view.size() - nInner, static_cast<size_t>(nOuterCoincident));
379-
CHECK_GT(nMappedThreads, 0) << "not mapping to threads";
380-
CHECK_LE(nMappedThreads, 3 - nInner) << "mapping to too many threads";
355+
numThreads.view.size() - nMappedReductionThreads,
356+
static_cast<size_t>(nOuterCoincident));
357+
358+
// Immediately return if mapping to one thread dimension only was requested
359+
// and a reduction was already mapped. (Note that reduction is detected only
360+
// if there are not enough outer coincident members, 0 in this case).
361+
if (nMappedThreads == 0) {
362+
return nMappedReductionThreads;
363+
}
364+
CHECK_LE(nMappedThreads, 3 - nMappedReductionThreads)
365+
<< "mapping to too many threads";
381366

382367
// Map the coincident dimensions to threads starting from the innermost and
383-
// from thread x.
384-
for (int i = 0, dim = nOuterCoincident - 1; i < nMappedThreads && dim >= 0;
385-
++i, --dim) {
386-
auto id = mapping::ThreadId::makeId(nInner + i);
368+
// from thread x unless it was already mapped to a reduction.
369+
for (int i = 0; i < nMappedThreads; ++i) {
370+
auto id = mapping::ThreadId::makeId(nMappedReductionThreads + i);
371+
auto dim = nOuterCoincident - 1 - i;
387372
if (id == mapping::ThreadId::x()) {
388373
threadIdxXScheduleDepthState.emplace_back(std::make_pair(
389374
activeDomainPoints(schedule(), band),
@@ -392,11 +377,7 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band, size_t nInner) {
392377
band = map(band, dim, id);
393378
}
394379

395-
if (nInner == 0) {
396-
markUnroll(scop_->scheduleRoot(), band, unroll);
397-
}
398-
399-
return nInner + nMappedThreads;
380+
return nMappedReductionThreads + nMappedThreads;
400381
}
401382

402383
namespace {
@@ -426,8 +407,12 @@ bool hasOuterSequentialMember(
426407
}
427408
} // namespace
428409

429-
// Maps bands to threads in DFS postorder, keeping track of
430-
// the (maximal) number of threads already mapped by descendants.
410+
// Maps bands to threads in DFS postorder.
411+
// Mapping is only allowed if descendants are not already mapped to threads.
412+
// Mapping nested bands to threads is invalid because members of those bands
413+
// are not necessarily permutable, and there is no guaranteed nesting between
414+
// thread dimensions (e.g., there is no guarantee that all threads with
415+
// threadIdx.y=0 will be executed before any thread with threadIdx.y=1).
431416
//
432417
// If any separation is needed for mapping reductions to full blocks,
433418
// then do so first.
@@ -479,8 +464,28 @@ size_t MappedScop::mapInnermostBandsToThreads(detail::ScheduleTree* st) {
479464
}
480465
}
481466

482-
if (st->elemAs<detail::ScheduleTreeElemBand>()) {
483-
n = mapToThreads(st, n);
467+
if (auto band = st->elemAs<detail::ScheduleTreeElemBand>()) {
468+
if (n == 0) {
469+
// If children were not mapped to threads, the current band can be mapped.
470+
// First, map the coincidence and reduction dimension to threads.
471+
// Then, if some threads were mapped, fix unused thread dimensions to 0
472+
// because we cannot map parent bands anyway.
473+
auto nMapped = mapToThreads(st);
474+
if (nMapped > 0) {
475+
mapRemaining<mapping::ThreadId>(st, nMapped, numThreads.view.size());
476+
markUnroll(scop_->scheduleRoot(), st, unroll);
477+
return numThreads.view.size();
478+
}
479+
} else if (anyNonCoincidentMember(band)) {
480+
// If children were mapped to threads, and this band has a non-coincident
481+
// member, insert a synchronization after its last child.
482+
// The node must have children if some of them were mapped to threads,
483+
// double-check. Note that a band node has at most one child.
484+
CHECK_EQ(st->numChildren(), 1);
485+
// The mapping should be always complete, double-check.
486+
CHECK_EQ(n, numThreads.view.size());
487+
scop_->insertSyncAfter(st->child({0}));
488+
}
484489
}
485490

486491
return n;
@@ -587,6 +592,22 @@ std::tuple<std::string, tc::Grid, tc::Block> MappedScop::codegen(
587592
mappedScopForCodegen->numThreads);
588593
}
589594

595+
// Split out reduction loops into separate bands and insert reduction
596+
// synchronizations outside those bands.
597+
void MappedScop::splitOutReductionsAndInsertSyncs() {
598+
using namespace polyhedral::detail;
599+
600+
for (auto bandUpdate : reductionBandUpdates_) {
601+
auto tree = bandSplitOut(
602+
scop_->scheduleRoot(),
603+
const_cast<ScheduleTree*>(bandUpdate.first),
604+
bandUpdate.second.reductionDim);
605+
for (auto updateId : bandUpdate.second.ids) {
606+
scop_->insertReductionSync1D(tree, updateId);
607+
}
608+
}
609+
}
610+
590611
std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
591612
std::unique_ptr<Scop>&& scopUPtr,
592613
const CudaMappingOptions& cudaOptions) {
@@ -650,7 +671,13 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
650671
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "After mapping to blocks:" << std::endl
651672
<< *mappedScop->schedule();
652673

653-
// 7. Promote to shared memory below the loops mapped to blocks.
674+
// 7. Insert reduction synchronizations if necessary.
675+
mappedScop->splitOutReductionsAndInsertSyncs();
676+
LOG_IF(INFO, FLAGS_debug_tc_mapper)
677+
<< "After inserting reduction synchronization:" << std::endl
678+
<< *mappedScop->schedule();
679+
680+
// 8. Promote to shared memory below the loops mapped to blocks.
654681
// This may split the outer band, so find the new outer band after promotion.
655682
if (cudaOptions.proto().use_shared_memory()) {
656683
size_t sharedMemorySize = cudaOptions.proto().has_max_shared_memory()
@@ -697,24 +724,16 @@ std::unique_ptr<MappedScop> MappedScop::makeWithOuterBlockInnerThreadStrategy(
697724
}
698725
}
699726

700-
// 8. Promote to registers below the loops mapped to threads.
727+
// 9. Promote to registers below the loops mapped to threads.
701728
if (cudaOptions.proto().use_private_memory()) {
702729
promoteToRegistersBelowThreads(
703730
mappedScop->scop(), mappedScop->threadIdxXScheduleDepthState, -1ull);
704731
}
705732

706-
// 9. Insert mapping context
733+
// 10. Insert mapping context
707734
mappedScop->insertMappingContext();
708-
709-
// 10. Optionally insert reduction synchronizations
710-
for (auto bandUpdate : mappedScop->reductionBandUpdates_) {
711-
for (auto updateId : bandUpdate.second.ids) {
712-
scop->insertReductionSync1D(
713-
const_cast<ScheduleTree*>(bandUpdate.first), updateId);
714-
}
715-
}
716735
LOG_IF(INFO, FLAGS_debug_tc_mapper)
717-
<< "After inserting reduction synchronization:" << std::endl
736+
<< "After outerBlockInnerThread strategy:" << std::endl
718737
<< *mappedScop->schedule();
719738

720739
return mappedScop;

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ class MappedScop {
142142
detail::ScheduleTree* band,
143143
std::vector<size_t> tileSizes);
144144
// Look for innermost reduction band members.
145-
// Store them in reductionBandUpdates_ and their parents
146-
// in reductionFromParent_. Return true if any were found.
145+
// Store them in reductionBandUpdates_.
146+
// Return true if any were found.
147147
bool detectReductions(detail::ScheduleTree* band);
148148
// Does separateReduction need to be called on this node?
149149
bool needReductionSeparation(const detail::ScheduleTree* st);
@@ -155,13 +155,12 @@ class MappedScop {
155155
// The remaining parts, if any, are no longer considered for replacement
156156
// by a library call.
157157
detail::ScheduleTree* separateReduction(detail::ScheduleTree* band);
158-
// Map "band" to thread identifiers, assuming "nInner" thread identifiers
159-
// have already been used and using as many remaining blockSizes values as
160-
// outer coincident dimensions,
161-
// unroll band members that execute at most "unroll" instances
162-
// (if nInner == 0) and
163-
// return the updated number of mapped thread identifiers.
164-
size_t mapToThreads(detail::ScheduleTree* band, size_t nInner);
158+
// Split out reduction bands and insert reduction synchronizations.
159+
void splitOutReductionsAndInsertSyncs();
160+
// Map "band" to thread identifiers using as many blockSizes values as outer
161+
// coincident dimensions, unroll band members that execute at most "unroll"
162+
// instances and return the number of mapped thread identifiers.
163+
size_t mapToThreads(detail::ScheduleTree* band);
165164
// Map innermost bands to thread identifiers and
166165
// return the number of mapped thread identifiers.
167166
size_t mapInnermostBandsToThreads(detail::ScheduleTree* st);
@@ -185,17 +184,15 @@ class MappedScop {
185184
// Information about a detected reduction that can potentially
186185
// be mapped to a library call.
187186
struct Reduction {
188-
Reduction(std::vector<isl::id> ids) : ids(ids), separated(false) {}
187+
Reduction(std::vector<isl::id> ids, size_t index)
188+
: ids(ids), separated(false), reductionDim(index) {}
189189
// The statement identifiers of the reduction update statements.
190190
std::vector<isl::id> ids;
191191
// Has the reduction been separated out as a full block?
192192
bool separated;
193+
// Index of the band member in which the reduction was detected.
194+
int reductionDim;
193195
};
194-
// Map parent band of reduction band to the reduction band.
195-
// As a special case, the parent band may be missing,
196-
// in which case it is the reduction band that gets mapped to itself.
197-
std::unordered_map<const detail::ScheduleTree*, const detail::ScheduleTree*>
198-
reductionFromParent_;
199196
// Map isolated innermost reduction band members to information
200197
// about the detected reduction.
201198
std::map<const detail::ScheduleTree*, Reduction> reductionBandUpdates_;

0 commit comments

Comments
 (0)