Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 33a3bb6

Browse files
Sven Verdoolaegeftynse
authored andcommitted
halide2isl: do not require init statement for reduction detection
Since the previous commit, the init statement is no longer used by the TC mapper, so there is no longer any need to keep track of it. This in turn means that there is also no longer any need to require the presence of an init statement. The new code does require that there is at least one reduction dimension in order to prune out trivial reductions. For example, in the TC below, both statements would otherwise be treated as reduction updates. def fun(float(N,K) A, float(K,M) B, float(N,M) C) -> (O) { O(i,j) +=! A(i,k) * B(k,j) O(i,j) = O(i,j) + C(i,j) } Note that the reduction is still identified by the tensor name, meaning that multiple reductions on the same tensor are still not allowed. Identifying the reduction by the update statement is left for future work.
1 parent 31dad22 commit 33a3bb6

File tree

2 files changed

+17
-47
lines changed

2 files changed

+17
-47
lines changed

tc/core/halide2isl.cc

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "tc/core/halide2isl.h"
1717

1818
#include <algorithm>
19-
#include <numeric>
2019
#include <unordered_set>
2120

2221
#include "tc/core/constants.h"
@@ -499,68 +498,42 @@ std::vector<Reduction> findReductions(const Stmt& s) {
499498
}
500499
}
501500

502-
// Keep track of any reduction variable name for use in isValidReduction
501+
// Keep track of any reduction variable name for use in visit(Provide*)
503502
void visit(const Variable* op) {
504503
if (op->reduction_domain.defined()) {
505504
reductionVars.insert(op->name);
506505
}
507506
}
508507

509-
// Check that the given update node, together with the corresponding
510-
// init node form a proper reduction pair.
511-
// In particular, check that they share some outer For nodes and
512-
// that the variables of the additional For nodes surrounding
513-
// the update node are all reduction variables.
514-
bool isValidReductionUpdate(const Provide* op) {
515-
const auto& opInitVars = initVars[op->name];
516-
auto n = opInitVars.size();
517-
if (vars.size() <= n) {
518-
return false;
519-
}
520-
if (!std::equal(opInitVars.begin(), opInitVars.end(), vars.begin())) {
521-
return false;
522-
}
523-
for (auto i = vars.begin() + n; i != vars.end(); ++i) {
524-
if (reductionVars.count(*i) == 0) {
525-
return false;
526-
}
527-
}
528-
return true;
529-
}
530-
531508
// Keep track of the names of the outer For nodes.
532509
void visit(const For* op) {
533510
vars.push_back(op->name);
534511
IRVisitor::visit(op);
535512
vars.pop_back();
536513
}
537514

538-
// Check if the node is an init node, keeping track of it,
539-
// or an update node corresponding to init node that was found before,
540-
// updating the information about the reduction.
541-
// In particular, double-check that the pair are in the right
542-
// relative positions and collect the positions of the reduction
515+
// Check if the node is an update node with at least one reduction
516+
// dimension, keeping track of the information about the reduction.
517+
// In particular, collect the positions of the reduction
543518
// dimensions in the update statement domain.
544519
// Visit the children first to ensure that all relevant
545520
// reduction variables have been found first.
546521
void visit(const Provide* op) {
547522
IRVisitor::visit(op);
548-
if (isReductionInit(op)) {
549-
reductions[op->name].init = op;
550-
initVars[op->name] = vars;
551-
} else if (isReductionUpdate(op)) {
552-
if (isValidReductionUpdate(op)) {
523+
if (isReductionUpdate(op)) {
524+
std::vector<size_t> dims;
525+
auto n = vars.size();
526+
for (size_t i = 0; i < n; ++i) {
527+
if (reductionVars.count(vars[i]) != 0) {
528+
dims.emplace_back(i);
529+
}
530+
}
531+
if (dims.size() > 0) {
553532
auto& p = reductions[op->name];
554-
CHECK(p.init.defined())
555-
<< "Missing reduction init or (unsupported) multiple updates";
556533
CHECK(!p.update.defined())
557534
<< "Multiple reduction updates not yet implemented";
558535
p.update = op;
559-
auto n = initVars[op->name].size();
560-
p.dims.resize(vars.size() - n);
561-
std::iota(p.dims.begin(), p.dims.end(), n);
562-
} else {
563-
reductions.erase(op->name);
536+
p.dims = dims;
564537
}
565538
}
566539
}
@@ -570,8 +543,6 @@ std::vector<Reduction> findReductions(const Stmt& s) {
570543
std::unordered_set<std::string> reductionVars;
571544
// The names of the outer For nodes, outermost to innermost.
572545
std::vector<std::string> vars;
573-
// For each init node, the names of its outer For nodes.
574-
std::map<std::string, std::vector<std::string>> initVars;
575546
std::map<std::string, Reduction> reductions;
576547
} finder;
577548
s.accept(&finder);

tc/core/halide2isl.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,11 @@ ScheduleTreeAndAccesses makeScheduleTree(
9494
/// Enumerate all reductions in a statement, by looking for the
9595
/// ReductionInit and ReductionUpdate markers inserted during lowering
9696
/// (see tc2halide.h).
97-
/// Each reduction object stores a reference to the init and
98-
/// the update statement, although the init statement is probably
99-
/// not strictly needed, and a list of reduction dimensions
97+
/// Each reduction object stores a reference to
98+
/// the update statement, and a list of reduction dimensions
10099
/// in the domain of the update statement.
101100
struct Reduction {
102-
Halide::Internal::Stmt init, update;
101+
Halide::Internal::Stmt update;
103102
std::vector<size_t> dims;
104103
};
105104
std::vector<Reduction> findReductions(const Halide::Internal::Stmt& s);

0 commit comments

Comments
 (0)