Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit d384ed3

Browse files
Merge pull request #520 from facebookresearch/pr/reduction-detection
use polyhedral reduction description to detect reductions
2 parents 83f59f3 + 8a1476f commit d384ed3

20 files changed

+301
-214
lines changed

tc/core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_library(
2121
halide2isl.cc
2222
halide_utils.cc
2323

24+
polyhedral/body.cc
2425
polyhedral/codegen.cc
2526
polyhedral/memory_promotion.cc
2627
polyhedral/reduction_matcher.cc

tc/core/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace polyhedral {
2222
// General constants to avoid hardcoding
2323
//
2424
constexpr auto kStatementLabel = "S_";
25+
constexpr auto kReductionLabel = "R_";
2526

2627
constexpr auto kAstNodeIdPrefix = "__node_";
2728

tc/core/halide2isl.cc

Lines changed: 101 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "tc/core/check.h"
2222
#include "tc/core/constants.h"
23+
#include "tc/core/polyhedral/body.h"
2324
#include "tc/core/polyhedral/schedule_isl_conversion.h"
2425
#include "tc/core/polyhedral/schedule_transforms.h"
2526
#include "tc/core/polyhedral/schedule_tree.h"
@@ -333,6 +334,90 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
333334
return {finder.reads, finder.writes};
334335
}
335336

337+
bool isReductionUpdate(const Provide* op) {
338+
if (const Call* call = op->values[0].as<Call>()) {
339+
return call->is_intrinsic(tc2halide::kReductionUpdate);
340+
} else {
341+
return false;
342+
}
343+
}
344+
345+
/* Construct a multi-dimensional affine function mapping
346+
* the given iteration domain
347+
* to the outer loop iterators that do not appear in "skip".
348+
* "id" is used as the identifier of the target space.
349+
* For each of these outer loop iterators, an affine function
350+
* is first constructed in terms of the parameter space
351+
* active at the point where the iteration domain was created and
352+
* then converted into an expression on that iteration domain
353+
* by reinterpreting the parameters as input dimensions.
354+
*/
355+
static isl::multi_aff mapToOther(
356+
const IterationDomain& iterationDomain,
357+
std::unordered_set<std::string> skip,
358+
isl::id id) {
359+
auto ctx = iterationDomain.tuple.get_ctx();
360+
auto list = isl::aff_list(ctx, 0);
361+
for (auto id : iterationDomain.tuple.get_id_list()) {
362+
if (skip.count(id.get_name()) == 1) {
363+
continue;
364+
}
365+
auto aff = isl::aff::param_on_domain_space(iterationDomain.paramSpace, id);
366+
aff = aff.unbind_params_insert_domain(iterationDomain.tuple);
367+
list = list.add(aff);
368+
}
369+
auto domainSpace = iterationDomain.tuple.get_space();
370+
auto space = domainSpace.params().named_set_from_params_id(id, list.size());
371+
space = domainSpace.product(space).unwrap();
372+
return isl::multi_aff(space, list);
373+
}
374+
375+
/*
376+
* If "op" performs a reduction, then return a mapping from
377+
* the statement instances to the individual reductions.
378+
* Otherwise, return an empty isl::union_map.
379+
*
380+
* "op" is considered to be a reduction if it has been marked
381+
* as performing a reduction and if more than one statement instance
382+
* is involved in the individual reductions.
383+
*
384+
* The space of the reduction has a name of the form R_<op->name>_<index>.
385+
* Each reduction is indexed by the outer loop variables
386+
* that are not marked as reduction variables.
387+
* Since the loop variables that iterate over output tensor elements
388+
* are never marked as reduction variables, this means in particular
389+
* that all statement instances that belong to the same reduction
390+
* write to the same tensor element.
391+
*/
392+
isl::union_map extractReduction(
393+
const IterationDomain& iterationDomain,
394+
const Provide* op,
395+
size_t index) {
396+
class FindReductionVars : public IRVisitor {
397+
void visit(const Variable* op) {
398+
if (op->reduction_domain.defined()) {
399+
reductionVars.insert(op->name);
400+
}
401+
}
402+
403+
public:
404+
// The variables that are known to be reduction variables.
405+
std::unordered_set<std::string> reductionVars;
406+
} finder;
407+
408+
if (!isReductionUpdate(op)) {
409+
return isl::union_map::empty(iterationDomain.tuple.get_space().params());
410+
}
411+
op->accept(&finder);
412+
if (finder.reductionVars.size() == 0) {
413+
return isl::union_map::empty(iterationDomain.tuple.get_space().params());
414+
}
415+
auto ctx = iterationDomain.tuple.get_ctx();
416+
isl::id id(ctx, kReductionLabel + op->name + "_" + std::to_string(index));
417+
auto reduction = mapToOther(iterationDomain, finder.reductionVars, id);
418+
return isl::union_map(isl::map(reduction));
419+
}
420+
336421
/*
337422
* Take a parametric expression "f" and convert it into an expression
338423
* on the iteration domains in "domain" by reinterpreting the parameters
@@ -360,7 +445,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
360445
* from outermost to innermost.
361446
* Return the schedule corresponding to the subtree at "s".
362447
*
363-
* "reads" and "writes" collect the accesses found along the way.
448+
* "body" collects the accesses and reductions found along the way.
364449
* "accesses" collects the mapping from Call (for the reads) and Provide nodes
365450
* (for the writes) to the corresponding tag in the access relations.
366451
* "statements" collects the mapping from instance set tuple identifiers
@@ -372,8 +457,7 @@ isl::schedule makeScheduleTreeHelper(
372457
const Stmt& s,
373458
isl::set set,
374459
isl::id_list outer,
375-
isl::union_map* reads,
376-
isl::union_map* writes,
460+
Body* body,
377461
AccessMap* accesses,
378462
StatementMap* statements,
379463
IterationDomainMap* domains) {
@@ -406,19 +490,19 @@ isl::schedule makeScheduleTreeHelper(
406490

407491
// Recursively descend.
408492
auto outerNext = outer.add(isl::id(set.get_ctx(), op->name));
409-
auto body = makeScheduleTreeHelper(
410-
op->body, set, outerNext, reads, writes, accesses, statements, domains);
493+
auto bodySchedule = makeScheduleTreeHelper(
494+
op->body, set, outerNext, body, accesses, statements, domains);
411495

412496
// Create an affine function that defines an ordering for all
413497
// the statements in the body of this loop over the values of
414498
// this loop. Start from a parametric expression equal
415499
// to the current loop iterator and then convert it to
416500
// a function on the statements in the domain of the body schedule.
417501
auto aff = isl::aff::param_on_domain_space(space, id);
418-
auto domain = body.get_domain();
502+
auto domain = bodySchedule.get_domain();
419503
auto mupa = isl::multi_union_pw_aff(onDomains(aff, domain, *domains));
420504

421-
schedule = body.insert_partial_schedule(mupa);
505+
schedule = bodySchedule.insert_partial_schedule(mupa);
422506
} else if (auto op = s.as<Halide::Internal::Block>()) {
423507
std::vector<Stmt> stmts;
424508
stmts.push_back(op->first);
@@ -429,7 +513,7 @@ isl::schedule makeScheduleTreeHelper(
429513
std::vector<isl::schedule> schedules;
430514
for (Stmt stmt : stmts) {
431515
schedules.push_back(makeScheduleTreeHelper(
432-
stmt, set, outer, reads, writes, accesses, statements, domains));
516+
stmt, set, outer, body, accesses, statements, domains));
433517
}
434518
schedule = schedules[0].sequence(schedules[1]);
435519

@@ -452,9 +536,13 @@ isl::schedule makeScheduleTreeHelper(
452536
isl::union_map newReads, newWrites;
453537
std::tie(newReads, newWrites) =
454538
extractAccesses(iterationDomain, op, accesses);
539+
// A tensor may be involved in multiple reductions.
540+
// Use the statement index to differentiate between them.
541+
auto newReduction = extractReduction(iterationDomain, op, stmtIndex);
455542

456-
*reads = reads->unite(newReads);
457-
*writes = writes->unite(newWrites);
543+
body->reads = body->reads.unite(newReads);
544+
body->writes = body->writes.unite(newWrites);
545+
body->reductions = body->reductions.unite(newReduction);
458546

459547
} else {
460548
LOG(FATAL) << "Unhandled Halide stmt: " << s;
@@ -465,87 +553,24 @@ isl::schedule makeScheduleTreeHelper(
465553
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
466554
ScheduleTreeAndAccesses result;
467555

468-
result.writes = result.reads = isl::union_map::empty(paramSpace);
556+
Body body(paramSpace);
469557

470558
// Walk the IR building a schedule tree
471559
isl::id_list outer(paramSpace.get_ctx(), 0);
472560
auto schedule = makeScheduleTreeHelper(
473561
s,
474562
isl::set::universe(paramSpace),
475563
outer,
476-
&result.reads,
477-
&result.writes,
564+
&body,
478565
&result.accesses,
479566
&result.statements,
480567
&result.domains);
481568

569+
result.body = body;
482570
result.tree = fromIslSchedule(schedule);
483571

484572
return result;
485573
}
486574

487-
std::vector<Reduction> findReductions(const Stmt& s) {
488-
class FindReductions : public IRVisitor {
489-
using IRVisitor::visit;
490-
491-
bool isReductionUpdate(const Provide* op) {
492-
if (const Call* call = op->values[0].as<Call>()) {
493-
return call->is_intrinsic(tc2halide::kReductionUpdate);
494-
} else {
495-
return false;
496-
}
497-
}
498-
499-
// Keep track of any reduction variable name for use in visit(Provide*)
500-
void visit(const Variable* op) {
501-
if (op->reduction_domain.defined()) {
502-
reductionVars.insert(op->name);
503-
}
504-
}
505-
506-
// Keep track of the names of the outer For nodes.
507-
void visit(const For* op) {
508-
vars.push_back(op->name);
509-
IRVisitor::visit(op);
510-
vars.pop_back();
511-
}
512-
513-
// Check if the node is an update node with at least one reduction
514-
// dimension, keeping track of the information about the reduction.
515-
// In particular, collect the positions of the reduction
516-
// dimensions in the update statement domain.
517-
// Visit the children first to ensure that all relevant
518-
// reduction variables have been found first.
519-
void visit(const Provide* op) {
520-
IRVisitor::visit(op);
521-
if (isReductionUpdate(op)) {
522-
std::vector<size_t> dims;
523-
auto n = vars.size();
524-
for (size_t i = 0; i < n; ++i) {
525-
if (reductionVars.count(vars[i]) != 0) {
526-
dims.emplace_back(i);
527-
}
528-
}
529-
if (dims.size() > 0) {
530-
Reduction p;
531-
p.update = op;
532-
p.dims = dims;
533-
reductions.emplace_back(p);
534-
}
535-
}
536-
}
537-
538-
public:
539-
// The variables that are known to be reduction variables.
540-
std::unordered_set<std::string> reductionVars;
541-
// The names of the outer For nodes, outermost to innermost.
542-
std::vector<std::string> vars;
543-
std::vector<Reduction> reductions;
544-
} finder;
545-
s.accept(&finder);
546-
547-
return finder.reductions;
548-
}
549-
550575
} // namespace halide2isl
551576
} // namespace tc

tc/core/halide2isl.h

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <Halide.h>
2424

25+
#include "tc/core/polyhedral/body.h"
2526
#include "tc/core/polyhedral/schedule_tree.h"
2627
#include "tc/core/tc2halide.h"
2728
#include "tc/external/isl.h"
@@ -82,10 +83,8 @@ struct ScheduleTreeAndAccesses {
8283
/// for each leaf node is captured below.
8384
tc::polyhedral::ScheduleTreeUPtr tree;
8485

85-
/// Union maps describing the reads and writes done. Uses the ids in
86-
/// the schedule tree to denote the containing Stmt, and tags each
87-
/// access with a unique reference id of the form __tc_ref_N.
88-
isl::union_map reads, writes;
86+
/// Information extracted from the bodies of the statements.
87+
tc::polyhedral::Body body;
8988

9089
/// The correspondence between from Call and Provide nodes and the
9190
/// reference ids in the reads and writes maps.
@@ -106,17 +105,5 @@ ScheduleTreeAndAccesses makeScheduleTree(
106105
isl::space paramSpace,
107106
const Halide::Internal::Stmt& s);
108107

109-
/// Enumerate all reductions in a statement, by looking for the
110-
/// ReductionUpdate markers inserted during lowering
111-
/// (see tc2halide.h).
112-
/// Each reduction object stores a reference to
113-
/// the update statement, and a list of reduction dimensions
114-
/// in the domain of the update statement.
115-
struct Reduction {
116-
Halide::Internal::Stmt update;
117-
std::vector<size_t> dims;
118-
};
119-
std::vector<Reduction> findReductions(const Halide::Internal::Stmt& s);
120-
121108
} // namespace halide2isl
122109
} // namespace tc

tc/core/polyhedral/body.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/**
2+
* Copyright (c) 2018, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tc/core/polyhedral/body.h"
18+
19+
#include <iostream>
20+
21+
#include "tc/external/isl.h"
22+
23+
namespace tc {
24+
namespace polyhedral {
25+
26+
std::ostream& operator<<(std::ostream& os, const Body& body) {
27+
os << "reads: " << body.reads << "\n";
28+
os << "writes: " << body.writes << "\n";
29+
os << "reductions: " << body.reductions << "\n";
30+
31+
return os;
32+
}
33+
34+
} // namespace polyhedral
35+
} // namespace tc

0 commit comments

Comments
 (0)