Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a54ec29

Browse files
author
Sven Verdoolaege
committed
reductionInitsUpdates: return updates as isl::union_set
Return inits and updates in the same way improves consistency. The updates in isl::union_set will also be useful for simplifying findFirstReductionDim.
1 parent 0579904 commit a54ec29

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,13 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
211211
auto initsUpdates = reductionInitsUpdates(band->mupa_.domain(), scop());
212212
auto inits = initsUpdates.first;
213213
auto updates = initsUpdates.second;
214-
if (updates.size() != 1) {
214+
if (updates.n_set() != 1) {
215215
return false;
216216
}
217+
std::vector<isl::id> updateIds;
218+
updates.foreach_set([&updateIds](isl::set set) {
219+
updateIds.emplace_back(set.get_tuple_id());
220+
});
217221
// The reduction member needs to appear right underneath
218222
// the coincident members.
219223
auto reductionDim = findFirstReductionDim(band->mupa_, scop());
@@ -229,7 +233,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
229233
orderBefore(scop_->scheduleRoot(), tree, inits);
230234
}
231235
reductionFromParent_.emplace(tree, reductionTree);
232-
reductionBandUpdates_.emplace(reductionTree, updates);
236+
reductionBandUpdates_.emplace(reductionTree, updateIds);
233237
return true;
234238
}
235239

tc/core/polyhedral/reduction_matcher.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,37 +134,37 @@ isl::id statementId(const Scop& scop, const Halide::Internal::Stmt& stmt) {
134134

135135
} // namespace
136136

137-
std::pair<isl::union_set, std::vector<isl::id>> reductionInitsUpdates(
137+
std::pair<isl::union_set, isl::union_set> reductionInitsUpdates(
138138
isl::union_set domain,
139139
const Scop& scop) {
140140
auto initUnion = isl::union_set::empty(domain.get_space());
141-
std::vector<isl::id> update;
141+
auto update = initUnion;
142142
std::unordered_set<isl::id, isl::IslIdIslHash> init;
143143
std::vector<isl::set> nonUpdate;
144-
// First collect all the update statement identifiers,
144+
// First collect all the update statements,
145145
// the corresponding init statement and all non-update statements.
146146
domain.foreach_set([&init, &update, &nonUpdate, &scop](isl::set set) {
147147
auto setId = set.get_tuple_id();
148148
Halide::Internal::Stmt initStmt;
149149
std::vector<size_t> reductionDims;
150150
if (isReductionUpdateId(setId, scop, initStmt, reductionDims)) {
151-
update.emplace_back(setId);
151+
update = update.unite(set);
152152
init.emplace(statementId(scop, initStmt));
153153
} else {
154154
nonUpdate.emplace_back(set);
155155
}
156156
});
157157
// Then check if all the non-update statements are init statements
158158
// that correspond to the update statements found.
159-
// If not, return an empty list of update statement identifiers.
159+
// If not, return an empty list of update statements.
160160
for (auto set : nonUpdate) {
161161
if (init.count(set.get_tuple_id()) != 1) {
162-
return std::pair<isl::union_set, std::vector<isl::id>>(
163-
initUnion, std::vector<isl::id>());
162+
return std::pair<isl::union_set, isl::union_set>(
163+
initUnion, isl::union_set::empty(domain.get_space()));
164164
}
165165
initUnion = initUnion.unite(set);
166166
}
167-
return std::pair<isl::union_set, std::vector<isl::id>>(initUnion, update);
167+
return std::pair<isl::union_set, isl::union_set>(initUnion, update);
168168
}
169169

170170
int findFirstReductionDim(isl::multi_union_pw_aff islMupa, const Scop& scop) {

tc/core/polyhedral/schedule_tree_matcher.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ namespace tc {
2626
namespace polyhedral {
2727

2828
// Return the union of the reduction init statements as well as
29-
// the identifiers of all reduction update statements
29+
// the union of the reduction update statements
3030
// that appear in "domain", assuming "domain" only contains
3131
// reduction init and update statements.
32-
// If "domain" contains any other statements, then return an empty vector
33-
// of identifiers.
34-
std::pair<isl::union_set, std::vector<isl::id>> reductionInitsUpdates(
32+
// If "domain" contains any other statements, then return an empty set
33+
// of reduction update statements.
34+
std::pair<isl::union_set, isl::union_set> reductionInitsUpdates(
3535
isl::union_set domain,
3636
const Scop& scop);
3737

0 commit comments

Comments
 (0)