Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e221236

Browse files
authored
Merge pull request #266 from facebookresearch/pr/reduction
further simplify reduction handling
2 parents 0579904 + d10d66e commit e221236

File tree

3 files changed

+35
-37
lines changed

3 files changed

+35
-37
lines changed

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,18 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
211211
auto initsUpdates = reductionInitsUpdates(band->mupa_.domain(), scop());
212212
auto inits = initsUpdates.first;
213213
auto updates = initsUpdates.second;
214-
if (updates.size() != 1) {
214+
if (updates.n_set() != 1) {
215215
return false;
216216
}
217+
std::vector<isl::id> updateIds;
218+
updates.foreach_set([&updateIds](isl::set set) {
219+
updateIds.emplace_back(set.get_tuple_id());
220+
});
217221
// The reduction member needs to appear right underneath
218222
// the coincident members.
219-
auto reductionDim = findFirstReductionDim(band->mupa_, scop());
220-
if (reductionDim != nCoincident) {
223+
auto reductionDim = nCoincident;
224+
auto member = band->mupa_.get_union_pw_aff(reductionDim);
225+
if (!isReductionMember(member, updates, scop())) {
221226
return false;
222227
}
223228
auto reductionTree = bandSplitOut(scop_->scheduleRoot(), tree, reductionDim);
@@ -229,7 +234,7 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
229234
orderBefore(scop_->scheduleRoot(), tree, inits);
230235
}
231236
reductionFromParent_.emplace(tree, reductionTree);
232-
reductionBandUpdates_.emplace(reductionTree, updates);
237+
reductionBandUpdates_.emplace(reductionTree, updateIds);
233238
return true;
234239
}
235240

tc/core/polyhedral/reduction_matcher.cc

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -134,56 +134,47 @@ isl::id statementId(const Scop& scop, const Halide::Internal::Stmt& stmt) {
134134

135135
} // namespace
136136

137-
std::pair<isl::union_set, std::vector<isl::id>> reductionInitsUpdates(
137+
std::pair<isl::union_set, isl::union_set> reductionInitsUpdates(
138138
isl::union_set domain,
139139
const Scop& scop) {
140140
auto initUnion = isl::union_set::empty(domain.get_space());
141-
std::vector<isl::id> update;
141+
auto update = initUnion;
142142
std::unordered_set<isl::id, isl::IslIdIslHash> init;
143143
std::vector<isl::set> nonUpdate;
144-
// First collect all the update statement identifiers,
144+
// First collect all the update statements,
145145
// the corresponding init statement and all non-update statements.
146146
domain.foreach_set([&init, &update, &nonUpdate, &scop](isl::set set) {
147147
auto setId = set.get_tuple_id();
148148
Halide::Internal::Stmt initStmt;
149149
std::vector<size_t> reductionDims;
150150
if (isReductionUpdateId(setId, scop, initStmt, reductionDims)) {
151-
update.emplace_back(setId);
151+
update = update.unite(set);
152152
init.emplace(statementId(scop, initStmt));
153153
} else {
154154
nonUpdate.emplace_back(set);
155155
}
156156
});
157157
// Then check if all the non-update statements are init statements
158158
// that correspond to the update statements found.
159-
// If not, return an empty list of update statement identifiers.
159+
// If not, return an empty list of update statements.
160160
for (auto set : nonUpdate) {
161161
if (init.count(set.get_tuple_id()) != 1) {
162-
return std::pair<isl::union_set, std::vector<isl::id>>(
163-
initUnion, std::vector<isl::id>());
162+
return std::pair<isl::union_set, isl::union_set>(
163+
initUnion, isl::union_set::empty(domain.get_space()));
164164
}
165165
initUnion = initUnion.unite(set);
166166
}
167-
return std::pair<isl::union_set, std::vector<isl::id>>(initUnion, update);
167+
return std::pair<isl::union_set, isl::union_set>(initUnion, update);
168168
}
169169

170-
int findFirstReductionDim(isl::multi_union_pw_aff islMupa, const Scop& scop) {
171-
auto mupa = isl::MUPA(islMupa);
172-
int reductionDim = -1;
173-
int currentDim = 0;
174-
for (auto const& upa : mupa) {
175-
for (auto const& pa : upa) {
176-
if (isAlmostIdentityReduction(pa.pa, scop)) {
177-
reductionDim = currentDim;
178-
break;
179-
}
180-
}
181-
if (reductionDim != -1) {
182-
break;
183-
}
184-
++currentDim;
185-
}
186-
return reductionDim;
170+
bool isReductionMember(
171+
isl::union_pw_aff member,
172+
isl::union_set domain,
173+
const Scop& scop) {
174+
return domain.every_set([member, &scop](isl::set set) {
175+
auto pa = member.extract_on_domain(set.get_space());
176+
return isAlmostIdentityReduction(pa, scop);
177+
});
187178
}
188179

189180
} // namespace polyhedral

tc/core/polyhedral/schedule_tree_matcher.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,21 @@ namespace tc {
2626
namespace polyhedral {
2727

2828
// Return the union of the reduction init statements as well as
29-
// the identifiers of all reduction update statements
29+
// the union of the reduction update statements
3030
// that appear in "domain", assuming "domain" only contains
3131
// reduction init and update statements.
32-
// If "domain" contains any other statements, then return an empty vector
33-
// of identifiers.
34-
std::pair<isl::union_set, std::vector<isl::id>> reductionInitsUpdates(
32+
// If "domain" contains any other statements, then return an empty set
33+
// of reduction update statements.
34+
std::pair<isl::union_set, isl::union_set> reductionInitsUpdates(
3535
isl::union_set domain,
3636
const Scop& scop);
3737

38-
// Find the first band member that corresponds to a reduction.
39-
// TODO: heuristic to choose the "best" band member in presence of multiple
40-
// reductions.
41-
int findFirstReductionDim(isl::multi_union_pw_aff islMupa, const Scop& scop);
38+
// Does the band member with the given partial schedule correspond
39+
// to a reduction on all statements with a domain in "domain"?
40+
bool isReductionMember(
41+
isl::union_pw_aff member,
42+
isl::union_set domain,
43+
const Scop& scop);
4244

4345
} // namespace polyhedral
4446
} // namespace tc

0 commit comments

Comments
 (0)