Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 339491a

Browse files
committed
Separate out schedule_utils.{cc,h} from schedule_transforms.{cc,h}
Many of the functions in schedule_transforms.{cc,h} are not transformations but some higher-level functions that operate on (parts of) schedule trees rather than on individual nodes. These functions were introduced into schedule_transforms.{cc,h} for historical reasons. The file has grown large. Separate the non-transformation functions into separate files, schedule_utils.{cc,h}.
1 parent 2aaa675 commit 339491a

11 files changed

+467
-395
lines changed

tc/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_library(
2929
polyhedral/schedule_tree.cc
3030
polyhedral/schedule_tree_elem.cc
3131
polyhedral/schedule_print.cc
32+
polyhedral/schedule_utils.cc
3233
polyhedral/scop.cc
3334
polyhedral/separation.cc
3435
polyhedral/unroll.cc

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "tc/core/polyhedral/functional.h"
3636
#include "tc/core/polyhedral/schedule_transforms.h"
3737
#include "tc/core/polyhedral/schedule_tree_matcher.h"
38+
#include "tc/core/polyhedral/schedule_utils.h"
3839
#include "tc/core/polyhedral/scop.h"
3940
#include "tc/core/polyhedral/separation.h"
4041
#include "tc/core/polyhedral/unroll.h"

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "tc/core/polyhedral/memory_promotion.h"
2424
#include "tc/core/polyhedral/schedule_tree.h"
2525
#include "tc/core/polyhedral/schedule_tree_matcher.h"
26+
#include "tc/core/polyhedral/schedule_utils.h"
2627
#include "tc/core/polyhedral/unroll.h"
2728

2829
#include <algorithm>

tc/core/polyhedral/cuda/tighten_launch_bounds.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include "tc/core/polyhedral/cuda/mapping_types.h"
2121
#include "tc/core/polyhedral/exceptions.h"
2222
#include "tc/core/polyhedral/functional.h"
23-
#include "tc/core/polyhedral/schedule_transforms.h"
2423
#include "tc/core/polyhedral/schedule_tree.h"
24+
#include "tc/core/polyhedral/schedule_utils.h"
2525

2626
namespace tc {
2727
namespace polyhedral {

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 1 addition & 287 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "tc/core/polyhedral/mapping_types.h"
3535
#include "tc/core/polyhedral/schedule_tree_elem.h"
3636
#include "tc/core/polyhedral/schedule_tree_matcher.h"
37+
#include "tc/core/polyhedral/schedule_utils.h"
3738
#include "tc/core/scope_guard.h"
3839
#include "tc/external/isl.h"
3940

@@ -47,164 +48,6 @@ namespace tc {
4748
namespace polyhedral {
4849
using namespace detail;
4950

50-
isl::union_map extendSchedule(
51-
const ScheduleTree* node,
52-
isl::union_map schedule) {
53-
if (auto bandElem = node->elemAs<ScheduleTreeElemBand>()) {
54-
if (bandElem->nMember() > 0) {
55-
schedule =
56-
schedule.flat_range_product(isl::union_map::from(bandElem->mupa_));
57-
}
58-
} else if (auto filterElem = node->elemAs<ScheduleTreeElemFilter>()) {
59-
schedule = schedule.intersect_domain(filterElem->filter_);
60-
} else if (auto extensionElem = node->elemAs<ScheduleTreeElemExtension>()) {
61-
// FIXME: we may need to restrict the range of reversed extension map to
62-
// schedule values that correspond to active domain elements at this
63-
// point.
64-
schedule = schedule.unite(
65-
extensionElem->extension_.reverse().intersect_range(schedule.range()));
66-
}
67-
68-
return schedule;
69-
}
70-
71-
namespace {
72-
isl::union_map partialScheduleImpl(
73-
const ScheduleTree* root,
74-
const ScheduleTree* node,
75-
bool useNode) {
76-
auto nodes = node->ancestors(root);
77-
if (useNode) {
78-
nodes.push_back(node);
79-
}
80-
TC_CHECK_GT(nodes.size(), 0u) << "root node does not have a prefix schedule";
81-
auto domain = root->elemAs<ScheduleTreeElemDomain>();
82-
TC_CHECK(domain);
83-
auto schedule = isl::union_map::from_domain(domain->domain_);
84-
for (auto anc : nodes) {
85-
if (anc->elemAs<ScheduleTreeElemDomain>()) {
86-
TC_CHECK(anc == root);
87-
} else {
88-
schedule = extendSchedule(anc, schedule);
89-
}
90-
}
91-
return schedule;
92-
}
93-
} // namespace
94-
95-
isl::union_map prefixSchedule(
96-
const ScheduleTree* root,
97-
const ScheduleTree* node) {
98-
return partialScheduleImpl(root, node, false);
99-
}
100-
101-
isl::union_map partialSchedule(
102-
const ScheduleTree* root,
103-
const ScheduleTree* node) {
104-
return partialScheduleImpl(root, node, true);
105-
}
106-
107-
namespace {
108-
/*
109-
* If "node" is any filter, then intersect "domain" with that filter.
110-
*/
111-
isl::union_set applyFilter(isl::union_set domain, const ScheduleTree* node) {
112-
if (auto filterElem = node->elemAs<ScheduleTreeElemFilter>()) {
113-
return domain.intersect(filterElem->filter_);
114-
}
115-
return domain;
116-
}
117-
118-
/*
119-
* If "node" is a mapping, then intersect "domain" with its filter.
120-
*/
121-
isl::union_set applyMapping(isl::union_set domain, const ScheduleTree* node) {
122-
if (auto filterElem = node->elemAs<ScheduleTreeElemMapping>()) {
123-
return domain.intersect(filterElem->filter_);
124-
}
125-
return domain;
126-
}
127-
128-
// Get the set of domain elements that are active below
129-
// the given branch of nodes, filtered using "filter".
130-
//
131-
// Domain elements are introduced by the root domain node. Some nodes
132-
// refine this set of elements based on "filter". Extension nodes
133-
// are considered to introduce additional domain points.
134-
isl::union_set collectDomain(
135-
const ScheduleTree* root,
136-
const vector<const ScheduleTree*>& nodes,
137-
isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) {
138-
auto domainElem = root->elemAs<ScheduleTreeElemDomain>();
139-
TC_CHECK(domainElem) << "root must be a Domain node" << *root;
140-
141-
auto domain = domainElem->domain_;
142-
143-
for (auto anc : nodes) {
144-
domain = filter(domain, anc);
145-
if (auto extensionElem = anc->elemAs<ScheduleTreeElemExtension>()) {
146-
auto parentSchedule = prefixSchedule(root, anc);
147-
auto extension = extensionElem->extension_;
148-
TC_CHECK(parentSchedule) << "missing root domain node";
149-
parentSchedule = parentSchedule.intersect_domain(domain);
150-
domain = domain.unite(parentSchedule.range().apply(extension));
151-
}
152-
}
153-
return domain;
154-
}
155-
156-
// Get the set of domain elements that are active below
157-
// the given branch of nodes.
158-
isl::union_set activeDomainPointsHelper(
159-
const ScheduleTree* root,
160-
const vector<const ScheduleTree*>& nodes) {
161-
return collectDomain(root, nodes, &applyFilter);
162-
}
163-
164-
} // namespace
165-
166-
isl::union_set prefixMappingFilter(
167-
const ScheduleTree* root,
168-
const ScheduleTree* node) {
169-
return collectDomain(root, node->ancestors(root), &applyMapping);
170-
}
171-
172-
isl::union_set activeDomainPoints(
173-
const ScheduleTree* root,
174-
const ScheduleTree* node) {
175-
return activeDomainPointsHelper(root, node->ancestors(root));
176-
}
177-
178-
isl::union_set activeDomainPointsBelow(
179-
const ScheduleTree* root,
180-
const ScheduleTree* node) {
181-
auto ancestors = node->ancestors(root);
182-
ancestors.emplace_back(node);
183-
return activeDomainPointsHelper(root, ancestors);
184-
}
185-
186-
vector<ScheduleTree*> collectScheduleTreesPath(
187-
std::function<ScheduleTree*(ScheduleTree*)> next,
188-
ScheduleTree* start) {
189-
vector<ScheduleTree*> res{start};
190-
auto n = start;
191-
while ((n = next(n)) != nullptr) {
192-
res.push_back(n);
193-
}
194-
return res;
195-
}
196-
197-
vector<const ScheduleTree*> collectScheduleTreesPath(
198-
std::function<const ScheduleTree*(const ScheduleTree*)> next,
199-
const ScheduleTree* start) {
200-
vector<const ScheduleTree*> res{start};
201-
auto n = start;
202-
while ((n = next(n)) != nullptr) {
203-
res.push_back(n);
204-
}
205-
return res;
206-
}
207-
20851
// Replace "tree" in the list of its parent's children with newTree.
20952
// Returns the pointer to newTree for call chaining purposes.
21053
ScheduleTree* swapSubtree(
@@ -432,85 +275,6 @@ ScheduleTree* bandScale(ScheduleTree* tree, const vector<size_t>& scales) {
432275
return tree;
433276
}
434277

435-
namespace {
436-
437-
template <typename T>
438-
vector<T> reversed(const vector<T>& vec) {
439-
vector<T> result;
440-
result.reserve(vec.size());
441-
result.insert(result.begin(), vec.rbegin(), vec.rend());
442-
return result;
443-
}
444-
445-
template <typename T>
446-
vector<const ScheduleTree*> filterType(const vector<const ScheduleTree*>& vec) {
447-
vector<const ScheduleTree*> result;
448-
for (auto e : vec) {
449-
if (e->elemAs<T>()) {
450-
result.push_back(e);
451-
}
452-
}
453-
return result;
454-
}
455-
456-
template <typename T, typename Func>
457-
T foldl(const vector<const ScheduleTree*> vec, Func op, T init = T()) {
458-
T value = init;
459-
for (auto st : vec) {
460-
value = op(st, value);
461-
}
462-
return value;
463-
}
464-
465-
template <typename... Args>
466-
ostream& operator<<(ostream& os, const vector<Args...>& v) {
467-
os << "[";
468-
bool first = true;
469-
for (auto const& ve : v) {
470-
if (!first) {
471-
os << ", ";
472-
}
473-
os << ve;
474-
first = true;
475-
}
476-
os << "]";
477-
return os;
478-
}
479-
} // namespace
480-
481-
isl::multi_union_pw_aff infixScheduleMupa(
482-
const ScheduleTree* root,
483-
const ScheduleTree* relativeRoot,
484-
const ScheduleTree* tree) {
485-
auto domainElem = root->elemAs<ScheduleTreeElemDomain>();
486-
TC_CHECK(domainElem);
487-
auto domain = domainElem->domain_.universe();
488-
auto zero = isl::multi_val::zero(domain.get_space().set_from_params());
489-
auto prefix = isl::multi_union_pw_aff(domain, zero);
490-
prefix = foldl(
491-
filterType<ScheduleTreeElemBand>(tree->ancestors(relativeRoot)),
492-
[](const ScheduleTree* st, isl::multi_union_pw_aff pref) {
493-
auto mupa = st->elemAs<ScheduleTreeElemBand>()->mupa_;
494-
return pref.flat_range_product(mupa);
495-
},
496-
prefix);
497-
return prefix;
498-
}
499-
500-
isl::multi_union_pw_aff prefixScheduleMupa(
501-
const ScheduleTree* root,
502-
const ScheduleTree* tree) {
503-
return infixScheduleMupa(root, root, tree);
504-
}
505-
506-
isl::multi_union_pw_aff partialScheduleMupa(
507-
const detail::ScheduleTree* root,
508-
const detail::ScheduleTree* tree) {
509-
auto prefix = prefixScheduleMupa(root, tree);
510-
auto band = tree->elemAs<ScheduleTreeElemBand>();
511-
return band ? prefix.flat_range_product(band->mupa_) : prefix;
512-
}
513-
514278
void updateTopLevelContext(detail::ScheduleTree* root, isl::set context) {
515279
if (!matchOne(tc::polyhedral::domain(tc::polyhedral::context(any())), root)) {
516280
root->appendChild(ScheduleTree::makeContext(
@@ -832,55 +596,5 @@ void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) {
832596
seq->insertChild(0, gistedFilter(other, parent->detachChild(childPos)));
833597
parent->insertChild(childPos, std::move(seq));
834598
}
835-
836-
/*
837-
* Extract a mapping from the domain elements active at "tree"
838-
* to identifiers "ids", where all branches in "tree"
839-
* are assumed to have been mapped to these identifiers.
840-
* The result lives in a space of the form "tupleId"["ids"...].
841-
*/
842-
isl::multi_union_pw_aff extractDomainToIds(
843-
const detail::ScheduleTree* root,
844-
const detail::ScheduleTree* tree,
845-
const std::vector<mapping::MappingId>& ids,
846-
isl::id tupleId) {
847-
using namespace polyhedral::detail;
848-
849-
auto space = isl::space(tree->ctx_, 0);
850-
auto empty = isl::union_set::empty(space);
851-
space = space.named_set_from_params_id(tupleId, ids.size());
852-
auto zero = isl::multi_val::zero(space);
853-
auto domainToIds = isl::multi_union_pw_aff(empty, zero);
854-
855-
for (auto mapping : tree->collect(tree, ScheduleTreeType::Mapping)) {
856-
auto mappingNode = mapping->elemAs<ScheduleTreeElemMapping>();
857-
auto list = isl::union_pw_aff_list(tree->ctx_, ids.size());
858-
for (auto id : ids) {
859-
if (mappingNode->mapping.count(id) == 0) {
860-
break;
861-
}
862-
auto idMap = mappingNode->mapping.at(id);
863-
list = list.add(idMap);
864-
}
865-
// Ignore this node if it does not map to all required ids.
866-
if (static_cast<size_t>(list.size()) != ids.size()) {
867-
continue;
868-
}
869-
auto nodeToIds = isl::multi_union_pw_aff(space, list);
870-
auto active = activeDomainPoints(root, mapping);
871-
TC_CHECK(active.intersect(domainToIds.domain()).is_empty())
872-
<< "conflicting mappings; are the filters in the tree disjoint?";
873-
nodeToIds = nodeToIds.intersect_domain(active);
874-
domainToIds = domainToIds.union_add(nodeToIds);
875-
}
876-
877-
auto active = activeDomainPoints(root, tree);
878-
TC_CHECK(active.is_subset(domainToIds.domain()))
879-
<< "not all domain points of\n"
880-
<< active << "\nwere mapped to the required ids";
881-
882-
return domainToIds;
883-
}
884-
885599
} // namespace polyhedral
886600
} // namespace tc

0 commit comments

Comments
 (0)