Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit b81795e

Browse files
author
Sven Verdoolaege
committed
makeScheduleTree: also extract polyhedral representation of reductions
Reductions are currently represented by lists of reduction dimensions in the domains of the reduction update statements. The detection of reductions in the schedule tree, based on this information, may work out in practice, but it is difficult to reason about and is technically incorrect. This detection will be replaced by one that is based on a polyhedral representation. This commit introduces this polyhedral representation. In particular, each individual reduction (a set of updates to a given tensor element) is represented by an element in a reduction space, along with a mapping from the update statement instances that contribute to that reduction. The obvious choice for such a reduction space would be the space of the tensor involved in the reduction. However, there may in theory be multiple reductions on the same tensor element separated by some other statement, even though this is currently impossible inside TC. Still, since the reductions are extracted from Halide, it's best to only use information available in the Halide representation and to not assume an identity matching between reductions and tensor elements. It is not immediately obvious what would be the best representation for reductions. This commit uses an isl::union_map representation, but an isl::union_pw_multi_aff would also work. This commit only extracts a polyhedral representation of reductions, temporarily resulting in a duplicate representation of reductions. After all uses of the AST representation of reductions have been replaced by uses of the polyhedral representation, the AST representation will be removed.
1 parent 381dc6c commit b81795e

File tree

4 files changed

+119
-2
lines changed

4 files changed

+119
-2
lines changed

tc/core/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace polyhedral {
2222
// General constants to avoid hardcoding
2323
//
2424
constexpr auto kStatementLabel = "S_";
25+
constexpr auto kReductionLabel = "R_";
2526

2627
constexpr auto kAstNodeIdPrefix = "__node_";
2728

tc/core/halide2isl.cc

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,82 @@ bool isReductionUpdate(const Provide* op) {
342342
}
343343
}
344344

345+
/* Construct a multi-dimensional affine function mapping
346+
* the given iteration domain
347+
* to the outer loop iterators that do not appear in "skip".
348+
* "id" is used as the identifier of the target space.
349+
* For each of these outer loop iterators, an affine function
350+
* is first constructed in terms of the parameter space
351+
* active at the point where the iteration domain was created and
352+
* then converted into an expression on that iteration domain
353+
* by reinterpreting the parameters as input dimensions.
354+
*/
355+
static isl::multi_aff mapToOther(
356+
const IterationDomain& iterationDomain,
357+
std::unordered_set<std::string> skip,
358+
isl::id id) {
359+
auto ctx = iterationDomain.tuple.get_ctx();
360+
auto list = isl::aff_list(ctx, 0);
361+
for (auto id : iterationDomain.tuple.get_id_list()) {
362+
if (skip.count(id.get_name()) == 1) {
363+
continue;
364+
}
365+
auto aff = isl::aff::param_on_domain_space(iterationDomain.paramSpace, id);
366+
aff = aff.unbind_params_insert_domain(iterationDomain.tuple);
367+
list = list.add(aff);
368+
}
369+
auto domainSpace = iterationDomain.tuple.get_space();
370+
auto space = domainSpace.params().named_set_from_params_id(id, list.size());
371+
space = domainSpace.product(space).unwrap();
372+
return isl::multi_aff(space, list);
373+
}
374+
375+
/*
376+
* If "op" performs a reduction, then return a mapping from
377+
* the statement instances to the individual reductions.
378+
* Otherwise, return an empty isl::union_map.
379+
*
380+
* "op" is considered to be a reduction if it has been marked
381+
* as performing a reduction and if more than one statement instance
382+
* is involved in the individual reductions.
383+
*
384+
* The space of the reduction has a name of the form R_<op->name>_<index>.
385+
* Each reduction is indexed by the outer loop variables
386+
* that are not marked as reduction variables.
387+
* Since the loop variables that iterate over output tensor elements
388+
* are never marked as reduction variables, this means in particular
389+
* that all statement instances that belong to the same reduction
390+
* write to the same tensor element.
391+
*/
392+
isl::union_map extractReduction(
393+
const IterationDomain& iterationDomain,
394+
const Provide* op,
395+
size_t index) {
396+
class FindReductionVars : public IRVisitor {
397+
void visit(const Variable* op) {
398+
if (op->reduction_domain.defined()) {
399+
reductionVars.insert(op->name);
400+
}
401+
}
402+
403+
public:
404+
// The variables that are known to be reduction variables.
405+
std::unordered_set<std::string> reductionVars;
406+
} finder;
407+
408+
if (!isReductionUpdate(op)) {
409+
return isl::union_map::empty(iterationDomain.tuple.get_space().params());
410+
}
411+
op->accept(&finder);
412+
if (finder.reductionVars.size() == 0) {
413+
return isl::union_map::empty(iterationDomain.tuple.get_space().params());
414+
}
415+
auto ctx = iterationDomain.tuple.get_ctx();
416+
isl::id id(ctx, kReductionLabel + op->name + "_" + std::to_string(index));
417+
auto reduction = mapToOther(iterationDomain, finder.reductionVars, id);
418+
return isl::union_map(isl::map(reduction));
419+
}
420+
345421
/*
346422
* Take a parametric expression "f" and convert it into an expression
347423
* on the iteration domains in "domain" by reinterpreting the parameters
@@ -369,7 +445,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
369445
* from outermost to innermost.
370446
* Return the schedule corresponding to the subtree at "s".
371447
*
372-
* "body" collects the accesses found along the way.
448+
* "body" collects the accesses and reductions found along the way.
373449
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
374450
* (for the writes) to the corresponding tag in the access relations.
375451
* "statements" collects the mapping from instance set tuple identifiers
@@ -460,9 +536,13 @@ isl::schedule makeScheduleTreeHelper(
460536
isl::union_map newReads, newWrites;
461537
std::tie(newReads, newWrites) =
462538
extractAccesses(iterationDomain, op, accesses);
539+
// A tensor may be involved in multiple reductions.
540+
// Use the statement index to differentiate between them.
541+
auto newReduction = extractReduction(iterationDomain, op, stmtIndex);
463542

464543
body->reads = body->reads.unite(newReads);
465544
body->writes = body->writes.unite(newWrites);
545+
body->reductions = body->reductions.unite(newReduction);
466546

467547
} else {
468548
LOG(FATAL) << "Unhandled Halide stmt: " << s;

tc/core/polyhedral/body.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace polyhedral {
2626
std::ostream& operator<<(std::ostream& os, const Body& body) {
2727
os << "reads: " << body.reads << "\n";
2828
os << "writes: " << body.writes << "\n";
29+
os << "reductions: " << body.reductions << "\n";
2930

3031
return os;
3132
}

tc/core/polyhedral/body.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,54 @@ namespace polyhedral {
2626
struct Body {
2727
Body() = default;
2828
Body(isl::space paramSpace) {
29-
writes = reads = isl::union_map::empty(paramSpace);
29+
reductions = writes = reads = isl::union_map::empty(paramSpace);
3030
}
3131

3232
// Specialize to the given context.
3333
void specialize(isl::set context) {
3434
reads = reads.intersect_params(context);
3535
writes = writes.intersect_params(context);
36+
reductions = reductions.intersect_params(context);
3637
}
3738

3839
// Union maps describing the reads and writes done. Uses the ids in
3940
// the schedule tree to denote the containing Stmt, and tags each
4041
// access with a unique reference id of the form __tc_ref_N.
4142
isl::union_map reads, writes;
43+
44+
// A function on reduction update statement instances that partitions them
45+
// into individual reductions, where each reduction consists of
46+
// associative updates to the same tensor element.
47+
// Since each reduction involves a single tensor element,
48+
// the partition of statement instances based
49+
// on the reductions forms a refinement of the partition based
50+
// on the element modified by the statement.
51+
// That is, if W is the write access relation and R is the reduction function,
52+
// then, in iscc notation, (W.W^-1) >= (R.R^-1).
53+
// In theory, it is possible for the inclusion to be strict, i.e.,
54+
// for (W.W^-1) > (R.R^-1) to hold. For example, for a statement
55+
//
56+
// A += T(i)
57+
//
58+
// with T of size 4, the write access relation is
59+
//
60+
// { S[i] -> A[] : 0 <= i < 4 }
61+
//
62+
// and the reduction relation could in theory be something like
63+
//
64+
// { S[i] -> R[i] : 0 <= i < 4 }
65+
//
66+
// or even
67+
//
68+
// { S[i] -> R1[i] : 0 <= i < 2; S[i] -> R2[i - 2] : 3 <= i < 4 }
69+
//
70+
// In practice, the reduction map is usually equal to
71+
// the write access relation on reduction update statements,
72+
// with different target spaces.
73+
// That is, in the example above, it would just be
74+
//
75+
// { S[i] -> R[] : 0 <= i < 4 }
76+
isl::union_map reductions;
4277
};
4378

4479
std::ostream& operator<<(std::ostream& os, const Body& body);

0 commit comments

Comments
 (0)