Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 9fb6a37

Browse files
authored
Merge pull request #444 from facebookresearch/pr/sync
MappedScop::findBestSync: use proper mapping to threads and warps
2 parents 86c1f8c + e3323ed commit 9fb6a37

File tree

6 files changed

+111
-164
lines changed

6 files changed

+111
-164
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 76 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -431,167 +431,107 @@ bool hasOuterSequentialMember(
431431
return false;
432432
}
433433

434-
// Intersect the union set with all the mapping
435-
// filters params in the given schedule tree
436-
isl::union_set intersectMappingFilterParams(
437-
detail::ScheduleTree* st,
438-
isl::union_set us) {
439-
if (auto filter = st->elemAsBase<detail::ScheduleTreeElemFilter>()) {
440-
us = us.intersect(filter->filter_);
441-
}
434+
// Name of the space of threads inside a block
435+
constexpr auto kBlock = "block";
436+
// Name of the space of warps
437+
constexpr auto kWarp = "warp";
442438

443-
auto children = st->children();
444-
auto nChildren = children.size();
445-
if (nChildren == 1) {
446-
us = intersectMappingFilterParams(children[0], us);
447-
} else if (nChildren > 1) {
448-
auto usParent = us;
449-
us = intersectMappingFilterParams(children[0], us);
450-
for (size_t i = 1; i < nChildren; ++i) {
451-
us = us.unite(intersectMappingFilterParams(children[i], usParent));
452-
}
453-
}
454-
455-
return us;
456-
}
439+
/*
440+
* Extract a mapping from the domain elements active at "tree"
441+
* to the thread identifiers, where all branches in "tree"
442+
* are assumed to have been mapped to thread identifiers.
443+
* "nThread" is the number of thread identifiers.
444+
* The result lives in a space of the form block[x, ...].
445+
*/
446+
isl::multi_union_pw_aff extractDomainToThread(
447+
const detail::ScheduleTree* tree,
448+
size_t nThread) {
449+
using namespace polyhedral::detail;
457450

458-
// Change the name of the isl ids tied to threads and blocks
459-
// by adding a suffix
460-
isl::union_set modifyMappingNames(
461-
isl::union_set set,
462-
const std::string suffix) {
463-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
464-
std::unordered_set<isl::id, isl::IslIdIslHash> identifiers{
465-
BX, BY, BZ, TX, TY, TZ};
466-
467-
auto space = set.get_space();
468-
for (auto id : identifiers) {
469-
auto name = id.get_name();
470-
auto dim = space.find_dim_by_name(isl::dim_type::param, id.get_name());
471-
CHECK_LE(0, dim);
472-
space = space.set_dim_name(isl::dim_type::param, dim, name + suffix);
473-
}
474-
auto newSet = isl::union_set::empty(space);
475-
set.foreach_set([&newSet, &identifiers, &suffix](isl::set setInFun) {
476-
for (auto id : identifiers) {
477-
auto name = id.get_name();
478-
auto dim =
479-
setInFun.get_space().find_dim_by_name(isl::dim_type::param, name);
480-
CHECK_LE(0, dim);
481-
setInFun =
482-
setInFun.set_dim_name(isl::dim_type::param, dim, name + suffix);
451+
auto space = isl::space(tree->ctx_, 0);
452+
auto empty = isl::union_set::empty(space);
453+
auto id = isl::id(tree->ctx_, kBlock);
454+
space = space.named_set_from_params_id(id, nThread);
455+
auto zero = isl::multi_val::zero(space);
456+
auto domainToThread = isl::multi_union_pw_aff(empty, zero);
457+
458+
for (auto mapping : tree->collect(tree, ScheduleTreeType::MappingFilter)) {
459+
auto mappingNode = mapping->elemAs<ScheduleTreeElemMappingFilter>();
460+
auto list = isl::union_pw_aff_list(tree->ctx_, nThread);
461+
for (size_t i = 0; i < nThread; ++i) {
462+
auto threadId = mapping::ThreadId::makeId(i);
463+
auto threadMap = mappingNode->mapping.at(threadId);
464+
list = list.add(threadMap);
483465
}
484-
newSet = newSet.unite(setInFun);
485-
});
486-
return newSet;
487-
}
488-
489-
// Get the formula computing the linearized index of a thread in a block.
490-
isl::aff getLinearizedThreadIdxFormula(
491-
isl::space space,
492-
const Block& block,
493-
const std::string& suffix = "") {
494-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
495-
std::vector<std::pair<isl::id, unsigned>> mappingIds{
496-
{TX, TX.mappingSize(block)},
497-
{TY, TY.mappingSize(block)},
498-
{TZ, TZ.mappingSize(block)}};
499-
500-
isl::aff formula = isl::aff(isl::local_space(space));
501-
502-
for (int i = (int)mappingIds.size() - 1; i >= 0; --i) {
503-
auto name = mappingIds[i].first.to_str();
504-
auto dim = space.find_dim_by_name(isl::dim_type::param, name + suffix);
505-
CHECK_LE(0, dim);
506-
auto id = space.get_dim_id(isl::dim_type::param, dim);
507-
isl::aff aff(isl::aff::param_on_domain_space(space, id));
508-
formula = formula * mappingIds[i].second + aff;
466+
auto nodeToThread = isl::multi_union_pw_aff(space, list);
467+
domainToThread = domainToThread.union_add(nodeToThread);
509468
}
510469

511-
return formula;
470+
return domainToThread;
512471
}
513472

514-
// Return the constraints ensuring that the points with parameters
515-
// [t0,t1,t2] and [t0',t1',t2'] are in the same warp.
516-
// (where t0 is "t0" + suffix1 and t0' is "t0" + suffix2)
517-
// if suffix1 is "_1" and suffix2 is "_2", the constraint is in the form
518-
// ((t0_1 + a * t1_1 + b * t2_1) / warpSize).floor()
519-
// == ((t0_2 + a' * t1_2 + b' * t1_2) / warpSize).floor()
520-
// with t0_1 + a * t1_1 + b * t2_1 the linearized formula of the thread index.
521-
// This function returns a set because it might change in the future,
522-
// and take into account the blocks.
523-
isl::set getSameWarpConstraints(
524-
isl::space space,
525-
const std::string& suffix1,
526-
const std::string& suffix2,
473+
/*
474+
* Construct a mapping
475+
*
476+
* block[x] -> warp[floor((x)/warpSize)]
477+
* block[x, y] -> warp[floor((x + s_x * (y))/warpSize)]
478+
* block[x, y, z] -> warp[floor((x + s_x * (y + s_y * (z)))/warpSize)]
479+
*
480+
* uniquely mapping thread identifiers that belong to the same warp
481+
* (of size "warpSize") to a warp identifier,
482+
* based on the thread sizes s_x, s_y up to s_z in "block".
483+
*/
484+
isl::multi_aff constructThreadToWarp(
485+
isl::ctx ctx,
527486
const unsigned warpSize,
528487
const Block& block) {
529-
USING_MAPPING_SHORT_NAMES(BX, BY, BZ, TX, TY, TZ);
530-
std::vector<std::pair<isl::id, unsigned>> mappingIds{
531-
{TX, TX.mappingSize(block)},
532-
{TY, TY.mappingSize(block)},
533-
{TZ, TZ.mappingSize(block)}};
534-
535-
auto formula1 = getLinearizedThreadIdxFormula(space, block, suffix1);
536-
auto formula2 = getLinearizedThreadIdxFormula(space, block, suffix2);
488+
auto space = isl::space(ctx, 0);
489+
auto id = isl::id(ctx, kBlock);
490+
auto blockSpace = space.named_set_from_params_id(id, block.view.size());
491+
auto warpSpace = space.named_set_from_params_id(isl::id(ctx, kWarp), 1);
492+
auto aff = isl::aff::zero_on_domain(blockSpace);
493+
494+
auto nThread = block.view.size();
495+
auto identity = isl::multi_aff::identity(blockSpace.map_from_set());
496+
for (int i = nThread - 1; i >= 0; --i) {
497+
aff = aff.scale(isl::val(ctx, block.view[i]));
498+
aff = aff.add(identity.get_aff(i));
499+
}
537500

538-
return (
539-
isl::aff_set((formula1 / warpSize).floor()) ==
540-
(formula2 / warpSize).floor());
501+
aff = aff.scale_down(isl::val(ctx, warpSize)).floor();
502+
auto mapSpace = blockSpace.product(warpSpace).unwrap();
503+
return isl::multi_aff(mapSpace, isl::aff_list(aff));
541504
}
542505
} // namespace
543506

544507
Scop::SyncLevel MappedScop::findBestSync(
545508
detail::ScheduleTree* st1,
546-
detail::ScheduleTree* st2) {
509+
detail::ScheduleTree* st2,
510+
isl::multi_union_pw_aff domainToThread,
511+
isl::multi_union_pw_aff domainToWarp) {
547512
// Active points in the two schedule trees
548513
auto stRoot = scop_->scheduleRoot();
549514
auto activePoints1 = activeDomainPointsBelow(stRoot, st1);
550515
auto activePoints2 = activeDomainPointsBelow(stRoot, st2);
551516

552517
// The dependences between the two schedule trees
553-
auto dependences =
554-
isl::union_map::from_domain_and_range(activePoints1, activePoints2);
555-
dependences = dependences.intersect(scop_->dependences);
518+
auto dependences = scop_->dependences;
519+
dependences = dependences.intersect_domain(activePoints1);
520+
dependences = dependences.intersect_range(activePoints2);
556521
if (dependences.is_empty()) {
557522
return Scop::SyncLevel::None;
558523
}
559524

560-
// The domain and the context of the root schedule tree
561-
auto domainAndContext = scop_->domain();
562525
CHECK_LE(1u, scop_->scheduleRoot()->children().size());
563526
auto contextSt = scop_->scheduleRoot()->children()[0];
564527
auto contextElem = contextSt->elemAs<detail::ScheduleTreeElemContext>();
565528
CHECK(nullptr != contextElem);
566-
domainAndContext = domainAndContext.intersect_params(contextElem->context_);
529+
dependences = dependences.intersect_params(contextElem->context_);
567530

568-
// The domain of both schedule trees filtered by mapping filters,
569-
// and then modified to have different threads and blocks names.
570-
auto domain1 = intersectMappingFilterParams(st1, domainAndContext);
571-
auto domain2 = intersectMappingFilterParams(st2, domainAndContext);
572-
auto suffix1 = "_1";
573-
auto suffix2 = "_2";
574-
domain1 = modifyMappingNames(domain1, suffix1);
575-
domain2 = modifyMappingNames(domain2, suffix2);
576-
577-
// The dependences between the two schedule trees
578-
// with mapping from threads and blocks
579-
auto mappedDependences =
580-
isl::union_map::from_domain_and_range(domain1, domain2);
581-
mappedDependences = mappedDependences.intersect(dependences);
582-
583-
auto space = mappedDependences.get_space();
584-
auto sameThreadConstraint =
585-
getSameWarpConstraints(space, suffix1, suffix2, 1, numThreads);
586-
auto sameWarpConstraints =
587-
getSameWarpConstraints(space, suffix1, suffix2, 32, numThreads);
588-
589-
if (mappedDependences ==
590-
mappedDependences.intersect_params(sameThreadConstraint)) {
531+
if (dependences.is_subset(dependences.eq_at(domainToThread))) {
591532
return Scop::SyncLevel::None;
592-
} else if (
593-
mappedDependences ==
594-
mappedDependences.intersect_params(sameWarpConstraints)) {
533+
}
534+
if (dependences.is_subset(dependences.eq_at(domainToWarp))) {
595535
return Scop::SyncLevel::Warp;
596536
}
597537
return Scop::SyncLevel::Block;
@@ -754,6 +694,10 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
754694

755695
auto outer = hasOuterSequentialMember(scop_->scheduleRoot(), seq);
756696

697+
auto domainToThread = extractDomainToThread(seq, numThreads.view.size());
698+
auto threadToWarp = constructThreadToWarp(seq->ctx_, 32, numThreads);
699+
auto domainToWarp = domainToThread.apply(threadToWarp);
700+
757701
std::vector<std::vector<int>> bestSync(
758702
nChildren, std::vector<int>(nChildren + 1));
759703
// Get the synchronization needed between children[i] and children[i+k]
@@ -765,7 +709,8 @@ void MappedScop::insertBestSyncInSeq(detail::ScheduleTree* seq) {
765709
for (size_t i = 0; i < nChildren; ++i) {
766710
for (size_t k = 0; k < nChildren; ++k) {
767711
auto ik = (i + k) % nChildren;
768-
bestSync[i][k] = (int)findBestSync(children[i], children[ik]);
712+
bestSync[i][k] = (int)findBestSync(
713+
children[i], children[ik], domainToThread, domainToWarp);
769714
}
770715
}
771716

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,13 @@ class MappedScop {
163163
// st1.
164164
// This function assumes that it is called before block mapping
165165
// and that st1 and st2 are already mapped to threads.
166+
// "domainToThread" and "domainToWarp" map the domain elements
167+
// of st1 and st2 to thread and warp identifiers, respectively.
166168
Scop::SyncLevel findBestSync(
167169
detail::ScheduleTree* st1,
168-
detail::ScheduleTree* st2);
170+
detail::ScheduleTree* st2,
171+
isl::multi_union_pw_aff domainToThread,
172+
isl::multi_union_pw_aff domainToWarp);
169173

170174
public:
171175
// Find best configuration of synchronizations in a sequence, minimizing

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ bool isThreadMapping(const detail::ScheduleTree* tree) {
4646
using namespace detail;
4747

4848
if (auto filterNode = tree->elemAs<ScheduleTreeElemMappingFilter>()) {
49-
for (auto id : filterNode->mappingIds) {
50-
if (isThreadId(id)) {
49+
for (auto& kvp : filterNode->mapping) {
50+
if (isThreadId(kvp.first)) {
5151
return true;
5252
}
5353
}

tc/core/polyhedral/schedule_print.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ std::ostream& ScheduleTreeElemFilter::write(std::ostream& os) const {
202202
std::ostream& ScheduleTreeElemMappingFilter::write(std::ostream& os) const {
203203
WS w;
204204
os << w.tab() << "mapping_filter(ids(";
205-
for (const auto& id : mappingIds) {
206-
os << id << ", ";
205+
for (auto& kvp : mapping) {
206+
os << kvp.first << ", ";
207207
}
208208
os << ")";
209209
for (const auto& u : filter_.get_set_list()) {

tc/core/polyhedral/schedule_tree-inl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@ inline ScheduleTreeUPtr ScheduleTree::makeMappingFilter(
2323
const std::vector<MappingIdType>& mappedIds,
2424
isl::union_pw_aff_list mappedAffs,
2525
std::vector<ScheduleTreeUPtr>&& children) {
26-
std::vector<mapping::MappingId> ids;
27-
for (auto id : mappedIds) {
28-
ids.push_back(id);
26+
CHECK_EQ(mappedIds.size(), static_cast<size_t>(mappedAffs.n()))
27+
<< "expected as many mapped ids as affs";
28+
ScheduleTreeElemMappingFilter::Mapping mapping;
29+
for (size_t i = 0, n = mappedAffs.n(); i < n; ++i) {
30+
mapping.emplace(mappedIds.at(i), mappedAffs.get(i));
2931
}
30-
CHECK_GE(ids.size(), 1u) << "empty mapping";
32+
CHECK_GE(mapping.size(), 1u) << "empty mapping";
33+
CHECK_EQ(mappedIds.size(), mapping.size())
34+
<< "some id is used more than once in the mapping";
3135
auto ctx = mappedIds[0].get_ctx();
3236
ScheduleTreeUPtr res(new ScheduleTree(ctx));
3337
res->elem_ = std::unique_ptr<ScheduleTreeElemMappingFilter>(
34-
new ScheduleTreeElemMappingFilter(ids, mappedAffs));
38+
new ScheduleTreeElemMappingFilter(mapping));
3539
res->type_ = ScheduleTreeType::MappingFilter;
3640
res->appendChildren(std::move(children));
3741
return res;

tc/core/polyhedral/schedule_tree_elem.h

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include <memory>
1919
#include <sstream>
20-
#include <unordered_set>
20+
#include <unordered_map>
2121
#include <vector>
2222

2323
#include "tc/external/isl.h"
@@ -139,34 +139,29 @@ struct ScheduleTreeElemFilter : public ScheduleTreeElemBase {
139139
};
140140

141141
struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter {
142+
using Mapping = std::unordered_map<
143+
mapping::MappingId,
144+
isl::union_pw_aff,
145+
typename mapping::MappingId::Hash>;
142146
static constexpr std::initializer_list<detail::ScheduleTreeType>
143147
NodeDerivedTypes{detail::ScheduleTreeType::None};
144148
static constexpr detail::ScheduleTreeType NodeType =
145149
detail::ScheduleTreeType::MappingFilter;
146150
ScheduleTreeElemMappingFilter() = delete;
147151
ScheduleTreeElemMappingFilter(const ScheduleTreeElemMappingFilter& eb)
148-
: ScheduleTreeElemFilter(eb.filter_), mappingIds(eb.mappingIds) {}
149-
ScheduleTreeElemMappingFilter(
150-
const std::vector<mapping::MappingId>& mappedIds,
151-
isl::union_pw_aff_list mappedAffs)
152-
: ScheduleTreeElemFilter(isl::union_set()),
153-
mappingIds(mappedIds.begin(), mappedIds.end()) {
154-
// Check that ids are unique.
155-
CHECK_EQ(mappedIds.size(), mappingIds.size())
156-
<< "some id is used more than once in the mapping";
152+
: ScheduleTreeElemFilter(eb.filter_), mapping(eb.mapping) {}
153+
ScheduleTreeElemMappingFilter(const Mapping& mapping)
154+
: ScheduleTreeElemFilter(isl::union_set()), mapping(mapping) {
155+
CHECK_GT(mapping.size(), 0u) << "empty mapping filter";
157156

158-
CHECK_EQ(mappedIds.size(), static_cast<size_t>(mappedAffs.n()))
159-
<< "expected as many mapped ids as affs";
160-
CHECK_GT(mappedIds.size(), 0u) << "empty mapping filter";
161-
162-
auto domain = mappedAffs.get(0).domain();
163-
for (size_t i = 1, n = mappedAffs.n(); i < n; ++i) {
164-
CHECK(domain.is_equal(mappedAffs.get(i).domain()));
157+
auto domain = mapping.cbegin()->second.domain();
158+
for (auto& kvp : mapping) {
159+
CHECK(domain.is_equal(kvp.second.domain()));
165160
}
166161
filter_ = domain.universe();
167-
for (size_t i = 0, n = mappedAffs.n(); i < n; ++i) {
168-
auto upa = mappedAffs.get(i);
169-
auto id = mappedIds.at(i);
162+
for (auto& kvp : mapping) {
163+
auto upa = kvp.second;
164+
auto id = kvp.first;
170165
// Create mapping filter by equating the
171166
// parameter mappedIds[i] to the "i"-th affine function.
172167
upa = upa.sub(isl::union_pw_aff::param_on_domain(domain.universe(), id));
@@ -183,9 +178,8 @@ struct ScheduleTreeElemMappingFilter : public ScheduleTreeElemFilter {
183178
return NodeType;
184179
}
185180

186-
const std::
187-
unordered_set<mapping::MappingId, typename mapping::MappingId::Hash>
188-
mappingIds;
181+
// Mapping from identifiers to affine functions on domain elements.
182+
const Mapping mapping;
189183
};
190184

191185
struct ScheduleTreeElemSequence : public ScheduleTreeElemBase {

0 commit comments

Comments
 (0)