Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 934a4ba

Browse files
author
Sven Verdoolaege
committed
activeDomainPointsHelper: extract out collectDomain
This will allow collectDomain to be reused to implement a prefixMappingFilter in the next commit.
1 parent 14ad694 commit 934a4ba

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,24 +102,34 @@ isl::union_map partialSchedule(
102102
}
103103

104104
namespace {
105+
/*
106+
* If "node" is any filter, then intersect "domain" with that filter.
107+
*/
108+
isl::union_set applyAnyFilter(isl::union_set domain, const ScheduleTree* node) {
109+
if (auto filterElem = node->elemAsBase<ScheduleTreeElemFilter>()) {
110+
return domain.intersect(filterElem->filter_);
111+
}
112+
return domain;
113+
}
114+
105115
// Get the set of domain elements that are active below
106-
// the given branch of nodes.
116+
// the given branch of nodes, filtered using "filter".
107117
//
108-
// Domain elements are introduced by the root domain node. Filter nodes
109-
// disable the points that do not intersect with the filter. Extension nodes
118+
// Domain elements are introduced by the root domain node. Some nodes
119+
// refine this set of elements based on "filter". Extension nodes
110120
// are considered to introduce additional domain points.
111-
isl::union_set activeDomainPointsHelper(
121+
isl::union_set collectDomain(
112122
const ScheduleTree* root,
113-
const vector<const ScheduleTree*>& nodes) {
123+
const vector<const ScheduleTree*>& nodes,
124+
isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) {
114125
auto domainElem = root->elemAs<ScheduleTreeElemDomain>();
115126
TC_CHECK(domainElem) << "root must be a Domain node" << *root;
116127

117128
auto domain = domainElem->domain_;
118129

119130
for (auto anc : nodes) {
120-
if (auto filterElem = anc->elemAsBase<ScheduleTreeElemFilter>()) {
121-
domain = domain.intersect(filterElem->filter_);
122-
} else if (auto extensionElem = anc->elemAs<ScheduleTreeElemExtension>()) {
131+
domain = filter(domain, anc);
132+
if (auto extensionElem = anc->elemAs<ScheduleTreeElemExtension>()) {
123133
auto parentSchedule = prefixSchedule(root, anc);
124134
auto extension = extensionElem->extension_;
125135
TC_CHECK(parentSchedule) << "missing root domain node";
@@ -129,6 +139,15 @@ isl::union_set activeDomainPointsHelper(
129139
}
130140
return domain;
131141
}
142+
143+
// Get the set of domain elements that are active below
144+
// the given branch of nodes.
145+
isl::union_set activeDomainPointsHelper(
146+
const ScheduleTree* root,
147+
const vector<const ScheduleTree*>& nodes) {
148+
return collectDomain(root, nodes, &applyAnyFilter);
149+
}
150+
132151
} // namespace
133152

134153
isl::union_set activeDomainPoints(

0 commit comments

Comments
 (0)