From 1fbcf9fe1d05c719e8444eb86c2878d325358bcd Mon Sep 17 00:00:00 2001 From: prigoyal Date: Tue, 1 May 2018 12:01:47 -0700 Subject: [PATCH 1/8] add some documentation to sema/lexer/parser and tc2halide --- tc/core/halide2isl.cc | 5 +++-- tc/core/tc2halide.cc | 8 ++++++++ tc/core/tc2halide.h | 4 ++-- tc/lang/sema.h | 20 ++++++++++++++++++++ 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index eb3c976ce..4b6e3422b 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -35,7 +35,8 @@ using namespace tc::polyhedral::detail; SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) { // const Stmt& s) { - // Collect and categorize all the Variable symbols + // Collect and categorize all the Halide Variable symbols as reduction + // or index variables class BuildSymbolTable : public IRVisitor { using IRVisitor::visit; std::set included; @@ -60,7 +61,7 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) { components.stmt.accept(&builder); // Get params from components.params which contain everything declared in - // tcdef. However, the 0-D tensors are registered as both params and inputs, + // TC Def. However, the 0-D tensors are registered as both params and inputs, // filter those out. for (auto kvp : components.params) { bool skip = false; diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 0d6c43ae3..050bb7e15 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -62,10 +62,13 @@ Type translateScalarType(int tcType) { } } +// Translate the TC def input params to corresponding Halide components. +// params, inputs will be populated here. void translateParam( const lang::Param& p, map* params, vector* inputs) { + // Check if the param has already been converted to halide components. if (params->find(p.ident().name()) != params->end()) { return; } else { @@ -492,6 +495,8 @@ Expr reductionUpdate(Expr e) { return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic); } +// Translate a single TC comprehension/statement to Halide components: funcs, +// bounds, reductions. void translateComprehension( const lang::Comprehension& c, const map& params, @@ -737,6 +742,7 @@ void translateComprehension( stage.reorder(loop_nest); } +// Translate a semantically checked TC def to HalideComponents struct. HalideComponents translateDef(const lang::Def& def, bool throwWarnings) { map funcs; HalideComponents components; @@ -956,6 +962,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) { lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings); } +// NOTE: there is no guarantee here that the tc string has only one def. It +// could have many defs. Only first def will be converted in that case. HalideComponents translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) { LOG_IF(INFO, tc::FLAGS_debug_halide) << tc; diff --git a/tc/core/tc2halide.h b/tc/core/tc2halide.h index b3422f43f..e3cd736a9 100644 --- a/tc/core/tc2halide.h +++ b/tc/core/tc2halide.h @@ -27,8 +27,8 @@ namespace tc2halide { // of the input and output tensors. We do not explicitly enumerate the // scalar params. struct HalideComponents { - lang::TreeRef - def; // post-semantic analaysis tree, used for later error reporting + // post-semantic analaysis tree, used for later error reporting + lang::TreeRef def; Halide::Internal::Stmt stmt; std::vector inputs; std::map params; diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 3b84afe75..4330ec545 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -339,6 +339,21 @@ struct Sema { throw ErrorReport(exp) << "NYI - semantic checking for " << exp; } } + // This is the entry function for semantic analysis. It is called by + // tc2halide to associate type with each node of the tree and to also make + // sure that the tree is sematically correct. For example: a variable + // may not be input two times. Parser only verifies for the syntax but does + // not check the semantics. + // + // It converts the TK_APPLY nodes to TK_ACCESS or TK_BUILT_IN + // + // The reduction variables are deduced and the objects are created for them + // and they are appended to the tree + // + // Type checking is also done by small amount of code + // + // The method 'withType' can be used to associate the type with a given node + // TreeRef checkFunction(TreeRef func_) { auto func = Def(func_); auto params_ = @@ -350,6 +365,10 @@ struct Sema { } } + // Everything has to be input or output. Keep track of the variables that + // are either input/output. We will check that the statements have variables + // from this list. If not, then throw error that temporaries are not yet + // implemented. for (auto p : func.params()) { nonTemporaries.insert(p.ident().name()); inputParameters.insert(p.ident().name()); @@ -437,6 +456,7 @@ struct Sema { return checkRangeConstraint(RangeConstraint(ref)); } } + // Semantic checking for the statements/comprehensions in a TC Def. TreeRef checkStmt(TreeRef stmt_) { auto stmt = Comprehension(stmt_); From 807148d86498129e37c84ecd52f86d530540d0e5 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:26:34 -0700 Subject: [PATCH 2/8] remove unnecessary else in if-else given early exit the 'if' condition is checked and 'return' happens if the condition is met. Using 'else' is not needed --- tc/core/tc2halide.cc | 51 ++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 050bb7e15..57b5204ea 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -71,37 +71,36 @@ void translateParam( // Check if the param has already been converted to halide components. if (params->find(p.ident().name()) != params->end()) { return; - } else { - lang::TensorType type = p.tensorType(); - int dimensions = (int)type.dims().size(); - ImageParam imageParam( - translateScalarType(type.scalarType()), dimensions, p.ident().name()); - inputs->push_back(imageParam); - vector dims; - for (auto d_ : type.dims()) { - if (d_->kind() == lang::TK_IDENT) { - auto d = lang::Ident(d_); - auto it = params->find(d.name()); - Parameter p; - if (it != params->end()) { - p = it->second; - } else { - p = Parameter(Int(32), false, 0, d.name(), true); - (*params)[d.name()] = p; - } - dims.push_back(Variable::make(Int(32), p.name(), p)); + } + lang::TensorType type = p.tensorType(); + int dimensions = (int)type.dims().size(); + ImageParam imageParam( + translateScalarType(type.scalarType()), dimensions, p.ident().name()); + inputs->push_back(imageParam); + vector dims; + for (auto d_ : type.dims()) { + if (d_->kind() == lang::TK_IDENT) { + auto d = lang::Ident(d_); + auto it = params->find(d.name()); + Parameter p; + if (it != params->end()) { + p = it->second; } else { - CHECK(d_->kind() == lang::TK_CONST); - int32_t value = lang::Const(d_).value(); - dims.push_back(Expr(value)); + p = Parameter(Int(32), false, 0, d.name(), true); + (*params)[d.name()] = p; } + dims.push_back(Variable::make(Int(32), p.name(), p)); + } else { + CHECK(d_->kind() == lang::TK_CONST); + int32_t value = lang::Const(d_).value(); + dims.push_back(Expr(value)); } + } - for (int i = 0; i < imageParam.dimensions(); i++) { - imageParam.dim(i).set_bounds(0, dims[i]); - } - (*params)[imageParam.name()] = imageParam.parameter(); + for (int i = 0; i < imageParam.dimensions(); i++) { + imageParam.dim(i).set_bounds(0, dims[i]); } + (*params)[imageParam.name()] = imageParam.parameter(); } void translateOutput( From 0398b9c77a12bce4a9a0b8e29f7df51e91511aaa Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:31:30 -0700 Subject: [PATCH 3/8] improve code readability proper variable naming and whitelining between functions --- tc/core/tc2halide.cc | 34 +++++++++++++++++---------------- tc/lang/sema.h | 45 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 57b5204ea..2bee8fa45 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -497,19 +497,19 @@ Expr reductionUpdate(Expr e) { // Translate a single TC comprehension/statement to Halide components: funcs, // bounds, reductions. void translateComprehension( - const lang::Comprehension& c, + const lang::Comprehension& comprehension, const map& params, bool throwWarnings, map* funcs, FunctionBounds* bounds, vector* reductions) { Function f; - auto it = funcs->find(c.ident().name()); + auto it = funcs->find(comprehension.ident().name()); if (it != funcs->end()) { f = it->second; } else { - f = Function(c.ident().name()); - (*funcs)[c.ident().name()] = f; + f = Function(comprehension.ident().name()); + (*funcs)[comprehension.ident().name()] = f; } // Function is the internal Halide IR type for a pipeline // stage. Func is the front-end class that wraps it. Here it's @@ -518,7 +518,7 @@ void translateComprehension( vector lhs; vector lhs_as_exprs; - for (lang::Ident id : c.indices()) { + for (lang::Ident id : comprehension.indices()) { lhs.push_back(Var(id.name())); lhs_as_exprs.push_back(lhs.back()); } @@ -527,17 +527,17 @@ void translateComprehension( // in the future we may consider using Halide Let bindings when they // are supported later map lets; - for (auto wc : c.whereClauses()) { + for (auto wc : comprehension.whereClauses()) { if (wc->kind() == lang::TK_LET) { auto let = lang::Let(wc); lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets); } } - Expr rhs = translateExpr(c.rhs(), params, *funcs, lets); + Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets); std::vector all_exprs; - for (auto wc : c.whereClauses()) { + for (auto wc : comprehension.whereClauses()) { if (wc->kind() == lang::TK_EXISTS) { all_exprs.push_back( translateExpr(lang::Exists(wc).exp(), params, *funcs, lets)); @@ -561,7 +561,7 @@ void translateComprehension( // values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity // for the reduction and then applies the reduction. bool should_zero = false; - switch (c.assignment()->kind()) { + switch (comprehension.assignment()->kind()) { case lang::TK_PLUS_EQ_B: should_zero = true; // fallthrough case lang::TK_PLUS_EQ: @@ -593,11 +593,12 @@ void translateComprehension( case '=': break; default: - throw lang::ErrorReport(c) << "Unimplemented reduction " - << c.assignment()->range().text() << "\n"; + throw lang::ErrorReport(comprehension) + << "Unimplemented reduction " + << comprehension.assignment()->range().text() << "\n"; } - if (c.assignment()->kind() != '=') { + if (comprehension.assignment()->kind() != '=') { reductions->push_back(f); } @@ -637,7 +638,7 @@ void translateComprehension( Scope solution; // Put anything explicitly specified with a 'where' class in the solution - for (auto constraint_ : c.whereClauses()) { + for (auto constraint_ : comprehension.whereClauses()) { if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT) continue; auto constraint = lang::RangeConstraint(constraint_); @@ -658,7 +659,8 @@ void translateComprehension( // Infer the rest all_exprs.push_back(rhs); - forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution); + forwardBoundsInference( + all_exprs, *bounds, comprehension, throwWarnings, &solution); // TODO: What if subsequent updates have incompatible bounds // (e.g. an in-place stencil)?. The .bound directive will use the @@ -669,7 +671,7 @@ void translateComprehension( for (Var v : lhs) { if (!solution.contains(v.name())) { - throw lang::ErrorReport(c) + throw lang::ErrorReport(comprehension) << "Free variable " << v << " was not solved in range inference. May not be used right-hand side"; } @@ -693,7 +695,7 @@ void translateComprehension( for (size_t i = 0; i < unbound.size(); i++) { auto v = unbound[unbound.size() - 1 - i]; if (!solution.contains(v->name)) { - throw lang::ErrorReport(c) + throw lang::ErrorReport(comprehension) << "Free variable " << v << " is unconstrained. " << "Use a 'where' clause to set its range."; } diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 4330ec545..dabec7330 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -166,6 +166,7 @@ struct Sema { } return expr_to_type.at(ref); } + // associate a type with this expression TreeRef withType(TreeRef expr, TreeRef type) { auto inserted = expr_to_type.emplace(expr, type).second; @@ -179,6 +180,7 @@ struct Sema { } return TensorType(typ); } + TreeRef matchAllTypes(TreeRef list, TreeRef matched_type = nullptr) { for (auto e : list->trees()) { if (!matched_type) @@ -188,6 +190,7 @@ struct Sema { } return matched_type; } + TreeRef expectIntegral(TreeRef e) { if (TypeInfo(typeOfExpr(e)).code() == TypeInfo::Float) { throw ErrorReport(e) << " expected integral type but found " @@ -195,16 +198,19 @@ struct Sema { } return e; } + void expectBool(TreeRef anchor, int token) { if (token != TK_BOOL) { throw ErrorReport(anchor) << "expected boolean but found " << kindToString(token); } } + TreeRef expectBool(TreeRef exp) { expectBool(exp, typeOfExpr(exp)->kind()); return exp; } + TreeRef lookupVarOrCreateIndex(Ident ident) { TreeRef type = lookup(ident, false); if (!type) { @@ -216,6 +222,7 @@ struct Sema { } return type; } + TreeRef checkExp(TreeRef exp, bool allow_access) { switch (exp->kind()) { case TK_APPLY: { @@ -339,6 +346,7 @@ struct Sema { throw ErrorReport(exp) << "NYI - semantic checking for " << exp; } } + // This is the entry function for semantic analysis. It is called by // tc2halide to associate type with each node of the tree and to also make // sure that the tree is sematically correct. For example: a variable @@ -352,7 +360,7 @@ struct Sema { // // Type checking is also done by small amount of code // - // The method 'withType' can be used to associate the type with a given node + // The method 'withType' is used to associate the type with a given node // TreeRef checkFunction(TreeRef func_) { auto func = Def(func_); @@ -385,21 +393,27 @@ struct Sema { Def::create(func.range(), func.name(), params_, returns_, statements_); return r; } + TreeRef indexType(TreeRef anchor) { - return c(TK_INT32, anchor->range(), {}); + return createCompound(TK_INT32, anchor->range(), {}); } + TreeRef dimType(TreeRef anchor) { return indexType(anchor); } + TreeRef floatType(TreeRef anchor) { - return c(TK_FLOAT, anchor->range(), {}); + return createCompound(TK_FLOAT, anchor->range(), {}); } + TreeRef boolType(TreeRef anchor) { - return c(TK_BOOL, anchor->range(), {}); + return createCompound(TK_BOOL, anchor->range(), {}); } + void checkDim(Ident dim) { insert(env, dim, dimType(dim), false); } + TreeRef checkTensorType(TreeRef type) { auto tt = TensorType(type); for (const auto& d : tt.dims()) { @@ -409,6 +423,7 @@ struct Sema { } return type; } + TreeRef checkParam(TreeRef param) { auto p = Param(param); TreeRef type_ = checkTensorType(p.type()); @@ -416,11 +431,13 @@ struct Sema { live_input_names.insert(p.ident().name()); return param; } + TreeRef checkReturn(TreeRef ret) { auto r = Param(ret); TreeRef real_type = lookup(env, r.ident(), true); return ret; } + TreeRef checkList(TreeRef list, std::function fn) { TC_ASSERT(list, list->kind() == TK_LIST); TreeList r; @@ -429,6 +446,7 @@ struct Sema { } return List::create(list->range(), std::move(r)); } + TreeRef checkRangeConstraint(RangeConstraint rc) { // RCs are checked _before_ the rhs of the TC, so // it is possible the index is not in the environment yet @@ -441,11 +459,13 @@ struct Sema { auto e = expectIntegral(checkExp(rc.end(), false)); return RangeConstraint::create(rc.range(), rc.ident(), s, e); } + TreeRef checkLet(Let l) { auto rhs = checkExp(l.rhs(), true); insert(let_env, l.name(), typeOfExpr(rhs), true); return Let::create(l.range(), l.name(), rhs); } + TreeRef checkWhereClause(TreeRef ref) { if (ref->kind() == TK_LET) { return checkLet(Let(ref)); @@ -456,6 +476,7 @@ struct Sema { return checkRangeConstraint(RangeConstraint(ref)); } } + // Semantic checking for the statements/comprehensions in a TC Def. TreeRef checkStmt(TreeRef stmt_) { auto stmt = Comprehension(stmt_); @@ -467,11 +488,13 @@ struct Sema { insert(index_env, index, typ, true); } - // make dimension variables for each dimension of the output tensor + // check that the input is not used for output - inputs are immutable std::string name = stmt.ident().name(); if (inputParameters.count(name) > 0) { throw ErrorReport(stmt_) << "TC inputs are immutable"; } + + // make dimension variables for each dimension of the output tensor TreeList output_indices; int n = stmt.indices().size(); for (int i = 0; i < n; ++i) { @@ -578,6 +601,7 @@ struct Sema { return result; } + static bool isUninitializedReductionOperation(TreeRef assignment) { switch (assignment->kind()) { case TK_PLUS_EQ: @@ -589,6 +613,7 @@ struct Sema { return false; } } + bool isNotInplace(TreeRef assignment) { switch (assignment->kind()) { case TK_PLUS_EQ_B: @@ -600,6 +625,7 @@ struct Sema { return false; } } + std::string dumpEnv() { std::stringstream ss; std::vector> elems(env.begin(), env.end()); @@ -618,6 +644,7 @@ struct Sema { private: using Env = std::unordered_map; + void insert(Env& the_env, Ident ident, TreeRef value, bool must_be_undefined) { std::string name = ident.name(); @@ -630,6 +657,7 @@ struct Sema { throw ErrorReport(ident) << name << " already defined"; } } + TreeRef lookup(Ident ident, bool required) { TreeRef v = lookup(index_env, ident, false); if (!v) @@ -638,6 +666,7 @@ struct Sema { v = lookup(env, ident, required); return v; } + TreeRef lookup(Env& the_env, Ident ident, bool required) { std::string name = ident.name(); auto it = the_env.find(name); @@ -647,10 +676,12 @@ struct Sema { } return it == the_env.end() ? nullptr : it->second; } - TreeRef c(int kind, const SourceRange& range, TreeList&& trees) { + + TreeRef createCompound(int kind, const SourceRange& range, TreeList&& trees) { return Compound::create(kind, range, std::move(trees)); } - TreeRef s(const std::string& s) { + + TreeRef createString(const std::string& s) { return String::create(s); } From 644de66f41d96aa73ef60eebf668b9e892234ca0 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:32:49 -0700 Subject: [PATCH 4/8] fix segfault in halide2isl for mod op change makeIslAffBoundsFromExpr(..., e, ...) to makeIslAffBoundsFromExpr(..., op->a, ...) to avoid infinite recursion leading to segfault --- tc/core/halide2isl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tc/core/halide2isl.cc b/tc/core/halide2isl.cc index 4b6e3422b..4effe762e 100644 --- a/tc/core/halide2isl.cc +++ b/tc/core/halide2isl.cc @@ -201,7 +201,7 @@ std::vector makeIslAffBoundsFromExpr( std::vector result; // We cannot span multiple constraints if a modulo operation is involved. // x > max(a,b) % C is not equivalent to (x > a % C && x > b % C). - auto lhs = makeIslAffBoundsFromExpr(space, e, false, false); + auto lhs = makeIslAffBoundsFromExpr(space, op->a, false, false); CHECK_EQ(lhs.size(), 1u); if (const int64_t* b = as_const_int(op->b)) { return {lhs[0].mod(isl::val(space.get_ctx(), *b))}; From 1a5a9ba5d2d8bb2021bd14ef894da96b6af26929 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:35:49 -0700 Subject: [PATCH 5/8] add support for modulo operator support % operator: propagate it from parser to halide to isl and add unit tests --- tc/core/tc2halide.cc | 2 ++ tc/lang/lexer.h | 4 ++-- tc/lang/sema.h | 1 + tc/lang/test_expected/math.expected | 4 +++- test/cuda/test_corner_cases.cc | 12 ++++++++++++ test/test_lang.cc | 2 +- 6 files changed, 21 insertions(+), 4 deletions(-) diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 2bee8fa45..8e4b3144e 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -158,6 +158,8 @@ Expr translateExpr( return t(0) * t(1); case '/': return t(0) / t(1); + case '%': + return t(0) % t(1); case lang::TK_MIN: return min(t(0), t(1)); case lang::TK_MAX: diff --git a/tc/lang/lexer.h b/tc/lang/lexer.h index d9bbcd9b2..9e40a3092 100644 --- a/tc/lang/lexer.h +++ b/tc/lang/lexer.h @@ -87,7 +87,7 @@ namespace lang { _(TK_LET, "let", "") \ _(TK_EXISTS, "exists", "exists") -static const char* valid_single_char_tokens = "+-*/()[]?:,={}>', '<', TK_LE, TK_GE, TK_EQ, TK_NE}, {'+', '-'}, - {'*', '/'}, + {'*', '/', '%'}, }; std::vector> unary_ops = { {'-', '!'}, diff --git a/tc/lang/sema.h b/tc/lang/sema.h index dabec7330..8c54b57be 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -293,6 +293,7 @@ struct Sema { case '-': case '*': case '/': + case '%': case TK_MIN: case TK_MAX: { auto nexp = diff --git a/tc/lang/test_expected/math.expected b/tc/lang/test_expected/math.expected index 8b21f02b1..700d8d7d8 100644 --- a/tc/lang/test_expected/math.expected +++ b/tc/lang/test_expected/math.expected @@ -1,7 +1,9 @@ (- (+ (+ - (- (const 3 (int32))) + (% + (- (const 3 (int32))) + (const 2 (int32))) (* (const 4 (int32)) (const 5 (int32)))) diff --git a/test/cuda/test_corner_cases.cc b/test/cuda/test_corner_cases.cc index 077763c32..80c93c6f6 100644 --- a/test/cuda/test_corner_cases.cc +++ b/test/cuda/test_corner_cases.cc @@ -285,6 +285,18 @@ TEST(TestCornerCases, E23) { at::Scalar(d[0]).toFloat()); } +TEST(TestCornerCases, E24) { + auto a = I(); + auto b = I(); + auto r = I(0); + Succeed( + "def f(int32 a, int32 b) -> (c) { c(i) = int32(a % b) where i in 0:1 }", + {a, b}, + {r}); + auto e = at::Scalar(a).toInt() % at::Scalar(b).toInt(); + CHECK_EQ(at::Scalar(r[0]).toInt(), e); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/test_lang.cc b/test/test_lang.cc index 4596d6ad0..c1c2b1902 100644 --- a/test/test_lang.cc +++ b/test/test_lang.cc @@ -232,7 +232,7 @@ int main(int argc, char** argv) { ASSERT(s->tree(0)->stringValue() == "min"); } { - std::string stuff = "-3+4*5+7-a"; + std::string stuff = "-3%2+4*5+7-a"; Parser p(stuff); auto r = p.parseExp(); std::stringstream ss; From faad095cd23465910de507c42b3fd3d5ba6bddeb Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:36:40 -0700 Subject: [PATCH 6/8] sema createString cleanup dead function and not used anywhere so cleaning it up --- tc/lang/sema.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 8c54b57be..03c7a9385 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -682,10 +682,6 @@ struct Sema { return Compound::create(kind, range, std::move(trees)); } - TreeRef createString(const std::string& s) { - return String::create(s); - } - std::vector reduction_variables; // per-statement Env index_env; // per-statement Env let_env; // per-statement, used for where i = From 34bd0b1df4f97c31ac64c64fda5f8e89cd2c11ce Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:37:41 -0700 Subject: [PATCH 7/8] add more tests for % to cover breadth --- test/test_cuda_mapper.cc | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/test_cuda_mapper.cc b/test/test_cuda_mapper.cc index 2ddc73ad3..44148285d 100644 --- a/test/test_cuda_mapper.cc +++ b/test/test_cuda_mapper.cc @@ -753,6 +753,38 @@ def perforatedConvolution(float(N, C, H, W) input, float(M, C, KH, KW) weights, Prepare(tc); } +TEST_F(PolyhedralMapperTest, ModulusConstantRHS) { + string tc = R"TC( +def fun(float(N) a) -> (b) { b(i) = a(i % 3) where i in 0:N } +)TC"; + // This triggers tc2halide conversion and should not throw. + auto scop = Prepare(tc); + for (auto r : scop->reads.range().get_set_list()) { + // skip irrelevant reads, if any + if (r.get_tuple_name() != std::string("a")) { + continue; + } + // EXPECT_EQ(r.get_stride(0), ?); + } +} + +TEST_F(PolyhedralMapperTest, ModulusVariableRHS) { + string tc = R"TC( +def local_sparse_convolution(float(N, C, H, W) I, float(O, KC, KH, KW) W1) -> (O1) { + O1(n, o, h, w) +=! I(n, kc % c, h + kh, w + kw) * W1(o, kc, kh, kw) where c in 1:C +} +)TC"; + // This triggers tc2halide conversion and should not throw. + auto scop = Prepare(tc); + for (auto r : scop->reads.range().get_set_list()) { + // skip irrelevant reads, if any + if (r.get_tuple_name() != std::string("I")) { + continue; + } + EXPECT_TRUE(r.plain_is_universe()); + } +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); From 892d79598047a8550bcdf502437ba0be025a9191 Mon Sep 17 00:00:00 2001 From: prigoyal Date: Thu, 3 May 2018 07:56:27 -0700 Subject: [PATCH 8/8] Add bitwise operators and propagate them from parser -> halide -> isl --- tc/core/libraries.h | 6 ++++++ tc/core/tc2halide.cc | 12 ++++++++++++ tc/lang/lexer.h | 10 ++++++++-- tc/lang/sema.h | 6 ++++++ tc/lang/test_expected/bitwise.expected | 11 +++++++++++ test/cuda/test_corner_cases.cc | 27 ++++++++++++++++++++++++++ test/test_lang.cc | 8 ++++++++ 7 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 tc/lang/test_expected/bitwise.expected diff --git a/tc/core/libraries.h b/tc/core/libraries.h index 9d8386cd0..503799054 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -145,6 +145,12 @@ template inline __device__ T floord(T n, T d) { return n < 0 ? - (-n + d - 1)/d : n / d; } #define if_then_else(cond,a,b) ((cond) ? (a) : (b)) +#define shift_left(a,b) ((a) << (b)) +#define shift_right(a,b) ((a) >> (b)) +#define bitwise_and(a,b) ((a) & (b)) +#define bitwise_xor(a,b) ((a) ^ (b)) +#define bitwise_or(a,b) ((a) | (b)) +#define bitwise_not(a) (~(a)) )C"; } // namespace cpp diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 8e4b3144e..c6b719f75 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -190,6 +190,18 @@ Expr translateExpr( return t(0) && t(1); case lang::TK_OR: return t(0) || t(1); + case lang::TK_LS: + return t(0) << t(1); + case lang::TK_RS: + return t(0) >> t(1); + case '|': + return t(0) | t(1); + case '^': + return t(0) ^ t(1); + case '&': + return t(0) & t(1); + case '~': + return ~t(0); case lang::TK_BUILT_IN: { auto b = lang::BuiltIn(expr); vector exprs; diff --git a/tc/lang/lexer.h b/tc/lang/lexer.h index 9e40a3092..70e90441e 100644 --- a/tc/lang/lexer.h +++ b/tc/lang/lexer.h @@ -84,10 +84,12 @@ namespace lang { _(TK_NE, "neq", "!=") \ _(TK_AND, "and", "&&") \ _(TK_OR, "or", "||") \ + _(TK_LS, "ls", "<<") \ + _(TK_RS, "rs", ">>") \ _(TK_LET, "let", "") \ _(TK_EXISTS, "exists", "exists") -static const char* valid_single_char_tokens = "+-*/()[]?:,={}>', '<', TK_LE, TK_GE, TK_EQ, TK_NE}, + {TK_LS, TK_RS}, {'+', '-'}, {'*', '/', '%'}, }; std::vector> unary_ops = { - {'-', '!'}, + {'-', '!', '~'}, }; std::stringstream ss; diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 03c7a9385..d2e033e42 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -294,6 +294,12 @@ struct Sema { case '*': case '/': case '%': + case '~': + case '|': + case '^': + case '&': + case TK_LS: + case TK_RS: case TK_MIN: case TK_MAX: { auto nexp = diff --git a/tc/lang/test_expected/bitwise.expected b/tc/lang/test_expected/bitwise.expected new file mode 100644 index 000000000..aa78c9735 --- /dev/null +++ b/tc/lang/test_expected/bitwise.expected @@ -0,0 +1,11 @@ +(| + (^ + (& + (~ (const 3 (int32))) + (const 4 (int32))) + (const 5 (int32))) + (ls + (rs + (const 6 (int32)) + (const 8 (int32))) + (const 2 (int32)))) diff --git a/test/cuda/test_corner_cases.cc b/test/cuda/test_corner_cases.cc index 80c93c6f6..e6b6add95 100644 --- a/test/cuda/test_corner_cases.cc +++ b/test/cuda/test_corner_cases.cc @@ -297,6 +297,33 @@ TEST(TestCornerCases, E24) { CHECK_EQ(at::Scalar(r[0]).toInt(), e); } +TEST(TestCornerCases, E25){ +#define GEN_BITWISE(op) \ + { \ + auto a = 2 * I(); \ + auto b = 2 * I(); \ + auto r = I(0); \ + Succeed( \ + "def f(int32 a, int32 b) -> (c) { c(i) = int32(a " #op \ + " b) where i in 0:1 }", \ + {a, b}, \ + {r}); \ + auto e = at::Scalar(a).toInt() op at::Scalar(b).toInt(); \ + CHECK_EQ(e, at::Scalar(r[0]).toInt()); \ + } + + GEN_BITWISE(<<) GEN_BITWISE(>>) GEN_BITWISE(&) GEN_BITWISE(|) + GEN_BITWISE (^)} + +TEST(TestCornerCases, E26) { + auto a = I(); + auto r = I(0); + Succeed( + "def f(int32 a) -> (c) { c(i) = int32(~a) where i in 0:1 }", {a}, {r}); + auto e = ~at::Scalar(a).toInt(); + CHECK_EQ(at::Scalar(r[0]).toInt(), e); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/test_lang.cc b/test/test_lang.cc index c1c2b1902..c5d774422 100644 --- a/test/test_lang.cc +++ b/test/test_lang.cc @@ -244,6 +244,14 @@ int main(int argc, char** argv) { ss2 << p2.parseExp(); assertEqual("function.expected", ss2.str()); } + { + std::string bitOps = "~3&4^5|6>>8<<2"; + Parser p(bitOps); + auto r = p.parseExp(); + std::stringstream ss; + ss << r; + assertEqual("bitwise.expected", ss.str()); + } assertParseEqual("trinary.expected", "a ? 3 : b ? 3 : 4", [&](Parser& p) { return p.parseExp(); });