Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 3e2da9c

Browse files
committed
extractDomainToIds: take into account filters
The original implementation of extractDomainToThreads was assuming that thread mapping filter nodes also include constraints on the domain elements active at that position in the schedule tree. This is not the case in practice, the mapping filters are defined for statements without restricting the instance sets. Therefore, the thread mapping functions computed by extractDomainToThreads are incorrect for trees that schedule different instances of the same statement in different branches. In particular, calling "union_add" takes a sum of the two mappings for the statements whose instances appear in different branches. Intersect the domain of the thread mapping function with the set of points active at the mapping node. Introduce a check that the result of this intersection and the set of instances with known mapping is disjoint. Note that, to enable the check that all domain points were mapped, the intersection must be made with the active domain points without accounting for ancestor mapping filters. In particular, if extractDomainToIds is called on the root of a mapped tree and looks for thread ids, the mapping to blocks will be considered to filter out some domain points and the thread mapping will not cover the entire domain. Introduce "activeDomainPointsNoMapping" to avoid this problem. Unlike "activeDomainPoints", this function does not take into account extension nodes. This is currently safe because the callers of "extractDomainToIds" do not rely on it returning the mapping of the elements introduced by extension nodes. Comment them accordingly. In the future, we should reconsider activeDomainPoints interpreting MappingFilter nodes as regular filters, or mappings being filters in general, as well as "activeDomainPointsNoMapping" ignoring extension nodes.
1 parent be028d4 commit 3e2da9c

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

tc/core/polyhedral/cuda/mapped_scop.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,17 @@ class MappedScop {
190190
// to the thread identifiers, where all branches in "tree"
191191
// are assumed to have been mapped to thread identifiers.
192192
// The result lives in a space of the form block[x, ...].
193+
//
194+
// Note: this function ignores statements introduced by extension nodes.
193195
isl::multi_union_pw_aff threadMappingSchedule(
194196
const detail::ScheduleTree* tree) const;
195197

196198
// Extract a mapping from the domain elements active at "tree"
197199
// to the block identifiers, where all branches in "tree"
198200
// are assumed to have been mapped to block identifiers.
199201
// The result lives in a space of the form grid[x, ...].
202+
//
203+
// Note: this function ignores statements introduced by extension nodes.
200204
isl::multi_union_pw_aff blockMappingSchedule(
201205
const detail::ScheduleTree* tree) const;
202206

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ isl::union_set activeDomainPointsBelow(
145145
return activeDomainPointsHelper(root, ancestors);
146146
}
147147

148+
isl::union_set activeDomainPointsNoMappingNoExtension(
149+
const detail::ScheduleTree* root,
150+
const detail::ScheduleTree* tree) {
151+
auto domain = root->elemAs<detail::ScheduleTreeElemDomain>()->domain_;
152+
for (auto t : tree->ancestors(root)) {
153+
if (auto f = t->elemAs<detail::ScheduleTreeElemFilter>()) {
154+
domain = domain.intersect(f->filter_);
155+
}
156+
}
157+
return domain;
158+
}
159+
148160
vector<ScheduleTree*> collectScheduleTreesPath(
149161
std::function<ScheduleTree*(ScheduleTree*)> next,
150162
ScheduleTree* start) {
@@ -804,6 +816,10 @@ void orderAfter(ScheduleTree* root, ScheduleTree* tree, isl::union_set filter) {
804816
* to identifiers "ids", where all branches in "tree"
805817
* are assumed to have been mapped to these identifiers.
806818
* The result lives in a space of the form "tupleId"["ids"...].
819+
*
820+
* Note: this function only takes into account points that are present in the
821+
* root domain node. Those introduced by extension nodes are ignored. This
822+
* behavior can change in the future.
807823
*/
808824
isl::multi_union_pw_aff extractDomainToIds(
809825
const detail::ScheduleTree* root,
@@ -833,12 +849,17 @@ isl::multi_union_pw_aff extractDomainToIds(
833849
continue;
834850
}
835851
auto nodeToIds = isl::multi_union_pw_aff(space, list);
852+
auto active = activeDomainPointsNoMappingNoExtension(root, mapping);
853+
TC_CHECK(active.intersect(domainToIds.domain()).is_empty())
854+
<< "conflicting mappings; are the filters in the tree disjoint?";
855+
nodeToIds = nodeToIds.intersect_domain(active);
836856
domainToIds = domainToIds.union_add(nodeToIds);
837857
}
838858

839-
TC_CHECK(activeDomainPoints(root, tree).is_subset(domainToIds.domain()))
840-
<< "not all domain points of" << activeDomainPoints(root, tree)
841-
<< "were mapped to the required ids";
859+
auto active = activeDomainPointsNoMappingNoExtension(root, tree);
860+
TC_CHECK(active.is_subset(domainToIds.domain()))
861+
<< "not all domain points of\n"
862+
<< active << "\nwere mapped to the required ids";
842863

843864
return domainToIds;
844865
}

tc/core/polyhedral/schedule_transforms.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,15 @@ isl::union_set activeDomainPointsBelow(
326326
const detail::ScheduleTree* root,
327327
const detail::ScheduleTree* node);
328328

329+
// Get the set of domain points active below the given node without including
330+
// the points introduced by extension nodes and without treating mapping nodes
331+
// as filters. A point is considered active at a schedule node "tree" if it is
332+
// present in the "root" domain node and was not filtered away on the path from
333+
// "root" to "tree". The root must be a domain element.
334+
isl::union_set activeDomainPointsNoMappingNoExtension(
335+
const detail::ScheduleTree* root,
336+
const detail::ScheduleTree* tree);
337+
329338
// Extract a mapping from the domain elements active at "tree" (in a tree
330339
// rooted at "root") to identifiers "ids", where all branches in "tree" are
331340
// assumed to have been mapped to these identifiers. The result lives in a

0 commit comments

Comments
 (0)