From 278f174135cb5f99e7ce5a32d7bfd938bfdeb400 Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Tue, 10 Apr 2018 21:14:44 +0200 Subject: [PATCH 1/6] add more Halide typedefs --- tc/core/libraries.h | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tc/core/libraries.h b/tc/core/libraries.h index 9d8386cd0..87cc0ee5f 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -32,8 +32,14 @@ namespace c { constexpr auto types = R"C( // Halide type handling -typedef int int32; -typedef long int64; +typedef signed char int8; +typedef unsigned char uint8; +typedef signed short int16; +typedef unsigned short uint16; +typedef signed int int32; +typedef unsigned int uint32; +typedef signed long int64; +typedef unsigned long uint64; typedef float float32; typedef double float64; )C"; @@ -81,16 +87,16 @@ float fmodf ( float x, float y ); //float frexpf ( float x, int* nptr ); float hypotf ( float x, float y ); //int ilogbf ( float x ); -//__RETURN_TYPE isfinite ( float a ); -//__RETURN_TYPE isinf ( float a ); -//__RETURN_TYPE isnan ( float a ); +//__RETURN_TYPE isfinite ( float a ); +//__RETURN_TYPE isinf ( float a ); +//__RETURN_TYPE isnan ( float a ); float j0f ( float x ); float j1f ( float x ); //float jnf ( int n, float x ); //float ldexpf ( float x, int exp ); float lgammaf ( float x ); -//long long int llrintf ( float x ); -//long long int llroundf ( float x ); +//long long int llrintf ( float x ); +//long long int llroundf ( float x ); float log10f ( float x ); float log1pf ( float x ); float log2f ( float x ); @@ -120,7 +126,7 @@ float roundf ( float x ); float rsqrtf ( float x ); //float scalblnf ( float x, long int n ); //float scalbnf ( float x, int n ); -//__RETURN_TYPE signbit ( float a ); +//__RETURN_TYPE signbit ( float a ); //void sincosf ( float x, float* sptr, float* cptr ); //void sincospif ( float x, float* sptr, float* cptr ); float sinf ( float x ); From 063597c36187ce0541d038534452b149a5e1e9b9 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 5 Apr 2018 11:16:58 -0700 Subject: [PATCH 2/6] Allow computed expressions on the left-hand-side --- tc/core/tc2halide.cc | 101 ++++++++++++++++++++--------- tc/lang/parser.h | 11 +++- tc/lang/sema.h | 13 ++-- tc/lang/tc_format.cc | 5 +- tc/lang/tree_views.h | 4 +- test/cuda/test_execution_engine.cc | 19 ++++++ test/cuda/test_tc_mapper.cc | 43 ++++++++++++ 7 files changed, 155 insertions(+), 41 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 7286c9509..9119190ff 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -216,7 +216,7 @@ Expr translateExpr( } } -vector unboundVariables(const vector& lhs, Expr rhs) { +vector unboundVariables(const vector& lhs, Expr rhs) { class FindUnboundVariables : public IRVisitor { using IRVisitor::visit; @@ -241,14 +241,19 @@ vector unboundVariables(const vector& lhs, Expr rhs) { set visited; public: - FindUnboundVariables(const vector& lhs) { - for (auto v : lhs) { - bound.push(v.name()); + FindUnboundVariables(const vector& lhs) { + for (auto e : lhs) { + if (const Variable* v = e.as()) { + bound.push(v->name); + } } } vector result; } finder(lhs); rhs.accept(&finder); + for (auto e : lhs) { + e.accept(&finder); + } return finder.result; } @@ -507,22 +512,31 @@ void translateComprehension( f = Function(c.ident().name()); (*funcs)[c.ident().name()] = f; } + + // we currently inline all of the let bindings generated in where clauses + // in the future we may consider using Halide Let bindings when they + // are supported later + map lets; + // Function is the internal Halide IR type for a pipeline // stage. Func is the front-end class that wraps it. Here it's // convenient to use both. Func func(f); - vector lhs; - vector lhs_as_exprs; - for (lang::Ident id : c.indices()) { - lhs.push_back(Var(id.name())); - lhs_as_exprs.push_back(lhs.back()); + vector lhs; + vector lhs_vars; + bool total_definition = true; + for (lang::TreeRef idx : c.indices()) { + Expr e = translateExpr(idx, params, *funcs, lets); + if (const Variable* op = e.as()) { + lhs_vars.push_back(Var(op->name)); + } else { + total_definition = false; + lhs_vars.push_back(Var()); + } + lhs.push_back(e); } - // we currently inline all of the let bindings generated in where clauses - // in the future we may consider using Halide Let bindings when they - // are supported later - map lets; for (auto wc : c.whereClauses()) { if (wc->kind() == lang::TK_LET) { auto let = lang::Let(wc); @@ -546,9 +560,8 @@ void translateComprehension( auto setupIdentity = [&](const Expr& identity, bool zero) { if (!f.has_pure_definition()) { added_implicit_initialization = true; - func(lhs) = (zero) ? identity - : undef(rhs.type()); // undef causes the original value - // to remain in input arrays + // undef causes the original value to remain in input arrays + func(lhs_vars) = (zero) ? identity : undef(rhs.type()); } }; @@ -587,6 +600,9 @@ void translateComprehension( break; case '=': + if (!total_definition) { + setupIdentity(rhs, false); + } break; default: throw lang::ErrorReport(c) << "Unimplemented reduction " @@ -618,9 +634,10 @@ void translateComprehension( for (auto& exp : all_exprs) { exp = bindParams.mutate(exp); } - - // TODO: When the LHS incorporates general expressions we'll need to - // bind params there too. + for (auto& e : lhs) { + e = bindParams.mutate(e); + all_exprs.push_back(e); + } // Do forward bounds inference -- construct an expression that says // this expression never reads out of bounds on its inputs, and @@ -660,19 +677,34 @@ void translateComprehension( // (e.g. an in-place stencil)?. The .bound directive will use the // bounds of the last stage for all stages. - // Does a tensor have a single bound, or can its bounds shrink over - // time? Solve for a single bound for now. + // Set the bounds to be the union of the boxes written to by every + // comprehension touching the tensor. + for (size_t i = 0; i < lhs.size(); i++) { + Expr e = lhs[i]; + if (const Variable* v = e.as()) { + if (!solution.contains(v->name)) { + throw lang::ErrorReport(c) + << "Free variable " << v + << " was not solved in range inference. May not be used right-hand side"; + } + } - for (Var v : lhs) { - if (!solution.contains(v.name())) { - throw lang::ErrorReport(c) - << "Free variable " << v - << " was not solved in range inference. May not be used right-hand side"; + Interval in = bounds_of_expr_in_scope(e, solution); + if (!in.is_bounded()) { + throw lang::ErrorReport(c.indices()[i]) + << "Left-hand side expression is unbounded"; } - // TODO: We're enforcing a single bound across all comprehensions - // for now. We should really check later ones are equal to earlier - // ones instead of just clobbering. - (*bounds)[f][v.name()] = solution.get(v.name()); + in.min = cast(in.min); + in.max = cast(in.max); + + map& b = (*bounds)[f]; + string dim_name = f.dimensions() ? f.args()[i] : lhs_vars[i].name(); + auto old = b.find(dim_name); + if (old != b.end()) { + // Take the union with any existing bounds + in.include(old->second); + } + b[dim_name] = in; } // Free variables that appear on the rhs but not the lhs are @@ -703,6 +735,9 @@ void translateComprehension( for (auto v : unbound) { Expr rv = Variable::make(Int(32), v->name, domain); rhs = substitute(v->name, rv, rhs); + for (Expr& e : lhs) { + e = substitute(v->name, rv, e); + } } rdom = RDom(domain); } @@ -718,9 +753,12 @@ void translateComprehension( } } while (!lhs.empty()) { - loop_nest.push_back(lhs.back()); + if (const Variable* v = lhs.back().as()) { + loop_nest.push_back(Var(v->name)); + } lhs.pop_back(); } + stage.reorder(loop_nest); if (added_implicit_initialization) { // Also reorder reduction initializations to the TC convention @@ -734,7 +772,6 @@ void translateComprehension( } func.compute_root(); - stage.reorder(loop_nest); } HalideComponents translateDef(const lang::Def& def, bool throwWarnings) { diff --git a/tc/lang/parser.h b/tc/lang/parser.h index 4083771f7..ceaffd32f 100644 --- a/tc/lang/parser.h +++ b/tc/lang/parser.h @@ -151,6 +151,15 @@ struct Parser { TreeRef parseExpList() { return parseList('(', ',', ')', [&](int i) { return parseExp(); }); } + TreeRef parseOptionalExpList() { + TreeRef list = nullptr; + if (L.cur().kind == '(') { + list = parseExpList(); + } else { + list = List::create(L.cur().range, {}); + } + return list; + } TreeRef parseIdentList() { return parseList('(', ',', ')', [&](int i) { return parseIdent(); }); } @@ -226,7 +235,7 @@ struct Parser { } TreeRef parseStmt() { auto ident = parseIdent(); - TreeRef list = parseOptionalIdentList(); + TreeRef list = parseOptionalExpList(); auto assign = parseAssignment(); auto rhs = parseExp(); TreeRef equivalent_statement = parseEquivalent(); diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 406a82711..5cbcb647e 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -442,9 +442,11 @@ struct Sema { // register index variables (non-reductions) for (const auto& index : stmt.indices()) { - std::string idx = index.name(); - auto typ = indexType(index); - insert(index_env, index, typ, true); + if (index->kind() == TK_IDENT) { + std::string idx = Ident(index).name(); + auto typ = indexType(index); + insert(index_env, Ident(index), typ, true); + } } // make dimension variables for each dimension of the output tensor @@ -465,6 +467,9 @@ struct Sema { auto where_clauses_ = stmt.whereClauses().map( [&](TreeRef rc) { return checkWhereClause(rc); }); + auto indices_ = + stmt.indices().map([&](TreeRef idx) { return checkExp(idx, true); }); + TreeRef rhs_ = checkExp(stmt.rhs(), true); TreeRef scalar_type = typeOfExpr(rhs_); @@ -525,7 +530,7 @@ struct Sema { TreeRef result = Comprehension::create( stmt.range(), stmt.ident(), - stmt.indices(), + indices_, stmt.assignment(), rhs_, where_clauses_, diff --git a/tc/lang/tc_format.cc b/tc/lang/tc_format.cc index 8f1fbe8f1..55457d55a 100644 --- a/tc/lang/tc_format.cc +++ b/tc/lang/tc_format.cc @@ -60,8 +60,9 @@ std::ostream& operator<<(std::ostream& s, const Param& p) { } std::ostream& operator<<(std::ostream& s, const Comprehension& comp) { - s << comp.ident() << "(" << comp.indices() << ") " - << kindToToken(comp.assignment()->kind()) << " "; + s << comp.ident() << "("; + showList(s, comp.indices(), showExpr); + s << ") " << kindToToken(comp.assignment()->kind()) << " "; showExpr(s, comp.rhs()); if (!comp.whereClauses().empty()) throw std::runtime_error("Printing of where clauses is not supported yet"); diff --git a/tc/lang/tree_views.h b/tc/lang/tree_views.h index 1e26b8437..099b4c458 100644 --- a/tc/lang/tree_views.h +++ b/tc/lang/tree_views.h @@ -386,8 +386,8 @@ struct Comprehension : public TreeView { Ident ident() const { return Ident(subtree(0)); } - ListView indices() const { - return ListView(subtree(1)); + ListView indices() const { + return ListView(subtree(1)); } // kind == '=', TK_PLUS_EQ, TK_PLUS_EQ_B, etc. TreeRef assignment() const { diff --git a/test/cuda/test_execution_engine.cc b/test/cuda/test_execution_engine.cc index cd508ae8c..3ca4558eb 100644 --- a/test/cuda/test_execution_engine.cc +++ b/test/cuda/test_execution_engine.cc @@ -145,6 +145,25 @@ def concat(float(M, N) A, float(M, N) B) -> (O1) { outputs); } +TEST_F(ATenCompilationUnitTest, Concat2) { + at::Tensor a = at::CUDA(at::kFloat).rand({32, 16}); + at::Tensor b = at::CUDA(at::kFloat).rand({32, 16}); + std::vector inputs = {a, b}; + std::vector outputs; + + Check( + R"( +def concat(float(M, N) A, float(M, N) B) -> (O1) { + O1(n, 0, m) = A(m, n) + O1(n, 1, m) = B(m, n) +} + )", + "concat", + tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), + inputs, + outputs); +} + TEST_F(ATenCompilationUnitTest, Indexing) { at::Tensor a = at::CUDA(at::kFloat).rand({3, 4}); at::Tensor b = at::CUDA(at::kInt).ones({2}); diff --git a/test/cuda/test_tc_mapper.cc b/test/cuda/test_tc_mapper.cc index d005e4db3..6a81e9055 100644 --- a/test/cuda/test_tc_mapper.cc +++ b/test/cuda/test_tc_mapper.cc @@ -352,6 +352,49 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) { checkFun); } +TEST_F(TcCudaMapperTest, Histogram) { + const int N = 17, M = 82; + at::Tensor I = + at::CUDA(at::kFloat).rand({N, M}).mul_(256).floor_().toType(at::kByte); + std::vector inputs = {I}; + std::vector outputs; + + static constexpr auto TC = R"TC( +def fun(uint8(N, M) I) -> (O) { + O(I(i, j)) +=! 1 +} +)TC"; + + auto checkFun = [=](const std::vector& inputs, + std::vector& outputs) { + at::Tensor I = inputs[0].toBackend(at::kCPU); + at::Tensor O = outputs[0].toBackend(at::kCPU); + auto IAccessor = I.accessor(); + auto OAccessor = O.accessor(); + int sum = 0; + for (int i = 0; i < 256; i++) { + sum += OAccessor[i]; + } + CHECK_EQ(sum, N * M); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { + OAccessor[IAccessor[i][j]]--; + } + } + + for (int i = 0; i < 256; i++) { + CHECK_EQ(OAccessor[i], 0); + } + }; + Check( + TC, + "fun", + tc::CudaMappingOptions::makeNaiveCudaMappingOptions(), + inputs, + checkFun); +} + /////////////////////////////////////////////////////////////////////////////// // SpatialBatchNormalization /////////////////////////////////////////////////////////////////////////////// From 27c37af2b50fd1cf1a3aacaa07787723c7b82cba Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Tue, 10 Apr 2018 16:53:34 +0200 Subject: [PATCH 3/6] Scop: separate may and must writes Originally, TC did not support indirection on the LHS. From the polyhedral representation point of view, all writes were thus "must" writes, that is the tensor elements were necessarily overwritten. With indirection, it is impossible to decide statically which elements will be overwritten and which ones won't. Therefore, we need to separately consider "may" writes, i.e. the elements that may or may not be written depending on some dynamic values, and "must" writes. Introduce may/must write separation at the Scop level. Treat all writes as may writes, which is safe may lead to inefficient schedules. --- tc/core/polyhedral/memory_promotion.cc | 2 +- tc/core/polyhedral/scop.cc | 8 +++++--- tc/core/polyhedral/scop.h | 14 +++++++++++--- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/tc/core/polyhedral/memory_promotion.cc b/tc/core/polyhedral/memory_promotion.cc index f12a60b89..5896dbace 100644 --- a/tc/core/polyhedral/memory_promotion.cc +++ b/tc/core/polyhedral/memory_promotion.cc @@ -348,7 +348,7 @@ TensorGroups TensorReferenceGroup::accessedBySubtree( auto schedule = partialSchedule(scop.scheduleRoot(), tree); addSingletonReferenceGroups( - tensorGroups, scop.writes, domain, schedule, AccessType::Write); + tensorGroups, scop.mayWrites, domain, schedule, AccessType::Write); addSingletonReferenceGroups( tensorGroups, scop.reads, domain, schedule, AccessType::Read); diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index 54a968556..09b377261 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -61,7 +61,8 @@ ScopUPtr Scop::makeScop( auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt); scop->scheduleTreeUPtr = std::move(tree.tree); scop->reads = tree.reads; - scop->writes = tree.writes; + scop->mayWrites = tree.writes; + scop->mustWrites = isl::union_map::empty(scop->mayWrites.get_space()); scop->halide.statements = std::move(tree.statements); scop->halide.accesses = std::move(tree.accesses); scop->halide.reductions = halide2isl::findReductions(components.stmt); @@ -109,7 +110,8 @@ const isl::union_set Scop::domain() const { std::ostream& operator<<(std::ostream& os, const Scop& s) { os << "domain: " << s.domain() << "\n"; os << "reads: " << s.reads << "\n"; - os << "writes: " << s.writes << "\n"; + os << "mayWrites: " << s.mayWrites << "\n"; + os << "mustWrites: " << s.mustWrites << "\n"; os << "schedule: " << *s.scheduleRoot() << "\n"; os << "idx: { "; for (auto i : s.halide.idx) { @@ -373,7 +375,7 @@ isl::schedule_constraints makeScheduleConstraints( auto schedule = toIslSchedule(scop.scheduleRoot()); auto firstChildNode = scop.scheduleRoot()->child({0}); auto reads = scop.reads.domain_factor_domain(); - auto writes = scop.writes.domain_factor_domain(); + auto writes = scop.mayWrites.domain_factor_domain(); // RAW auto flowDeps = computeDependences(writes, reads, schedule); diff --git a/tc/core/polyhedral/scop.h b/tc/core/polyhedral/scop.h index bc4953b1b..918d0ee4f 100644 --- a/tc/core/polyhedral/scop.h +++ b/tc/core/polyhedral/scop.h @@ -66,7 +66,8 @@ struct Scop { res->globalParameterContext = scop.globalParameterContext; res->halide = scop.halide; res->reads = scop.reads; - res->writes = scop.writes; + res->mayWrites = scop.mayWrites; + res->mustWrites = scop.mustWrites; res->scheduleTreeUPtr = detail::ScheduleTree::makeScheduleTree(*scop.scheduleTreeUPtr); res->treeSyncUpdateMap = scop.treeSyncUpdateMap; @@ -115,7 +116,8 @@ struct Scop { void specializeToContext() { domain() = domain().intersect_params(globalParameterContext); reads = reads.intersect_params(globalParameterContext); - writes = writes.intersect_params(globalParameterContext); + mayWrites = mayWrites.intersect_params(globalParameterContext); + mustWrites = mustWrites.intersect_params(globalParameterContext); } // Returns a set that specializes the named scop's subset of @@ -442,8 +444,14 @@ struct Scop { // This globalParameterContext lives in a parameter space. isl::set globalParameterContext; // TODO: not too happy about this name + // Access relations. + // Elements in mayWrite may or may not be written by the execution, depending + // on some dynamic condition. Those in mustWrites are always written. + // Thefore, mayWrites do not participate in transitively-covered dependence + // removal. isl::union_map reads; - isl::union_map writes; + isl::union_map mayWrites; + isl::union_map mustWrites; private: // By analogy with generalized functions, a ScheduleTree is a (piecewise From 16ff65af7c521e875466f7dd45251378206df3cf Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Tue, 10 Apr 2018 17:26:34 +0200 Subject: [PATCH 4/6] Scop: use must-sources and kills in dependence analysis Previous commit introduced must writes in Scop. Use them in the dependence analysis as must sources for flow dependences and as kills for flow and false dependences. --- tc/core/polyhedral/scop.cc | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index 09b377261..f460ba90a 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -353,19 +353,29 @@ namespace { using namespace tc::polyhedral; +// Compute the dependence using the given may/must sources, sinks and kills. +// Any of the inputs may be an empty (but non-null) union map. +// Dependence analysis removes the cases transitively covered by a must source +// or a kill. isl::union_map computeDependences( - isl::union_map sources, + isl::union_map maySources, + isl::union_map mustSources, isl::union_map sinks, + isl::union_map kills, isl::schedule schedule) { auto uai = isl::union_access_info(sinks); - uai = uai.set_may_source(sources); + uai = uai.set_may_source(maySources); + uai = uai.set_must_source(mustSources); + uai = uai.set_kill(kills); uai = uai.set_schedule(schedule); auto flow = uai.compute_flow(); return flow.get_may_dependence(); } -// Do the simplest possible dependence analysis. -// Live-range reordering needs tagged access relations to be available. +// Set up schedule constraints by performing the dependence analysis using +// access relations from "scop". Set up callbacks in the constraints depending +// on "scheduleOptions". +// // The domain of the constraints is intersected with "restrictDomain" if it is // provided. isl::schedule_constraints makeScheduleConstraints( @@ -375,12 +385,16 @@ isl::schedule_constraints makeScheduleConstraints( auto schedule = toIslSchedule(scop.scheduleRoot()); auto firstChildNode = scop.scheduleRoot()->child({0}); auto reads = scop.reads.domain_factor_domain(); - auto writes = scop.mayWrites.domain_factor_domain(); + auto mayWrites = scop.mayWrites.domain_factor_domain(); + auto mustWrites = scop.mustWrites.domain_factor_domain(); + auto empty = isl::union_map::empty(mustWrites.get_space()); // RAW - auto flowDeps = computeDependences(writes, reads, schedule); + auto flowDeps = + computeDependences(mayWrites, mustWrites, reads, mustWrites, schedule); // WAR and WAW - auto falseDeps = computeDependences(writes.unite(reads), writes, schedule); + auto falseDeps = computeDependences( + mayWrites.unite(reads), empty, mayWrites, mustWrites, schedule); auto allDeps = flowDeps.unite(falseDeps).coalesce(); From e9d1dc413b830ec42c1a3e4d345ed8adff64857e Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Tue, 10 Apr 2018 18:32:13 +0200 Subject: [PATCH 5/6] extract may/must writes from Halide IR Previous commits introduced may/must writes in Scop and dependence analysis. Extract those from Halide IR. Change extractAccess to return a flag indicating whether the affine access relation is constructed is exact or not. Exact relations correspond to must writes since we statically know which tensor elements are written. Inexact relations overapproximate non-affine accesses and should be treated as may writes, assuming the tensor elements are not necessarily written. --- tc/core/halide2isl.cc | 85 +++++++++++++++++++++++++++++--------- tc/core/halide2isl.h | 2 +- tc/core/polyhedral/scop.cc | 4 +- test/test_core.cc | 54 ++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 23 deletions(-) diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index f5e1ed753..e2ae351aa 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include "tc/core/constants.h" @@ -238,7 +239,20 @@ isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) { return context; } -isl::map extractAccess( +// Extract a tagged affine access relation from Halide IR. +// The relation is tagged with a unique identifier, i.e. it lives in the space +// [D[...] -> __tc_ref_#[]] -> A[] +// where # is a unique sequential number, D is the statement identifier +// extracted from "domain" and A is the tensor identifier constructed from +// "tensor". "accesses" map is updated to keep track of the Halide IR nodes in +// which a particular reference # appeared. +// Returns the access relation and a flag indicating whether this relation is +// exact or not. The relation is overapproximated (that is, not exact) if it +// represents a non-affine access, for example, an access with indirection such +// as O(Index(i)) = 42. In such overapproximated access relation, dimensions +// that correspond to affine subscripts are still exact while those that +// correspond to non-affine subscripts are not constrained. +std::pair extractAccess( isl::set domain, const IRNode* op, const std::string& tensor, @@ -267,6 +281,7 @@ isl::map extractAccess( isl::map map = isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace)); + bool exact = true; for (size_t i = 0; i < args.size(); i++) { // Then add one equality constraint per dimension to encode the // point in the allocation actually read/written for each point in @@ -278,15 +293,17 @@ isl::map extractAccess( isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i); // ... equals the coordinate accessed as a function of the domain. auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]); - if (!domainPoint.is_null()) { + if (!domainPoint) { + exact = false; + } else { map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint)); } } - return map; + return std::make_pair(map, exact); } -std::pair +std::tuple extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) { class FindAccesses : public IRGraphVisitor { using IRGraphVisitor::visit; @@ -294,31 +311,46 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) { void visit(const Call* op) override { IRGraphVisitor::visit(op); if (op->call_type == Call::Halide || op->call_type == Call::Image) { - reads = reads.unite( - extractAccess(domain, op, op->name, op->args, accesses)); + // Read relations can be safely overapproximated. + isl::map read; + std::tie(read, std::ignore) = + extractAccess(domain, op, op->name, op->args, accesses); + reads = reads.unite(read); } } void visit(const Provide* op) override { IRGraphVisitor::visit(op); - writes = - writes.unite(extractAccess(domain, op, op->name, op->args, accesses)); + + // If the write access relation is not exact, we consider that any + // element _may_ be written by the statement. If it is exact, then we + // can guarantee that all the elements specified by the relation _must_ + // be written and any previously stored value will be killed. + isl::map write; + bool exact; + std::tie(write, exact) = + extractAccess(domain, op, op->name, op->args, accesses); + if (exact) { + mustWrites = mustWrites.unite(write); + } + mayWrites = mayWrites.unite(write); } const isl::set& domain; AccessMap* accesses; public: - isl::union_map reads, writes; + isl::union_map reads, mayWrites, mustWrites; FindAccesses(const isl::set& domain, AccessMap* accesses) : domain(domain), accesses(accesses), reads(isl::union_map::empty(domain.get_space())), - writes(isl::union_map::empty(domain.get_space())) {} + mayWrites(isl::union_map::empty(domain.get_space())), + mustWrites(isl::union_map::empty(domain.get_space())) {} } finder(domain, accesses); s.accept(&finder); - return {finder.reads, finder.writes}; + return std::make_tuple(finder.reads, finder.mayWrites, finder.mustWrites); } /* @@ -343,7 +375,8 @@ isl::schedule makeScheduleTreeHelper( isl::set set, std::vector& outer, isl::union_map* reads, - isl::union_map* writes, + isl::union_map* mayWrites, + isl::union_map* mustWrites, AccessMap* accesses, StatementMap* statements, IteratorMap* iterators) { @@ -389,7 +422,8 @@ isl::schedule makeScheduleTreeHelper( set, outerNext, reads, - writes, + mayWrites, + mustWrites, accesses, statements, iterators); @@ -422,7 +456,15 @@ isl::schedule makeScheduleTreeHelper( std::vector schedules; for (Stmt s : stmts) { schedules.push_back(makeScheduleTreeHelper( - s, set, outer, reads, writes, accesses, statements, iterators)); + s, + set, + outer, + reads, + mayWrites, + mustWrites, + accesses, + statements, + iterators)); } schedule = schedules[0].sequence(schedules[1]); @@ -437,23 +479,25 @@ isl::schedule makeScheduleTreeHelper( isl::set domain = set.set_tuple_id(id); schedule = isl::schedule::from_domain(domain); - isl::union_map newReads, newWrites; - std::tie(newReads, newWrites) = + isl::union_map newReads, newMayWrites, newMustWrites; + std::tie(newReads, newMayWrites, newMustWrites) = halide2isl::extractAccesses(domain, op, accesses); *reads = reads->unite(newReads); - *writes = writes->unite(newWrites); + *mayWrites = mayWrites->unite(newMayWrites); + *mustWrites = mustWrites->unite(newMustWrites); } else { LOG(FATAL) << "Unhandled Halide stmt: " << s; } return schedule; -}; +} ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) { ScheduleTreeAndAccesses result; - result.writes = result.reads = isl::union_map::empty(paramSpace); + result.mayWrites = result.mustWrites = result.reads = + isl::union_map::empty(paramSpace); // Walk the IR building a schedule tree std::vector outer; @@ -462,7 +506,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) { isl::set::universe(paramSpace), outer, &result.reads, - &result.writes, + &result.mayWrites, + &result.mustWrites, &result.accesses, &result.statements, &result.iterators); diff --git a/tc/core/halide2isl.h b/tc/core/halide2isl.h index 74ab23f88..bd771c01f 100644 --- a/tc/core/halide2isl.h +++ b/tc/core/halide2isl.h @@ -70,7 +70,7 @@ struct ScheduleTreeAndAccesses { /// Union maps describing the reads and writes done. Uses the ids in /// the schedule tree to denote the containing Stmt, and tags each /// access with a unique reference id of the form __tc_ref_N. - isl::union_map reads, writes; + isl::union_map reads, mayWrites, mustWrites; /// The correspondence between from Call and Provide nodes and the /// reference ids in the reads and writes maps. diff --git a/tc/core/polyhedral/scop.cc b/tc/core/polyhedral/scop.cc index f460ba90a..0c2335ca8 100644 --- a/tc/core/polyhedral/scop.cc +++ b/tc/core/polyhedral/scop.cc @@ -61,8 +61,8 @@ ScopUPtr Scop::makeScop( auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt); scop->scheduleTreeUPtr = std::move(tree.tree); scop->reads = tree.reads; - scop->mayWrites = tree.writes; - scop->mustWrites = isl::union_map::empty(scop->mayWrites.get_space()); + scop->mayWrites = tree.mayWrites; + scop->mustWrites = tree.mustWrites; scop->halide.statements = std::move(tree.statements); scop->halide.accesses = std::move(tree.accesses); scop->halide.reductions = halide2isl::findReductions(components.stmt); diff --git a/test/test_core.cc b/test/test_core.cc index 022ef1020..471eff5eb 100644 --- a/test/test_core.cc +++ b/test/test_core.cc @@ -165,6 +165,10 @@ struct TC2Isl : public ::testing::Test { auto scheduleHalide = polyhedral::detail::fromIslSchedule( polyhedral::detail::toIslSchedule(scop->scheduleRoot()).reset_user()); } + + std::unique_ptr MakeScop(const std::string& tc) { + return polyhedral::Scop::makeScop(isl::with_exceptions::globalIslCtx(), tc); + } }; TEST_F(TC2Isl, Copy1D) { @@ -313,6 +317,56 @@ def fun(float(M, N) I) -> (O1, O2, O3) { Check(tc, {123, 13}); } +// FIXME: range inference seems unaware of indirections on the LHS +TEST_F(TC2Isl, DISABLED_MayWritesOnly) { + string tc = R"TC( +def scatter(int32(N) A, int32(M) B) -> (O) { + O(A(i)) = B(i) +} +)TC"; + auto scop = MakeScop(tc); + CHECK(scop->mustWrites.is_empty()) + << "expected empty must-writes for scatter, got\n" + << scop->mustWrites; + CHECK(!scop->mayWrites.is_empty()) + << "expected non-empty may-writes for scatter, got\n" + << scop->mayWrites; +} + +TEST_F(TC2Isl, AllMustWrites) { + string tc = R"TC( +def gather(int32(N) A, int32(N) B) -> (O) { + O(i) = A(B(i)) where i in 0:N +} +)TC"; + auto scop = MakeScop(tc); + CHECK(!scop->mustWrites.is_empty()) + << "expected non-empty must-writes for gather, got\n" + << scop->mustWrites; + CHECK_EQ(scop->mustWrites, scop->mayWrites); +} + +// FIXME: range inference seems unaware of indirections on the LHS +TEST_F(TC2Isl, DISABLED_MustWritesSubsetMayWrites) { + string tc = R"TC( +def scatter_gather(int32(N) A, int32(N) B) -> (O1,O2) { + O1(i) = A(B(i)) where i in 0:N + O2(A(i)) = B(i) +} +)TC"; + auto scop = MakeScop(tc); + CHECK(!scop->mustWrites.is_empty()) << "expected non-empty must-writes, got\n" + << scop->mustWrites; + CHECK(!scop->mayWrites.is_empty()) << "expected non-empty may-writes, got\n" + << scop->mustWrites; + CHECK(scop->mustWrites.is_subset(scop->mayWrites)) + << scop->mustWrites << " is expected to be a subsetset of " + << scop->mayWrites; + CHECK(!scop->mayWrites.subtract(scop->mustWrites).is_empty()) + << scop->mustWrites << "is expected to be a strict subset of " + << scop->mayWrites; +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); From d9eab4b9c2527bcb3bddbf9e741cf541bac4f232 Mon Sep 17 00:00:00 2001 From: Oleksandr Zinenko Date: Wed, 25 Apr 2018 13:57:29 +0200 Subject: [PATCH 6/6] teach sema to extract iteration variables from LHS subtrees Sema needs the list of iteration variables on the Comprehension LHS to differentiate between reduction and non-reduction variables, the former appearing only on the RHS. Original implementation assumes Comprehension LHS is a tensor whose indices are Idents and ignores more complex constructs. With indirection support, comprehensions like O(A(i)) = B(i) are possible but i is interpreted as a reduction dimension by Sema. Traverse indices of the LHS Tensor in Comprehension recursively, inspecting subtrees of Access and Apply trees and collecting all Idents. --- tc/lang/sema.h | 33 ++++++++++++++++++++++++++------- test/test_core.cc | 13 +++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 5cbcb647e..34a9e7732 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -437,17 +437,35 @@ struct Sema { return checkRangeConstraint(RangeConstraint(ref)); } } + + private: + // Traverse the list of trees, recursively descending into arguments of APPLY + // and ACCESS subtrees and into all subtrees of different types (mostly + // expressions), and collect names and types of IDENT subtrees in + // "index_env". Expects to be called on the indices of the LHS tensor. + template + void registerLHSIndices(const Collection& treeRefs) { + for (const auto& treeRef : treeRefs) { + if (treeRef->kind() == TK_IDENT) { + std::string idx = Ident(treeRef).name(); + auto typ = indexType(treeRef); + insert(index_env, Ident(treeRef), typ, true); + } else if (treeRef->kind() == TK_APPLY) { + registerLHSIndices(Apply(treeRef).arguments()); + } else if (treeRef->kind() == TK_ACCESS) { + registerLHSIndices(Access(treeRef).arguments()); + } else { + registerLHSIndices(treeRef->trees()); + } + } + } + + public: TreeRef checkStmt(TreeRef stmt_) { auto stmt = Comprehension(stmt_); // register index variables (non-reductions) - for (const auto& index : stmt.indices()) { - if (index->kind() == TK_IDENT) { - std::string idx = Ident(index).name(); - auto typ = indexType(index); - insert(index_env, Ident(index), typ, true); - } - } + registerLHSIndices(stmt.indices()); // make dimension variables for each dimension of the output tensor std::string name = stmt.ident().name(); @@ -464,6 +482,7 @@ struct Sema { // where clauses are checked _before_ the rhs because they // introduce let bindings that are in scope for the rhs + // auto where_clauses_ = stmt.whereClauses().map( [&](TreeRef rc) { return checkWhereClause(rc); }); diff --git a/test/test_core.cc b/test/test_core.cc index 471eff5eb..2896e090b 100644 --- a/test/test_core.cc +++ b/test/test_core.cc @@ -346,6 +346,19 @@ def gather(int32(N) A, int32(N) B) -> (O) { CHECK_EQ(scop->mustWrites, scop->mayWrites); } +TEST_F(TC2Isl, Computed) { + string tc = R"TC( +def gather(int32(N) A, int32(N) B) -> (O) { + O(i - 2) = A(i) + B(i) +} +)TC"; + auto scop = MakeScop(tc); + CHECK(!scop->mustWrites.is_empty()) + << "expected non-empty must-writes for gather, got\n" + << scop->mustWrites; + CHECK_EQ(scop->mustWrites, scop->mayWrites); +} + // FIXME: range inference seems unaware of indirections on the LHS TEST_F(TC2Isl, DISABLED_MustWritesSubsetMayWrites) { string tc = R"TC(