Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2d73ed1

Browse files
Sven VerdoolaegeTheodoros Theodoridis
authored andcommitted
MappedScop::findBestSync: use proper mapping to threads and warps
In particular, use mappings to threads and warps to determine whether all dependences are within a thread or warp. Do so instead of considering pairs of thread identifier parameters. Duplicating thread identifier parameters is confusing at best and relies on renaming parameters, which is frowned upon and which will not be exported in the mainline isl C++ interface. Duplicate parameters are confusing because there is only one value for any given thread identifier at a given point in the execution. Furthermore, the filter assigning the thread mapping to parameters is only relevant during AST generation. Prior to that, it is more natural to think in terms of the thread mapping itself. This is what this commit does. It is not only simpler, but also shorter. It is not entirely clear why the intersection with the context is needed, but this is what the original code does and it is preserved by this commit.
1 parent c5c4365 commit 2d73ed1

File tree

2 files changed

+78
-129
lines changed

2 files changed

+78
-129
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 73 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -431,119 +431,84 @@ bool hasOuterSequentialMember(
431431
return false;
432432
}
433433

434-
// Intersect the union set with all the mapping
435-
// filters params in the given schedule tree
436-
isl::union_set intersectMappingFilterParams(
437-
detail::ScheduleTree* st,
438-
isl::union_set us) {
439-
if (auto filter = st->elemAsBase<detail::ScheduleTreeElemFilter>()) {
440-
us = us.intersect(filter->filter_);
441-
}
434+
// Name of the space of threads inside a block
435+
constexpr auto kBlock = "block";
436+
// Name of the space of warps
437+
constexpr auto kWarp = "warp";
442438

443-
auto children = st->children();
444-
auto nChildren = children.size();
445-
if (nChildren == 1) {
446-
us = intersectMappingFilterParams(children[0], us);
447-
} else if (nChildren > 1) {
448-
auto usParent = us;
449-
us = intersectMappingFilterParams(children[0], us);
450-
for (size_t i = 1; i < nChildren; ++i) {
451-
us = us.unite(intersectMappingFilterParams(children[i], usParent));
452-
}
453-
}
454-
455-
return us;
456-
}
439+
/*
440+
* Extract a mapping from the domain elements active at "tree"
441+
* to the thread identifiers, where all branches in "tree"
442+
* are assumed to have been mapped to thread identifiers.
443+
* "nThread" is the number of thread identifiers.
444+
* The result lives in a space of the form block[x, ...].
445+
*/
446+
isl::multi_union_pw_aff extractDomainToThread(
447+
const detail::ScheduleTree* tree,
448+
size_t nThread) {
449+
using namespace polyhedral::detail;
457450

458-
// Change the name of the isl ids tied to threads and blocks
459-
// by adding a suffix
460-
isl::union_set modifyMappingNames(
461-
isl::union_set set,
462-
const std::string suffix) {
463-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
464-
std::unordered_set<isl::id, isl::IslIdIslHash> identifiers{
465-
BX, BY, BZ, TX, TY, TZ};
466-
467-
auto space = set.get_space();
468-
for (auto id : identifiers) {
469-
auto name = id.get_name();
470-
auto dim = space.find_dim_by_name(isl::dim_type::param, id.get_name());
471-
CHECK_LE(0, dim);
472-
space = space.set_dim_name(isl::dim_type::param, dim, name + suffix);
473-
}
474-
auto newSet = isl::union_set::empty(space);
475-
set.foreach_set([&newSet, &identifiers, &suffix](isl::set setInFun) {
476-
for (auto id : identifiers) {
477-
auto name = id.get_name();
478-
auto dim =
479-
setInFun.get_space().find_dim_by_name(isl::dim_type::param, name);
480-
CHECK_LE(0, dim);
481-
setInFun =
482-
setInFun.set_dim_name(isl::dim_type::param, dim, name + suffix);
451+
auto space = isl::space(tree->ctx_, 0);
452+
auto empty = isl::union_set::empty(space);
453+
auto id = isl::id(tree->ctx_, kBlock);
454+
space = space.named_set_from_params_id(id, nThread);
455+
auto zero = isl::multi_val::zero(space);
456+
auto domainToThread = isl::multi_union_pw_aff(empty, zero);
457+
458+
for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) {
459+
auto mappingNode = mapping->elemAs<ScheduleTreeElemMappingFilter>();
460+
auto list = isl::union_pw_aff_list(tree->ctx_, nThread);
461+
for (size_t i = 0; i < nThread; ++i) {
462+
auto threadId = mapping::ThreadId::makeId(i);
463+
auto threadMap = mappingNode->mapping.at(threadId);
464+
list = list.add(threadMap);
483465
}
484-
newSet = newSet.unite(setInFun);
485-
});
486-
return newSet;
487-
}
488-
489-
// Get the formula computing the linearized index of a thread in a block.
490-
isl::aff getLinearizedThreadIdxFormula(
491-
isl::space space,
492-
const Block& block,
493-
const std::string& suffix = "") {
494-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
495-
std::vector<std::pair<isl::id, unsigned>> mappingIds{
496-
{TX, TX.mappingSize(block)},
497-
{TY, TY.mappingSize(block)},
498-
{TZ, TZ.mappingSize(block)}};
499-
500-
isl::aff formula = isl::aff(isl::local_space(space));
501-
502-
for (int i = (int)mappingIds.size() - 1; i >= 0; --i) {
503-
auto name = mappingIds[i].first.to_str();
504-
auto dim = space.find_dim_by_name(isl::dim_type::param, name + suffix);
505-
CHECK_LE(0, dim);
506-
auto id = space.get_dim_id(isl::dim_type::param, dim);
507-
isl::aff aff(isl::aff::param_on_domain_space(space, id));
508-
formula = formula * mappingIds[i].second + aff;
466+
auto nodeToThread = isl::multi_union_pw_aff(space, list);
467+
domainToThread = domainToThread.union_add(nodeToThread);
509468
}
510469

511-
return formula;
470+
return domainToThread;
512471
}
513472

514-
// Return the constraints ensuring that the points with parameters
515-
// [t0,t1,t2] and [t0',t1',t2'] are in the same warp.
516-
// (where t0 is "t0" + suffix1 and t0' is "t0" + suffix2)
517-
// if suffix1 is "_1" and suffix2 is "_2", the constraint is in the form
518-
// ((t0_1 + a * t1_1 + b * t2_1) / warpSize).floor()
519-
// == ((t0_2 + a' * t1_2 + b' * t1_2) / warpSize).floor()
520-
// with t0_1 + a * t1_1 + b * t2_1 the linearized formula of the thread index.
521-
// This function returns a set because it might change in the future,
522-
// and take into account the blocks.
523-
isl::set getSameWarpConstraints(
524-
isl::space space,
525-
const std::string& suffix1,
526-
const std::string& suffix2,
473+
/*
474+
* Construct a mapping
475+
*
476+
* block[x] -> warp[floor((x)/warpSize)]
477+
* block[x, y] -> warp[floor((x + s_x * (y))/warpSize)]
478+
* block[x, y, z] -> warp[floor((x + s_x * (y + s_y * (z)))/warpSize)]
479+
*
480+
* uniquely mapping thread identifiers that belong to the same warp
481+
* (of size "warpSize") to a warp identifier,
482+
* based on the thread sizes s_x, s_y up to s_z in "block".
483+
*/
484+
isl::multi_aff constructThreadToWarp(
485+
isl::ctx ctx,
527486
const unsigned warpSize,
528487
const Block& block) {
529-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
530-
std::vector<std::pair<isl::id, unsigned>> mappingIds{
531-
{TX, TX.mappingSize(block)},
532-
{TY, TY.mappingSize(block)},
533-
{TZ, TZ.mappingSize(block)}};
534-
535-
auto formula1 = getLinearizedThreadIdxFormula(space, block, suffix1);
536-
auto formula2 = getLinearizedThreadIdxFormula(space, block, suffix2);
488+
auto space = isl::space(ctx, 0);
489+
auto id = isl::id(ctx, kBlock);
490+
auto blockSpace = space.named_set_from_params_id(id, block.view.size());
491+
auto warpSpace = space.named_set_from_params_id(isl::id(ctx, kWarp), 1);
492+
auto aff = isl::aff::zero_on_domain(blockSpace);
493+
494+
auto nThread = block.view.size();
495+
auto identity = isl::multi_aff::identity(blockSpace.map_from_set());
496+
for (int i = nThread - 1; i >= 0; --i) {
497+
aff = aff.scale(isl::val(ctx, block.view[i]));
498+
aff = aff.add(identity.get_aff(i));
499+
}
537500

538-
return (
539-
isl::aff_set((formula1 / warpSize).floor()) ==
540-
(formula2 / warpSize).floor());
501+
aff = aff.scale_down(isl::val(ctx, warpSize)).floor();
502+
auto mapSpace = blockSpace.product(warpSpace).unwrap();
503+
return isl::multi_aff(mapSpace, isl::aff_list(aff));
541504
}
542505
} // namespace
543506

544507
Scop::SyncLevel MappedScop::findBestSync(
545508
detail::ScheduleTree* st1,
546-
detail::ScheduleTree* st2) {
509+
detail::ScheduleTree* st2,
510+
isl::multi_union_pw_aff domainToThread,
511+
isl::multi_union_pw_aff domainToWarp) {
547512
// Active points in the two schedule trees
548513
auto stRoot = scop_->scheduleRoot();
549514
auto activePoints1 = activeDomainPointsBelow(stRoot, st1);
@@ -557,41 +522,16 @@ Scop::SyncLevel MappedScop::findBestSync(
557522
return Scop::SyncLevel::None;
558523
}
559524

560-
// The domain and the context of the root schedule tree
561-
auto domainAndContext = scop_->domain();
562525
CHECK_LE(1u, scop_->scheduleRoot()->children().size());
563526
auto contextSt = scop_->scheduleRoot()->children()[0];
564527
auto contextElem = contextSt->elemAs<detail::ScheduleTreeElemContext>();
565528
CHECK(nullptr != contextElem);
566-
domainAndContext = domainAndContext.intersect_params(contextElem->context_);
529+
dependences = dependences.intersect_params(contextElem->context_);
567530

568-
// The domain of both schedule trees filtered by mapping filters,
569-
// and then modified to have different threads and blocks names.
570-
auto domain1 = intersectMappingFilterParams(st1, domainAndContext);
571-
auto domain2 = intersectMappingFilterParams(st2, domainAndContext);
572-
auto suffix1 = "_1";
573-
auto suffix2 = "_2";
574-
domain1 = modifyMappingNames(domain1, suffix1);
575-
domain2 = modifyMappingNames(domain2, suffix2);
576-
577-
// The dependences between the two schedule trees
578-
// with mapping from threads and blocks
579-
auto mappedDependences =
580-
isl::union_map::from_domain_and_range(domain1, domain2);
581-
mappedDependences = mappedDependences.intersect(dependences);
582-
583-
auto space = mappedDependences.get_space();
584-
auto sameThreadConstraint =
585-
getSameWarpConstraints(space, suffix1, suffix2, 1, numThreads);
586-
auto sameWarpConstraints =
587-
getSameWarpConstraints(space, suffix1, suffix2, 32, numThreads);
588-
589-
if (mappedDependences ==
590-
mappedDependences.intersect_params(sameThreadConstraint)) {
531+
if (dependences.is_subset(dependences.eq_at(domainToThread))) {
591532
return Scop::SyncLevel::None;
592-
} else if (
593-
mappedDependences ==
594-
mappedDependences.intersect_params(sameWarpConstraints)) {
533+
}
534+
if (dependences.is_subset(dependences.eq_at(domainToWarp))) {
595535
return Scop::SyncLevel::Warp;
596536
}
597537
return Scop::SyncLevel::Block;
@@ -754,6 +694,10 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
754694

755695
auto outer = hasOuterSequentialMember(scop_->scheduleRoot(), seq);
756696

697+
auto domainToThread = extractDomainToThread(seq, numThreads.view.size());
698+
auto threadToWarp = constructThreadToWarp(seq->ctx_, 32, numThreads);
699+
auto domainToWarp = domainToThread.apply(threadToWarp);
700+
757701
std::vector<std::vector<int>> bestSync(
758702
nChildren, std::vector<int>(nChildren + 1));
759703
// Get the synchronization needed between children[i] and children[i+k]
@@ -765,7 +709,8 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
765709
for (size_t i = 0; i < nChildren; ++i) {
766710
for (size_t k = 0; k < nChildren; ++k) {
767711
auto ik = (i + k) % nChildren;
768-
bestSync[i][k] = (int)findBestSync(children[i], children[ik]);
712+
bestSync[i][k] = (int)findBestSync(
713+
children[i], children[ik], domainToThread, domainToWarp);
769714
}
770715
}
771716

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,13 @@ class MappedScop {
163163
// st1.
164164
// This function assumes that it is called before block mapping
165165
// and that st1 and st2 are already mapped to threads.
166+
// "domainToThread" and "domainToWarp" map the domain elements
167+
// of st1 and st2 to thread and warp identifiers, respectively.
166168
Scop::SyncLevel findBestSync(
167169
detail::ScheduleTree* st1,
168-
detail::ScheduleTree* st2);
170+
detail::ScheduleTree* st2,
171+
isl::multi_union_pw_aff domainToThread,
172+
isl::multi_union_pw_aff domainToWarp);
169173

170174
public:
171175
// Find best configuration of synchronizations in a sequence, minimizing

0 commit comments

Comments
 (0)