Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 2f98f2e

Browse files
authored
Merge pull request #336 from facebookresearch/pr/insert_node
extract out generic insertNode{Above,Below}
2 parents c8a4424 + 69e1860 commit 2f98f2e

File tree

5 files changed

+33
-80
lines changed

5 files changed

+33
-80
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ void MappedScop::mapRemaining(
109109
auto root = scop_->scheduleRoot();
110110
auto domain = activeDomainPoints(root, tree);
111111
auto filter = makeFixRemainingZeroFilter(domain, ids);
112-
insertMappingFilterAbove(root, tree, filter, ids);
112+
auto mapping = detail::ScheduleTree::makeMappingFilter(filter, ids);
113+
insertNodeAbove(root, tree, std::move(mapping));
113114

114115
for (size_t i = nMapped; i < nToMap; ++i) {
115116
if (MappingTypeId::makeId(i) == mapping::ThreadId::x()) {
@@ -171,8 +172,9 @@ void fixThreadsBelowFilter(
171172
// invariant that leaf mapping filters have a single space.
172173
// So we intersect with the universe set of the filter to only keep the
173174
// space for the legitimate statement.
174-
insertMappingFilterBelow(
175-
filterTree, mappingFilter & filter->filter_.universe(), ids);
175+
mappingFilter = mappingFilter & filter->filter_.universe();
176+
auto mapping = detail::ScheduleTree::makeMappingFilter(mappingFilter, ids);
177+
insertNodeBelow(filterTree, std::move(mapping));
176178

177179
for (size_t i = begin; i < end; ++i) {
178180
if (mapping::ThreadId::makeId(i) == mapping::ThreadId::x()) {

tc/core/polyhedral/schedule_transforms-inl.h

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,27 @@
1717

1818
namespace tc {
1919
namespace polyhedral {
20-
template <typename MappingIdType>
21-
inline detail::ScheduleTree* insertMappingFilterAbove(
20+
inline detail::ScheduleTree* insertNodeAbove(
2221
detail::ScheduleTree* root,
2322
detail::ScheduleTree* tree,
24-
isl::union_set filter,
25-
const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
26-
mappingIds) {
23+
ScheduleTreeUPtr&& node) {
24+
CHECK_EQ(node->numChildren(), 0u);
2725
auto parent = tree->ancestor(root, 1);
2826
auto childPos = tree->positionInParent(parent);
29-
parent->insertChild(
30-
childPos,
31-
detail::ScheduleTree::makeMappingFilter(
32-
filter, mappingIds, parent->detachChild(childPos)));
27+
node->appendChild(parent->detachChild(childPos));
28+
parent->insertChild(childPos, std::move(node));
3329
return parent->child({childPos});
3430
}
3531

36-
template <typename MappingIdType>
37-
inline void insertMappingFilterBelow(
32+
inline detail::ScheduleTree* insertNodeBelow(
3833
detail::ScheduleTree* tree,
39-
isl::union_set filter,
40-
const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
41-
mappingIds) {
34+
ScheduleTreeUPtr&& node) {
35+
CHECK_EQ(node->numChildren(), 0u);
4236
auto numChildren = tree->numChildren();
4337
CHECK_LE(numChildren, 1u);
44-
tree->appendChild(detail::ScheduleTree::makeMappingFilter(
45-
filter, mappingIds, tree->detachChildren()));
38+
node->appendChildren(tree->detachChildren());
39+
tree->appendChild(std::move(node));
40+
return tree->child({0});
4641
}
4742

4843
template <typename MappingIdType>
@@ -70,8 +65,9 @@ inline detail::ScheduleTree* mapToParameterWithExtent(
7065
band->mupa_.get_union_pw_aff(pos).mod_val(isl::val(tree->ctx_, extent));
7166
upa = upa.sub(isl::union_pw_aff::param_on_domain(domain, id));
7267
auto filter = upa.zero_union_set();
73-
return insertMappingFilterAbove<MappingIdType>(root, tree, filter, {id})
74-
->child({0});
68+
auto mapping =
69+
detail::ScheduleTree::makeMappingFilter<MappingIdType>(filter, {id});
70+
return insertNodeAbove(root, tree, std::move(mapping))->child({0});
7571
}
7672
} // namespace polyhedral
7773
} // namespace tc

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -462,26 +462,6 @@ isl::multi_union_pw_aff partialScheduleMupa(
462462
return prefixScheduleMupa(root, tree).flat_range_product(band->mupa_);
463463
}
464464

465-
ScheduleTree* insertBandAbove(
466-
ScheduleTree* root,
467-
ScheduleTree* tree,
468-
isl::multi_union_pw_aff mupa) {
469-
auto parent = tree->ancestor(root, 1);
470-
auto childPos = tree->positionInParent(parent);
471-
auto child = parent->detachChild(childPos);
472-
parent->insertChild(childPos, ScheduleTree::makeBand(mupa, std::move(child)));
473-
return parent->child({childPos});
474-
}
475-
476-
ScheduleTree* insertBandBelow(
477-
detail::ScheduleTree* tree,
478-
isl::multi_union_pw_aff mupa) {
479-
auto numChildren = tree->numChildren();
480-
CHECK_LE(numChildren, 1u);
481-
tree->appendChild(ScheduleTree::makeBand(mupa, tree->detachChildren()));
482-
return tree->child({0});
483-
}
484-
485465
void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
486466
if (!matchOne(tc::polyhedral::domain(tc::polyhedral::context(any())), root)) {
487467
root->appendChild(ScheduleTree::makeContext(

tc/core/polyhedral/schedule_transforms.h

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -131,27 +131,6 @@ detail::ScheduleTree* mapToParameterWithExtent(
131131
MappingIdType id,
132132
size_t extent);
133133

134-
// In a tree starting at a (relative) "root", insert a band node with the
135-
// given partial schedule above the node identified by "tree".
136-
//
137-
// The tree is modified in place.
138-
// Return a non-owning pointer to the inserted band node
139-
// for call chaining purposes.
140-
detail::ScheduleTree* insertBandAbove(
141-
detail::ScheduleTree* root,
142-
detail::ScheduleTree* tree,
143-
isl::multi_union_pw_aff mupa);
144-
145-
// Insert a band node with the given partial schedule below node "tree",
146-
// which is assumed to have at most one child.
147-
//
148-
// The tree is modified in place.
149-
// Return a non-owning pointer to the inserted band node
150-
// for call chaining purposes.
151-
detail::ScheduleTree* insertBandBelow(
152-
detail::ScheduleTree* tree,
153-
isl::multi_union_pw_aff mupa);
154-
155134
// Update the top-level conext node by intersecting it with "context". The
156135
// top-level context node must be located directly under the root of the tree.
157136
// If there is no such node, insert one with universe context first.
@@ -178,31 +157,26 @@ detail::ScheduleTree* insertExtensionAbove(
178157
detail::ScheduleTree* tree,
179158
isl::union_map extension);
180159

181-
// In a tree starting at a (relative) "root", insert a mapping filter node
182-
// with the given filter above the node identified by "tree".
160+
// In a tree starting at a (relative) "root", insert the given node
161+
// above the node identified by "tree".
183162
//
184163
// The tree is modified in place.
185-
// Return a non-owning pointer to the inserted filter node
164+
// Return a non-owning pointer to the inserted node
186165
// for call chaining purposes.
187-
template <typename MappingIdType>
188-
inline detail::ScheduleTree* insertMappingFilterAbove(
166+
inline detail::ScheduleTree* insertNodeAbove(
189167
detail::ScheduleTree* root,
190168
detail::ScheduleTree* tree,
191-
isl::union_set filter,
192-
const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
193-
mappingIds);
169+
ScheduleTreeUPtr&& node);
194170

195-
// Insert a mapping filter node below node "tree", which is assumed to have at
196-
// most one child. The underlying isl::union_set filter is constructed from
197-
// the arguments.
171+
// Insert the given node below node "tree", which is assumed to have at
172+
// most one child.
198173
//
199174
// The tree is modified in place.
200-
template <typename MappingIdType>
201-
inline void insertMappingFilterBelow(
175+
// Return a non-owning pointer to the inserted node
176+
// for call chaining purposes.
177+
inline detail::ScheduleTree* insertNodeBelow(
202178
detail::ScheduleTree* tree,
203-
isl::union_set filter,
204-
const std::unordered_set<MappingIdType, typename MappingIdType::Hash>&
205-
mappingIds);
179+
ScheduleTreeUPtr&& node);
206180

207181
// Given a sequence node in the schedule tree, insert
208182
// a zero-dimensional extension statement with the given identifier

tc/core/polyhedral/scop.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,11 @@ detail::ScheduleTree* obtainOuterBand(detail::ScheduleTree* root) {
492492
CHECK(domain);
493493
auto space = domain->domain_.get_space().set_from_params();
494494
auto zero = isl::multi_union_pw_aff::zero(space);
495+
auto band = ScheduleTree::makeBand(zero);
495496
if (n == 0) {
496-
return setPermutable(insertBandBelow(tree, zero));
497+
return setPermutable(insertNodeBelow(tree, std::move(band)));
497498
} else {
498-
return setPermutable(insertBandAbove(root, tree, zero));
499+
return setPermutable(insertNodeAbove(root, tree, std::move(band)));
499500
}
500501
}
501502
return tree;

0 commit comments

Comments
 (0)