Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1ab9fbd

Browse files
author
Sven Verdoolaege
committed
[RFC] use templated isl types
Templated isl types require the user to specify the domain and range universes of isl objects, allowing the compiler to check whether it makes sense to combine pairs of objects. This RFC only converts isPromotableToRegistersBelow and some related functions to illustrate the effect. The isPromotableToRegistersBelow was already applying operations correctly, so the code itself did not require any changes. However, one variable was reused to store different types of intermediate result and this one had to be split up into several variables because they now have different types.
1 parent 98a402a commit 1ab9fbd

File tree

7 files changed

+23
-24
lines changed

7 files changed

+23
-24
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ bool MappedScop::detectReductions(ScheduleTree* tree) {
288288
// a single reduction for now.
289289
// Support for multiple reductions would require a check
290290
// that these reductions do not interfere with each other.
291-
auto domain = isl::UnionSet<Statement>(band->mupa_.domain());
291+
auto domain = band->mupa_.domain();
292292
auto updates = reductionUpdates(domain, scop());
293293
if (updates.n_set() != 1) {
294294
return false;
@@ -560,7 +560,7 @@ Scop::SyncLevel MappedScop::findBestSync(
560560
auto activePoints2 = activeDomainPointsBelow(stRoot, st2);
561561

562562
// The dependences between the two schedule trees
563-
isl::union_map dependences = scop_->dependences;
563+
auto dependences = scop_->dependences;
564564
dependences = dependences.intersect_domain(activePoints1);
565565
dependences = dependences.intersect_range(activePoints2);
566566
if (dependences.is_empty()) {

tc/core/polyhedral/cuda/memory_promotion_heuristic.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ bool accessSubscriptsAreUnrolledLoops(
333333
}
334334
}
335335

336-
auto space =
337-
subdomain.get_space().add_unnamed_tuple_ui<Unrolled>(unrolledDims.size());
336+
auto space = subdomain.get_space().add_unnamed_tuple_ui<Unrolled>(
337+
unrolledDims.size());
338338
auto unrolledDimsMupa = isl::multi_union_pw_aff(space, unrolledDims);
339339

340340
// It is possible that no loops are unrolled, in which case
@@ -343,7 +343,7 @@ bool accessSubscriptsAreUnrolledLoops(
343343
unrolledDimsMupa =
344344
unrolledDimsMupa.intersect_domain(group.originalAccesses().domain());
345345

346-
auto accesses = group.originalAccesses();
346+
isl::union_map accesses = group.originalAccesses();
347347
auto schedule = outerSchedule.flat_range_product(unrolledDimsMupa);
348348
accesses = accesses.apply_domain(isl::union_map::from(schedule));
349349

@@ -656,16 +656,15 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
656656
auto blockSchedule = mscop.blockMappingSchedule(mscop.schedule());
657657

658658
// Pure affine schedule without (mapping) filters.
659-
isl::multi_union_pw_aff partialSchedMupa =
660-
partialScheduleMupa<Scope>(root, scope);
659+
auto partialSchedMupa = partialScheduleMupa<Scope>(root, scope);
661660
// Schedule with block mapping filter.
662661
auto partialSched =
663662
isl::union_map::from(partialSchedMupa).intersect_domain(blockMapping);
664663
// The following promotion validity and profitability checks need to be
665664
// performed with respect to the block mapping, so append the block schedule.
666665
// If the partial schedule contains it already, it will just end up with
667666
// identical dimensions without affecting the result of the checks.
668-
partialSchedMupa = partialSchedMupa.flat_range_product(blockSchedule);
667+
auto partialSchedBlockMupa = partialSchedMupa.range_product(blockSchedule);
669668

670669
for (auto& tensorGroups : groupMap) {
671670
auto tensorId = tensorGroups.first;
@@ -679,11 +678,11 @@ void promoteToRegistersBelow(MappedScop& mscop, detail::ScheduleTree* scope) {
679678
continue;
680679
}
681680
if (!isPromotableToRegistersBelow(
682-
*group, root, scope, partialSchedMupa, threadSchedule)) {
681+
*group, root, scope, partialSchedBlockMupa, threadSchedule)) {
683682
continue;
684683
}
685684
// Check reuse within threads.
686-
auto schedule = partialSchedMupa.flat_range_product(threadSchedule);
685+
auto schedule = partialSchedBlockMupa.range_product(threadSchedule);
687686
if (!hasReuseWithin(*group, schedule)) {
688687
continue;
689688
}

tc/core/polyhedral/memory_promotion.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ class TensorReferenceGroup {
137137
// range spaces.
138138
isl::union_map originalWrites() const;
139139
isl::union_map originalReads() const;
140-
isl::union_map originalAccesses() const {
141-
return originalWrites().unite(originalReads());
140+
isl::UnionMap<Statement, Tensor> originalAccesses() const {
141+
auto accesses = originalWrites().unite(originalReads());
142+
return isl::UnionMap<Statement, Tensor>(accesses);
142143
}
143144

144145
// Rectangular overapproximation of the set of tensor elements accessed below

tc/core/polyhedral/schedule_utils-inl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,10 @@ inline isl::MultiUnionPwAff<Statement, Schedule> partialScheduleMupa(
129129
const detail::ScheduleTree* tree) {
130130
using namespace polyhedral::detail;
131131

132-
isl::multi_union_pw_aff prefix = prefixScheduleMupa<Schedule>(root, tree);
132+
auto prefix = prefixScheduleMupa<Schedule>(root, tree);
133133
auto band = tree->as<detail::ScheduleTreeBand>();
134-
auto partial = band ? prefix.flat_range_product(band->mupa_) : prefix;
135-
return isl::MultiUnionPwAff<Statement, Schedule>(partial);
134+
return band ? prefix.template flat_range_product<Schedule>(band->mupa_)
135+
: prefix;
136136
}
137137

138138
/*
@@ -171,7 +171,7 @@ isl::MultiUnionPwAff<Statement, MappingType> extractDomainToIds(
171171
continue;
172172
}
173173
auto nodeToIds = isl::MultiUnionPwAff<Statement, MappingType>(space, list);
174-
auto active = isl::UnionSet<Statement>(activeDomainPoints(root, mapping));
174+
auto active = activeDomainPoints(root, mapping);
175175
TC_CHECK(active.intersect(domainToIds.domain()).is_empty())
176176
<< "conflicting mappings; are the filters in the tree disjoint?";
177177
nodeToIds = nodeToIds.intersect_domain(active);

tc/core/polyhedral/schedule_utils.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ isl::UnionSet<Statement> collectDomain(
7979
domain = domain.unite(parentSchedule.range().apply(extension));
8080
}
8181
}
82-
return domain;
82+
return isl::UnionSet<Statement>(domain);
8383
}
8484

8585
// Get the set of domain elements that are active below
8686
// the given branch of nodes.
87-
isl::union_set activeDomainPointsHelper(
87+
isl::UnionSet<Statement> activeDomainPointsHelper(
8888
const ScheduleTree* root,
8989
const vector<const ScheduleTree*>& nodes) {
9090
return collectDomain(root, nodes, &applyFilter);
@@ -101,16 +101,15 @@ isl::UnionSet<Statement> prefixMappingFilter(
101101
isl::UnionSet<Statement> activeDomainPoints(
102102
const ScheduleTree* root,
103103
const ScheduleTree* node) {
104-
return isl::UnionSet<Statement>(
105-
activeDomainPointsHelper(root, node->ancestors(root)));
104+
return activeDomainPointsHelper(root, node->ancestors(root));
106105
}
107106

108107
isl::UnionSet<Statement> activeDomainPointsBelow(
109108
const ScheduleTree* root,
110109
const ScheduleTree* node) {
111110
auto ancestors = node->ancestors(root);
112111
ancestors.emplace_back(node);
113-
return isl::UnionSet<Statement>(activeDomainPointsHelper(root, ancestors));
112+
return activeDomainPointsHelper(root, ancestors);
114113
}
115114

116115
vector<ScheduleTree*> collectScheduleTreesPath(

tc/core/polyhedral/scop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ ScopUPtr Scop::makeScop(
8585
return makeScop(ctx, tc2halide::translate(ctx, treeRef, compilerOptions));
8686
}
8787

88-
isl::union_set& Scop::domainRef() {
88+
isl::UnionSet<Statement>& Scop::domainRef() {
8989
auto dom = scheduleRoot()->as<ScheduleTreeDomain>();
9090
TC_CHECK(dom) << "root is not a domain in: " << *scheduleRoot();
9191
// TODO: activate this when the invariant has a chance of working (i.e. we
@@ -99,7 +99,7 @@ isl::union_set& Scop::domainRef() {
9999
}
100100

101101
const isl::UnionSet<Statement> Scop::domain() const {
102-
return isl::UnionSet<Statement>(const_cast<Scop*>(this)->domainRef());
102+
return const_cast<Scop*>(this)->domainRef();
103103
}
104104

105105
std::ostream& operator<<(std::ostream& os, const Scop& s) {

tc/core/polyhedral/scop.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ struct Scop {
509509
// By analogy with generalized functions, the domain is the "support" part
510510
// of the ScheduleTree "function".
511511
private:
512-
isl::union_set& domainRef();
512+
isl::UnionSet<Statement>& domainRef();
513513

514514
public:
515515
const isl::UnionSet<Statement> domain() const;

0 commit comments

Comments
 (0)