Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 78399a0

Browse files
authored
Merge pull request #360 from facebookresearch/pr/single
create a single block mapping filter and a single thread mapping filter per branch
2 parents f54c5dd + f84119c commit 78399a0

File tree

9 files changed

+114
-211
lines changed

9 files changed

+114
-211
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,69 @@ const CudaDim& mappingSize<mapping::ThreadId>(const MappedScop* mscop) {
108108
}
109109
} // namespace
110110

111+
// Map the elements in "list" to successive blocks or thread identifiers,
112+
// with the first element mapped to identifier X. The extents are obtained
113+
// from the initial elements of numBlocks or numThreads. The identifiers
114+
// must not be present in the space of the partial schedules in "list" and
115+
// extents must be non-zero. The mapping corresponds to inserting a filter
116+
// node with condition 'list % extent = ids'.
117+
// The mapping is inserted above "tree".
118+
//
119+
// Return a pointer to the updated node (below the inserted filter)
120+
// for call chaining purposes.
111121
template <typename MappingTypeId>
112-
void MappedScop::mapRemaining(detail::ScheduleTree* tree, size_t nMapped) {
113-
size_t nToMap = mappingSize<MappingTypeId>(this).view.size();
114-
if (nMapped >= nToMap) {
115-
return;
122+
detail::ScheduleTree* MappedScop::map(
123+
detail::ScheduleTree* tree,
124+
isl::union_pw_aff_list list) {
125+
size_t nToMap = list.n();
126+
const auto& extent = mappingSize<MappingTypeId>(this).view;
127+
CHECK_LE(nToMap, extent.size()) << "dimension overflow";
128+
129+
auto root = scop_->scheduleRoot();
130+
auto domain = activeDomainPoints(root, tree).universe();
131+
auto filter = domain;
132+
133+
std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> idSet;
134+
for (size_t i = 0; i < nToMap; ++i) {
135+
auto id = MappingTypeId::makeId(i);
136+
auto upa = list.get(i);
137+
// Introduce the "mapping" parameter after checking it is not already
138+
// present in the schedule space.
139+
CHECK(not upa.involves_param(id));
140+
CHECK_NE(extent[i], 0u) << "NYI: mapping to 0";
141+
142+
// Create mapping filter by equating the newly introduced
143+
// parameter ids[i] to the "i"-th affine function modulo its extent.
144+
upa = upa.mod_val(isl::val(tree->ctx_, extent[i]));
145+
upa = upa.sub(isl::union_pw_aff::param_on_domain(domain, id));
146+
filter = filter.intersect(upa.zero_union_set());
147+
148+
idSet.emplace(id);
116149
}
117150

118-
std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> ids;
119-
for (size_t i = nMapped; i < nToMap; ++i) {
120-
ids.insert(MappingTypeId::makeId(i));
151+
std::unordered_set<MappingTypeId, typename MappingTypeId::Hash> unmapped;
152+
for (size_t i = nToMap; i < extent.size(); ++i) {
153+
auto id = MappingTypeId::makeId(i);
154+
unmapped.emplace(id);
155+
idSet.emplace(id);
121156
}
122-
auto root = scop_->scheduleRoot();
123-
auto domain = activeDomainPoints(root, tree);
124-
auto filter = makeFixRemainingZeroFilter(domain, ids);
125-
auto mapping = detail::ScheduleTree::makeMappingFilter(filter, ids);
126-
insertNodeAbove(root, tree, std::move(mapping));
157+
filter = filter.intersect(makeFixRemainingZeroFilter(domain, unmapped));
158+
159+
auto mapping = detail::ScheduleTree::makeMappingFilter(filter, idSet);
160+
tree = insertNodeAbove(root, tree, std::move(mapping))->child({0});
161+
162+
return tree;
163+
}
164+
165+
detail::ScheduleTree* MappedScop::mapBlocksForward(
166+
detail::ScheduleTree* band,
167+
size_t nToMap) {
168+
auto bandNode = band->elemAs<detail::ScheduleTreeElemBand>();
169+
CHECK(bandNode) << "expected a band, got " << *band;
170+
171+
auto list = bandNode->mupa_.get_union_pw_aff_list();
172+
list = list.drop(nToMap, list.n() - nToMap);
173+
return map<mapping::BlockId>(band, list);
127174
}
128175

129176
// Uses as many blockSizes elements as outer coincident dimensions in the
@@ -142,10 +189,7 @@ void MappedScop::mapToBlocksAndScaleBand(
142189
// and no more than block dimensions to be mapped
143190
nBlocksToMap = std::min(nBlocksToMap, numBlocks.view.size());
144191

145-
for (size_t i = 0; i < nBlocksToMap; ++i) {
146-
band = map(band, i, mapping::BlockId::makeId(i));
147-
}
148-
mapRemaining<mapping::BlockId>(band, nBlocksToMap);
192+
mapBlocksForward(band, nBlocksToMap);
149193
bandScale(band, tileSizes);
150194
}
151195

@@ -166,10 +210,7 @@ void fixThreadsBelow(
166210

167211
auto band = detail::ScheduleTree::makeEmptyBand(mscop.scop().scheduleRoot());
168212
auto bandTree = insertNodeBelow(tree, std::move(band));
169-
auto ctx = tree->ctx_;
170-
insertNodeBelow(
171-
bandTree, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
172-
mscop.mapRemaining<mapping::ThreadId>(bandTree, begin);
213+
mscop.mapThreadsBackward(bandTree);
173214
}
174215

175216
bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
@@ -305,6 +346,22 @@ detail::ScheduleTree* MappedScop::separateReduction(detail::ScheduleTree* st) {
305346
return st->ancestor(root, 2);
306347
}
307348

349+
detail::ScheduleTree* MappedScop::mapThreadsBackward(
350+
detail::ScheduleTree* band) {
351+
auto bandNode = band->elemAs<detail::ScheduleTreeElemBand>();
352+
CHECK(bandNode);
353+
auto nMember = bandNode->nMember();
354+
auto nToMap = std::min(nMember, numThreads.view.size());
355+
CHECK_LE(nToMap, 3) << "mapping to too many threads";
356+
357+
auto ctx = band->ctx_;
358+
insertNodeBelow(band, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
359+
360+
auto list = bandNode->mupa_.get_union_pw_aff_list().reverse();
361+
list = list.drop(nToMap, list.n() - nToMap);
362+
return map<mapping::ThreadId>(band, list);
363+
}
364+
308365
size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
309366
using namespace tc::polyhedral::detail;
310367

@@ -355,20 +412,9 @@ size_t MappedScop::mapToThreads(detail::ScheduleTree* band) {
355412
bandSplit(scop_->scheduleRoot(), band, nMappedThreads);
356413
}
357414

358-
auto ctx = band->ctx_;
359-
insertNodeBelow(band, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
360-
361415
CHECK_GT(nMappedThreads, 0) << "not mapping to threads";
362-
CHECK_LE(nMappedThreads, 3) << "mapping to too many threads";
363416

364-
// Map the coincident dimensions to threads starting from the innermost and
365-
// from thread x.
366-
for (size_t i = 0; i < nMappedThreads; ++i) {
367-
auto id = mapping::ThreadId::makeId(i);
368-
auto dim = nMappedThreads - 1 - i;
369-
band = map(band, dim, id);
370-
}
371-
mapRemaining<mapping::ThreadId>(band, nMappedThreads);
417+
mapThreadsBackward(band);
372418

373419
if (isReduction) {
374420
splitOutReductionAndInsertSyncs(band, nMappedThreads - 1);

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,25 +87,17 @@ class MappedScop {
8787
std::unique_ptr<Scop>&& scopUPtr,
8888
const CudaMappingOptions& mappingOptions);
8989

90-
// Map a particular "pos"-th dimension in a _band_ node identified by "tree"
91-
// to the block or thread dimension. Ancestors or descendants of "tree" must
92-
// not have a dimension already mapped to the same block or thread.
93-
inline detail::ScheduleTree*
94-
map(detail::ScheduleTree* tree, int pos, const mapping::BlockId& id) {
95-
return mapToParameterWithExtent(
96-
scop_->scheduleRoot(), tree, pos, id, id.mappingSize(numBlocks));
97-
}
98-
inline detail::ScheduleTree*
99-
map(detail::ScheduleTree* tree, int pos, const mapping::ThreadId& id) {
100-
return mapToParameterWithExtent(
101-
scop_->scheduleRoot(), tree, pos, id, id.mappingSize(numThreads));
102-
}
103-
104-
// Given that "nMapped" identifiers of type "MappingTypeId" have already
105-
// been mapped, map the remaining ones to zero
106-
// for all statement instances.
107-
template <typename MappingTypeId>
108-
void mapRemaining(detail::ScheduleTree* tree, size_t nMapped);
90+
// Map the initial (up to "nToMap") band members of "band"
91+
// to successive block identifiers.
92+
// This function can only be called once on the entire tree.
93+
detail::ScheduleTree* mapBlocksForward(
94+
detail::ScheduleTree* band,
95+
size_t nToMap);
96+
// Map the final band members of "band"
97+
// to successive thread identifiers, with the last member mapped
98+
// to thread identifier X.
99+
// This function can only be called once in any branch of the tree.
100+
detail::ScheduleTree* mapThreadsBackward(detail::ScheduleTree* band);
109101

110102
// Fix the values of the specified parameters in the context
111103
// to the corresponding specified values.
@@ -136,6 +128,14 @@ class MappedScop {
136128
}
137129

138130
private:
131+
// Map the elements in "list" to successive blocks or thread identifiers,
132+
// with the first element mapped to identifier X.
133+
// Return a pointer to the updated node (below the inserted filter)
134+
// for call chaining purposes.
135+
template <typename MappingTypeId>
136+
detail::ScheduleTree* map(
137+
detail::ScheduleTree* tree,
138+
isl::union_pw_aff_list list);
139139
// Map "band" to block identifiers and then scale
140140
// the band members by "tileSizes".
141141
void mapToBlocksAndScaleBand(

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
8383
throw promotion::PromotionLogicError("no copy band");
8484
}
8585

86-
auto ctx = node->ctx_;
87-
insertNodeBelow(
88-
bandNode, detail::ScheduleTree::makeThreadSpecificMarker(ctx));
89-
9086
// Check that we are not mapping to threads below other thread mappings.
9187
std::unordered_set<mapping::ThreadId, mapping::ThreadId::Hash> usedThreads;
9288
for (auto n : node->ancestors(root)) {
@@ -97,20 +93,7 @@ void mapCopiesToThreads(MappedScop& mscop, bool unroll) {
9793
}
9894
}
9995

100-
// Map band dimensions to threads, in inverse order since the last member
101-
// iterates over the last subscript and is likely to result in coalescing.
102-
// If not all available thread ids are used, fix remaining to 1 thread.
103-
auto nToMap = std::min(band->nMember(), mscop.numThreads.view.size());
104-
for (size_t t = 0; t < nToMap; ++t) {
105-
auto pos = band->nMember() - 1 - t;
106-
mapToParameterWithExtent(
107-
root,
108-
bandNode,
109-
pos,
110-
mapping::ThreadId::makeId(t),
111-
mscop.numThreads.view[t]);
112-
}
113-
mscop.mapRemaining<mapping::ThreadId>(bandNode, nToMap);
96+
mscop.mapThreadsBackward(bandNode);
11497

11598
// Unroll if requested.
11699
if (unroll) {

tc/core/polyhedral/schedule_transforms-inl.h

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,35 +39,5 @@ inline detail::ScheduleTree* insertNodeBelow(
3939
tree->appendChild(std::move(node));
4040
return tree->child({0});
4141
}
42-
43-
template <typename MappingIdType>
44-
inline detail::ScheduleTree* mapToParameterWithExtent(
45-
detail::ScheduleTree* root,
46-
detail::ScheduleTree* tree,
47-
size_t pos,
48-
MappingIdType id,
49-
size_t extent) {
50-
auto band = tree->elemAs<detail::ScheduleTreeElemBand>();
51-
CHECK(band) << "expected a band, got " << *tree;
52-
CHECK_GE(pos, 0u) << "dimension underflow";
53-
CHECK_LT(pos, band->nMember()) << "dimension overflow";
54-
CHECK_NE(extent, 0u) << "NYI: mapping to 0";
55-
56-
auto domain = activeDomainPoints(root, tree).universe();
57-
58-
// Introduce the "mapping" parameter after checking it is not already present
59-
// in the schedule space.
60-
CHECK(not band->mupa_.involves_param(id));
61-
62-
// Create mapping filter by equating the newly introduced
63-
// parameter "id" to the "pos"-th schedule dimension modulo its extent.
64-
auto upa =
65-
band->mupa_.get_union_pw_aff(pos).mod_val(isl::val(tree->ctx_, extent));
66-
upa = upa.sub(isl::union_pw_aff::param_on_domain(domain, id));
67-
auto filter = upa.zero_union_set();
68-
auto mapping =
69-
detail::ScheduleTree::makeMappingFilter<MappingIdType>(filter, {id});
70-
return insertNodeAbove(root, tree, std::move(mapping))->child({0});
71-
}
7242
} // namespace polyhedral
7343
} // namespace tc

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -754,52 +754,5 @@ void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) {
754754
parent->insertChild(childPos, std::move(seq));
755755
}
756756

757-
detail::ScheduleTree* mergeConsecutiveMappingFilters(
758-
detail::ScheduleTree* root,
759-
detail::ScheduleTree* node) {
760-
CHECK(
761-
root->elemAs<ScheduleTreeElemDomain>() ||
762-
root->elemAs<ScheduleTreeElemExtension>());
763-
bool changed = true;
764-
while (changed) {
765-
changed = false;
766-
auto filterNodes = detail::ScheduleTree::collect(
767-
node, detail::ScheduleTreeType::MappingFilter);
768-
769-
for (auto f : filterNodes) {
770-
auto p = f->ancestor(root, 1);
771-
auto parentFilter = p->elemAs<ScheduleTreeElemMappingFilter>();
772-
if (!parentFilter) {
773-
continue;
774-
}
775-
auto filter = f->elemAs<ScheduleTreeElemMappingFilter>();
776-
auto merged = parentFilter->filter_ & filter->filter_;
777-
// We can only merge filters that have the same number of tuples
778-
if (merged.n_set() != parentFilter->filter_.n_set() ||
779-
merged.n_set() != filter->filter_.n_set()) {
780-
continue;
781-
}
782-
p->elemAs<ScheduleTreeElemMappingFilter>()->filter_ = merged;
783-
// const cast to replace in place rather than construct a new
784-
// ScheduleTree object (which would not be more functional-style anyway)
785-
auto& ids = const_cast<std::unordered_set<
786-
mapping::MappingId,
787-
typename mapping::MappingId::Hash>&>(
788-
p->elemAs<ScheduleTreeElemMappingFilter>()->mappingIds);
789-
for (auto id : filter->mappingIds) {
790-
CHECK_EQ(0u, ids.count(id))
791-
<< "Error when merging filters\n"
792-
<< *f << "\nand\n"
793-
<< *p << "\nid: " << id << " mapped in both!";
794-
ids.insert(id);
795-
}
796-
p->replaceChild(f->positionInParent(p), f->detachChild(0));
797-
changed = true;
798-
break;
799-
}
800-
}
801-
return node;
802-
}
803-
804757
} // namespace polyhedral
805758
} // namespace tc

tc/core/polyhedral/schedule_transforms.h

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,6 @@ detail::ScheduleTree* bandScale(
114114
detail::ScheduleTree* tree,
115115
const std::vector<size_t>& scales);
116116

117-
// Map "pos"-th schedule dimension of the band node identified by "tree" to a
118-
// _new_ parameter identified by "id" and limited by 0 <= id < extent. The
119-
// parameter must not be present in the space of partial schedule of "tree" and
120-
// extent must be non-zero. The mapping corresponds to inserting a filter
121-
// node with condition 'dim % extent = id' where dim is "pos"-th
122-
// schedule dimension.
123-
//
124-
// Returns a pointer to the updated band (below the inserted filter)
125-
// for call chaining purposes.
126-
template <typename MappingIdType>
127-
detail::ScheduleTree* mapToParameterWithExtent(
128-
detail::ScheduleTree* root,
129-
detail::ScheduleTree* tree,
130-
size_t pos,
131-
MappingIdType id,
132-
size_t extent);
133-
134117
// Update the top-level conext node by intersecting it with "context". The
135118
// top-level context node must be located directly under the root of the tree.
136119
// If there is no such node, insert one with universe context first.
@@ -314,15 +297,6 @@ isl::union_set activeDomainPointsBelow(
314297
const detail::ScheduleTree* root,
315298
const detail::ScheduleTree* node);
316299

317-
////////////////////////////////////////////////////////////////////////////////
318-
// Experimental
319-
////////////////////////////////////////////////////////////////////////////////
320-
// Mapping filters are introduced one mapping dimension at a time.
321-
// This merges consecutive filters.
322-
detail::ScheduleTree* mergeConsecutiveMappingFilters(
323-
detail::ScheduleTree* root,
324-
detail::ScheduleTree* node);
325-
326300
} // namespace polyhedral
327301
} // namespace tc
328302

0 commit comments

Comments
 (0)