Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 86c1f8c

Browse files
authored
Merge pull request #448 from facebookresearch/pr/no_init
stop relying on reduction init statements
2 parents 85bd48f + 687f183 commit 86c1f8c

File tree

13 files changed

+210
-193
lines changed

13 files changed

+210
-193
lines changed

tc/core/halide2isl.cc

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "tc/core/halide2isl.h"
1717

1818
#include <algorithm>
19-
#include <numeric>
2019
#include <unordered_set>
2120

2221
#include "tc/core/constants.h"
@@ -499,68 +498,42 @@ std::vector<Reduction> findReductions(const Stmt& s) {
499498
}
500499
}
501500

502-
// Keep track of any reduction variable name for use in isValidReduction
501+
// Keep track of any reduction variable name for use in visit(Provide*)
503502
void visit(const Variable* op) {
504503
if (op->reduction_domain.defined()) {
505504
reductionVars.insert(op->name);
506505
}
507506
}
508507

509-
// Check that the given update node, together with the corresponding
510-
// init node form a proper reduction pair.
511-
// In particular, check that they share some outer For nodes and
512-
// that the variables of the additional For nodes surrounding
513-
// the update node are all reduction variables.
514-
bool isValidReductionUpdate(const Provide* op) {
515-
const auto& opInitVars = initVars[op->name];
516-
auto n = opInitVars.size();
517-
if (vars.size() <= n) {
518-
return false;
519-
}
520-
if (!std::equal(opInitVars.begin(), opInitVars.end(), vars.begin())) {
521-
return false;
522-
}
523-
for (auto i = vars.begin() + n; i != vars.end(); ++i) {
524-
if (reductionVars.count(*i) == 0) {
525-
return false;
526-
}
527-
}
528-
return true;
529-
}
530-
531508
// Keep track of the names of the outer For nodes.
532509
void visit(const For* op) {
533510
vars.push_back(op->name);
534511
IRVisitor::visit(op);
535512
vars.pop_back();
536513
}
537514

538-
// Check if the node is an init node, keeping track of it,
539-
// or an update node corresponding to init node that was found before,
540-
// updating the information about the reduction.
541-
// In particular, double-check that the pair are in the right
542-
// relative positions and collect the positions of the reduction
515+
// Check if the node is an update node with at least one reduction
516+
// dimension, keeping track of the information about the reduction.
517+
// In particular, collect the positions of the reduction
543518
// dimensions in the update statement domain.
544519
// Visit the children first to ensure that all relevant
545520
// reduction variables have been found first.
546521
void visit(const Provide* op) {
547522
IRVisitor::visit(op);
548-
if (isReductionInit(op)) {
549-
reductions[op->name].init = op;
550-
initVars[op->name] = vars;
551-
} else if (isReductionUpdate(op)) {
552-
if (isValidReductionUpdate(op)) {
523+
if (isReductionUpdate(op)) {
524+
std::vector<size_t> dims;
525+
auto n = vars.size();
526+
for (size_t i = 0; i < n; ++i) {
527+
if (reductionVars.count(vars[i]) != 0) {
528+
dims.emplace_back(i);
529+
}
530+
}
531+
if (dims.size() > 0) {
553532
auto& p = reductions[op->name];
554-
CHECK(p.init.defined())
555-
<< "Missing reduction init or (unsupported) multiple updates";
556533
CHECK(!p.update.defined())
557534
<< "Multiple reduction updates not yet implemented";
558535
p.update = op;
559-
auto n = initVars[op->name].size();
560-
p.dims.resize(vars.size() - n);
561-
std::iota(p.dims.begin(), p.dims.end(), n);
562-
} else {
563-
reductions.erase(op->name);
536+
p.dims = dims;
564537
}
565538
}
566539
}
@@ -570,8 +543,6 @@ std::vector<Reduction> findReductions(const Stmt& s) {
570543
std::unordered_set<std::string> reductionVars;
571544
// The names of the outer For nodes, outermost to innermost.
572545
std::vector<std::string> vars;
573-
// For each init node, the names of its outer For nodes.
574-
std::map<std::string, std::vector<std::string>> initVars;
575546
std::map<std::string, Reduction> reductions;
576547
} finder;
577548
s.accept(&finder);

tc/core/halide2isl.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,11 @@ ScheduleTreeAndAccesses makeScheduleTree(
9494
/// Enumerate all reductions in a statement, by looking for the
9595
/// ReductionInit and ReductionUpdate markers inserted during lowering
9696
/// (see tc2halide.h).
97-
/// Each reduction object stores a reference to the init and
98-
/// the update statement, although the init statement is probably
99-
/// not strictly needed, and a list of reduction dimensions
97+
/// Each reduction object stores a reference to
98+
/// the update statement, and a list of reduction dimensions
10099
/// in the domain of the update statement.
101100
struct Reduction {
102-
Halide::Internal::Stmt init, update;
101+
Halide::Internal::Stmt update;
103102
std::vector<size_t> dims;
104103
};
105104
std::vector<Reduction> findReductions(const Halide::Internal::Stmt& s);

tc/core/polyhedral/cuda/mapped_scop.cc

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,8 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
215215
// a single reduction for now.
216216
// Support for multiple reductions would require a check
217217
// that these reductions do not interfere with each other.
218-
auto initsUpdates = reductionInitsUpdates(band->mupa_.domain(), scop());
219-
auto inits = initsUpdates.first;
220-
auto updates = initsUpdates.second;
218+
auto domain = band->mupa_.domain();
219+
auto updates = reductionUpdates(domain, scop());
221220
if (updates.n_set() != 1) {
222221
return false;
223222
}
@@ -232,12 +231,19 @@ bool MappedScop::detectReductions(detail::ScheduleTree* tree) {
232231
if (!isReductionMember(member, updates, scop())) {
233232
return false;
234233
}
235-
// Order the init statements (if any) before the update statements
234+
// Order the other statements (if any) before the update statements
236235
// to ensure the band from which the reduction band has been split off
237236
// only contains update statements.
238-
// Note that this relies on the outer members being coincident.
239-
if (!inits.is_empty()) {
240-
orderBefore(scop_->scheduleRoot(), tree, inits);
237+
// Only do this if it doesn't violate any dependences.
238+
// TODO (#454): order statements before or after the reduction based on
239+
// dependences.
240+
auto other = domain.subtract(updates);
241+
if (!other.is_empty()) {
242+
auto dependences = scop_->activeDependences(tree);
243+
if (!canOrderBefore(scop_->scheduleRoot(), tree, other, dependences)) {
244+
return false;
245+
}
246+
orderBefore(scop_->scheduleRoot(), tree, other);
241247
}
242248
reductionBandUpdates_.emplace(tree, Reduction(updateIds));
243249
return true;

tc/core/polyhedral/reduction_matcher.cc

Lines changed: 7 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,10 @@ bool isSupportedReduction(Halide::Internal::Stmt stmt) {
4848
// the reduction. that is a kind of internal state dependence we want to avoid
4949
// If id is the statement identifier of an update statement
5050
// of a supported type of reduction,
51-
// then return the corresponding init statement in init and
52-
// the corresponding reduction dimensions in reductionDims.
51+
// then return the corresponding reduction dimensions in reductionDims.
5352
bool isReductionUpdateId(
5453
isl::id id,
5554
const Scop& scop,
56-
Halide::Internal::Stmt& init,
5755
std::vector<size_t>& reductionDims) {
5856
CHECK_EQ(scop.halide.statements.count(id), 1u)
5957
<< "id is not a statement in scop" << id;
@@ -63,7 +61,6 @@ bool isReductionUpdateId(
6361
}
6462
for (auto const& iup : scop.halide.reductions) {
6563
if (iup.update.same_as(provideNode)) {
66-
init = iup.init;
6764
reductionDims = iup.dims;
6865
return true;
6966
}
@@ -104,9 +101,8 @@ bool isAlmostIdentityReduction(isl::pw_aff pa, const Scop& scop) {
104101
return false;
105102
}
106103
auto stmtId = space.get_tuple_id(isl::dim_type::in);
107-
Halide::Internal::Stmt init;
108104
std::vector<size_t> reductionDims;
109-
if (!isReductionUpdateId(stmtId, scop, init, reductionDims)) {
105+
if (!isReductionUpdateId(stmtId, scop, reductionDims)) {
110106
return false;
111107
}
112108

@@ -118,52 +114,18 @@ bool isAlmostIdentityReduction(isl::pw_aff pa, const Scop& scop) {
118114
return false;
119115
}
120116

121-
/*
122-
* Return the identifier that maps to "stmt".
123-
*/
124-
isl::id statementId(const Scop& scop, const Halide::Internal::Stmt& stmt) {
125-
for (auto kvp : scop.halide.statements) {
126-
if (kvp.second.same_as(stmt)) {
127-
return kvp.first;
128-
}
129-
}
130-
CHECK(false) << "no id recorded for statement" << stmt;
131-
return isl::id();
132-
}
133-
134117
} // namespace
135118

136-
std::pair<isl::union_set, isl::union_set> reductionInitsUpdates(
137-
isl::union_set domain,
138-
const Scop& scop) {
139-
auto initUnion = isl::union_set::empty(domain.get_space());
140-
auto update = initUnion;
141-
std::unordered_set<isl::id, isl::IslIdIslHash> init;
142-
std::vector<isl::set> nonUpdate;
143-
// First collect all the update statements,
144-
// the corresponding init statement and all non-update statements.
145-
domain.foreach_set([&init, &update, &nonUpdate, &scop](isl::set set) {
119+
isl::union_set reductionUpdates(isl::union_set domain, const Scop& scop) {
120+
auto update = isl::union_set::empty(domain.get_space());
121+
domain.foreach_set([&update, &scop](isl::set set) {
146122
auto setId = set.get_tuple_id();
147-
Halide::Internal::Stmt initStmt;
148123
std::vector<size_t> reductionDims;
149-
if (isReductionUpdateId(setId, scop, initStmt, reductionDims)) {
124+
if (isReductionUpdateId(setId, scop, reductionDims)) {
150125
update = update.unite(set);
151-
init.emplace(statementId(scop, initStmt));
152-
} else {
153-
nonUpdate.emplace_back(set);
154126
}
155127
});
156-
// Then check if all the non-update statements are init statements
157-
// that correspond to the update statements found.
158-
// If not, return an empty list of update statements.
159-
for (auto set : nonUpdate) {
160-
if (init.count(set.get_tuple_id()) != 1) {
161-
return std::pair<isl::union_set, isl::union_set>(
162-
initUnion, isl::union_set::empty(domain.get_space()));
163-
}
164-
initUnion = initUnion.unite(set);
165-
}
166-
return std::pair<isl::union_set, isl::union_set>(initUnion, update);
128+
return update;
167129
}
168130

169131
bool isReductionMember(

tc/core/polyhedral/schedule_transforms.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,25 @@ ScheduleTreeUPtr gistedFilter(isl::union_set filter, ScheduleTreeUPtr child) {
731731

732732
} // namespace
733733

734+
bool canOrderBefore(
735+
ScheduleTree* root,
736+
ScheduleTree* tree,
737+
isl::union_set filter,
738+
isl::union_map dependences) {
739+
// Create an ordering schedule function filter -> 0; other -> 1.
740+
auto other = activeDomainPoints(root, tree).subtract(filter);
741+
auto ctx = root->ctx_;
742+
auto space = isl::space(ctx, 0).unnamed_set_from_params(1);
743+
auto zero = isl::multi_val::zero(space);
744+
auto one = zero.set_val(0, isl::val::one(ctx));
745+
auto order = isl::multi_union_pw_aff(filter, zero);
746+
order = order.union_add(isl::multi_union_pw_aff(other, one));
747+
748+
// Check that this ordering preserves all dependences.
749+
auto preserved = dependences.lex_lt_at(order).unite(dependences.eq_at(order));
750+
return dependences.is_subset(preserved);
751+
}
752+
734753
void orderBefore(
735754
ScheduleTree* root,
736755
ScheduleTree* tree,

tc/core/polyhedral/schedule_transforms.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ void insertExtensionLabelAfter(
233233
detail::ScheduleTree* tree,
234234
isl::id id);
235235

236+
// Is it possible to order the elements in the given filter
237+
// before the other active elements without violating
238+
// any of the given dependences?
239+
bool canOrderBefore(
240+
detail::ScheduleTree* root,
241+
detail::ScheduleTree* tree,
242+
isl::union_set filter,
243+
isl::union_map dependences);
244+
236245
// Insert a sequence to ensure that the active domain elements
237246
// in the given filter are executed before the other active domain elements.
238247
void orderBefore(

tc/core/polyhedral/schedule_tree_matcher.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,9 @@
2525
namespace tc {
2626
namespace polyhedral {
2727

28-
// Return the union of the reduction init statements as well as
29-
// the union of the reduction update statements
30-
// that appear in "domain", assuming "domain" only contains
31-
// reduction init and update statements.
32-
// If "domain" contains any other statements, then return an empty set
33-
// of reduction update statements.
34-
std::pair<isl::union_set, isl::union_set> reductionInitsUpdates(
35-
isl::union_set domain,
36-
const Scop& scop);
28+
// Return the union of the reduction update statements
29+
// that appear in "domain".
30+
isl::union_set reductionUpdates(isl::union_set domain, const Scop& scop);
3731

3832
// Does the band member with the given partial schedule correspond
3933
// to a reduction on all statements with a domain in "domain"?

tc/core/polyhedral/scop.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,16 @@ void Scop::computeAllDependences() {
390390
dependences = flowDeps.unite(falseDeps).coalesce();
391391
}
392392

393+
isl::union_map Scop::activeDependences(detail::ScheduleTree* tree) {
394+
auto prefix = prefixScheduleMupa(scheduleRoot(), tree);
395+
auto domain = activeDomainPoints(scheduleRoot(), tree);
396+
auto active = dependences;
397+
active = active.intersect_domain(domain);
398+
active = active.intersect_range(domain);
399+
active = active.eq_at(prefix);
400+
return active;
401+
}
402+
393403
std::unique_ptr<detail::ScheduleTree> Scop::computeSchedule(
394404
isl::schedule_constraints constraints,
395405
const SchedulerOptionsView& schedulerOptions) {

tc/core/polyhedral/scop.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ struct Scop {
471471
// Do the simplest possible dependence analysis.
472472
// Compute all RAW, WAR, and WAW dependences, and save them in dependences.
473473
void computeAllDependences();
474+
// Return the set of dependences that are active
475+
// at the given position.
476+
isl::union_map activeDependences(detail::ScheduleTree* tree);
474477

475478
public:
476479
// Halide stuff

tc/core/tc2halide.cc

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -834,38 +834,6 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
834834
s = tagReduction.mutate(s);
835835
}
836836

837-
// Temporary hack: Fuse reduction initializations into the
838-
// updates. Shouldn't matter because we're going to reschedule
839-
// everything anyway, but starting at this point is more similar to
840-
// the code this replaces. Only correct because we only currently
841-
// create update definitions when there are reductions.
842-
class FuseReductions : public IRMutator2 {
843-
using IRMutator2::visit;
844-
Stmt visit(const Block* op) override {
845-
const For* first = op->first.as<For>();
846-
const For* rest = op->rest.as<For>();
847-
if (first && rest && equal(first->min, rest->min) &&
848-
equal(first->extent, rest->extent) &&
849-
first->for_type == rest->for_type &&
850-
replace_all(first->name, ".s0.", ".s1.") == rest->name) {
851-
Stmt body = rest->body;
852-
body =
853-
substitute(rest->name, Variable::make(Int(32), first->name), body);
854-
body = mutate(Block::make(first->body, body));
855-
return For::make(
856-
first->name,
857-
first->min,
858-
first->extent,
859-
first->for_type,
860-
first->device_api,
861-
body);
862-
} else {
863-
return IRMutator2::visit(op);
864-
}
865-
}
866-
} fuser;
867-
s = fuser.mutate(s);
868-
869837
// Trim ProducerConsumer annotations. TC doesn't use them.
870838
class RemoveProducerConsumer : public IRMutator2 {
871839
using IRMutator2::visit;

0 commit comments

Comments
 (0)