@@ -102,24 +102,34 @@ isl::union_map partialSchedule(
102
102
}
103
103
104
104
namespace {
105
+ /*
106
+ * If "node" is any filter, then intersect "domain" with that filter.
107
+ */
108
+ isl::union_set applyAnyFilter (isl::union_set domain, const ScheduleTree* node) {
109
+ if (auto filterElem = node->elemAsBase <ScheduleTreeElemFilter>()) {
110
+ return domain.intersect (filterElem->filter_ );
111
+ }
112
+ return domain;
113
+ }
114
+
105
115
// Get the set of domain elements that are active below
106
- // the given branch of nodes.
116
+ // the given branch of nodes, filtered using "filter" .
107
117
//
108
- // Domain elements are introduced by the root domain node. Filter nodes
109
- // disable the points that do not intersect with the filter. Extension nodes
118
+ // Domain elements are introduced by the root domain node. Some nodes
119
+ // refine this set of elements based on " filter" . Extension nodes
110
120
// are considered to introduce additional domain points.
111
- isl::union_set activeDomainPointsHelper (
121
+ isl::union_set collectDomain (
112
122
const ScheduleTree* root,
113
- const vector<const ScheduleTree*>& nodes) {
123
+ const vector<const ScheduleTree*>& nodes,
124
+ isl::union_set (*filter)(isl::union_set domain, const ScheduleTree* node)) {
114
125
auto domainElem = root->elemAs <ScheduleTreeElemDomain>();
115
126
TC_CHECK (domainElem) << " root must be a Domain node" << *root;
116
127
117
128
auto domain = domainElem->domain_ ;
118
129
119
130
for (auto anc : nodes) {
120
- if (auto filterElem = anc->elemAsBase <ScheduleTreeElemFilter>()) {
121
- domain = domain.intersect (filterElem->filter_ );
122
- } else if (auto extensionElem = anc->elemAs <ScheduleTreeElemExtension>()) {
131
+ domain = filter (domain, anc);
132
+ if (auto extensionElem = anc->elemAs <ScheduleTreeElemExtension>()) {
123
133
auto parentSchedule = prefixSchedule (root, anc);
124
134
auto extension = extensionElem->extension_ ;
125
135
TC_CHECK (parentSchedule) << " missing root domain node" ;
@@ -129,6 +139,15 @@ isl::union_set activeDomainPointsHelper(
129
139
}
130
140
return domain;
131
141
}
142
+
143
+ // Get the set of domain elements that are active below
144
+ // the given branch of nodes.
145
+ isl::union_set activeDomainPointsHelper (
146
+ const ScheduleTree* root,
147
+ const vector<const ScheduleTree*>& nodes) {
148
+ return collectDomain (root, nodes, &applyAnyFilter);
149
+ }
150
+
132
151
} // namespace
133
152
134
153
isl::union_set activeDomainPoints (
0 commit comments