Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a44e7df

Browse files
committed
Add modulo operator and propagate it from parser to halide to isl
1 parent ae93da7 commit a44e7df

File tree

8 files changed

+147
-61
lines changed

8 files changed

+147
-61
lines changed

tc/core/halide2isl.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ using namespace tc::polyhedral::detail;
3535

3636
SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
3737
// const Stmt& s) {
38-
// Collect and categorize all the Variable symbols
38+
// Collect and categorize all the Halide Variable symbols as reduction
39+
// or index variables
3940
class BuildSymbolTable : public IRVisitor {
4041
using IRVisitor::visit;
4142
std::set<std::string> included;
@@ -60,7 +61,7 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
6061

6162
components.stmt.accept(&builder);
6263
// Get params from components.params which contain everything declared in
63-
// tcdef. However, the 0-D tensors are registered as both params and inputs,
64+
// TC Def. However, the 0-D tensors are registered as both params and inputs,
6465
// filter those out.
6566
for (auto kvp : components.params) {
6667
bool skip = false;
@@ -93,7 +94,7 @@ namespace {
9394
* Convert Halide binary expression "op" into a list of isl affine functions by
9495
* converting its LHS and RHS into lists of affs and concatenating those lists.
9596
* This is intended to be used with Min/Max operations in upper/lower bound
96-
* computations, respectively. Essentially, this allows for replacements
97+
* computations, respectively. Essentially, this allows for replacements
9798
* x < min(a,min(b,c)) <=> x < a AND x < b AND x < c
9899
* x > max(a,max(b,c)) <=> x > a AND x > b AND x > c
99100
*/
@@ -197,10 +198,9 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
197198
} else if (const Div* op = e.as<Div>()) {
198199
return combineSingleAffs(space, op, &isl::aff::div);
199200
} else if (const Mod* op = e.as<Mod>()) {
200-
std::vector<isl::aff> result;
201201
// We cannot span multiple constraints if a modulo operation is involved.
202202
// x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
203-
auto lhs = makeIslAffBoundsFromExpr(space, e, false, false);
203+
auto lhs = makeIslAffBoundsFromExpr(space, op->a, false, false);
204204
CHECK_EQ(lhs.size(), 1u);
205205
if (const int64_t* b = as_const_int(op->b)) {
206206
return {lhs[0].mod(isl::val(space.get_ctx(), *b))};

tc/core/tc2halide.cc

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
6262
}
6363
}
6464

65+
// translate the TC def input params to corresponding Halide components.
66+
// params, inputs will be populated here
6567
void translateParam(
6668
const lang::Param& p,
6769
map<string, Parameter>* params,
6870
vector<ImageParam>* inputs) {
71+
// check if the param is already converted to halide components
6972
if (params->find(p.ident().name()) != params->end()) {
7073
return;
71-
} else {
72-
lang::TensorType type = p.tensorType();
73-
int dimensions = (int)type.dims().size();
74-
ImageParam imageParam(
75-
translateScalarType(type.scalarType()), dimensions, p.ident().name());
76-
inputs->push_back(imageParam);
77-
vector<Expr> dims;
78-
for (auto d_ : type.dims()) {
79-
if (d_->kind() == lang::TK_IDENT) {
80-
auto d = lang::Ident(d_);
81-
auto it = params->find(d.name());
82-
Parameter p;
83-
if (it != params->end()) {
84-
p = it->second;
85-
} else {
86-
p = Parameter(Int(32), false, 0, d.name(), true);
87-
(*params)[d.name()] = p;
88-
}
89-
dims.push_back(Variable::make(Int(32), p.name(), p));
74+
}
75+
lang::TensorType type = p.tensorType();
76+
int dimensions = (int)type.dims().size();
77+
ImageParam imageParam(
78+
translateScalarType(type.scalarType()), dimensions, p.ident().name());
79+
inputs->push_back(imageParam);
80+
vector<Expr> dims;
81+
for (auto d_ : type.dims()) {
82+
if (d_->kind() == lang::TK_IDENT) {
83+
auto d = lang::Ident(d_);
84+
auto it = params->find(d.name());
85+
Parameter p;
86+
if (it != params->end()) {
87+
p = it->second;
9088
} else {
91-
CHECK(d_->kind() == lang::TK_CONST);
92-
int32_t value = lang::Const(d_).value();
93-
dims.push_back(Expr(value));
89+
p = Parameter(Int(32), false, 0, d.name(), true);
90+
(*params)[d.name()] = p;
9491
}
92+
dims.push_back(Variable::make(Int(32), p.name(), p));
93+
} else {
94+
CHECK(d_->kind() == lang::TK_CONST);
95+
int32_t value = lang::Const(d_).value();
96+
dims.push_back(Expr(value));
9597
}
98+
}
9699

97-
for (int i = 0; i < imageParam.dimensions(); i++) {
98-
imageParam.dim(i).set_bounds(0, dims[i]);
99-
}
100-
(*params)[imageParam.name()] = imageParam.parameter();
100+
for (int i = 0; i < imageParam.dimensions(); i++) {
101+
imageParam.dim(i).set_bounds(0, dims[i]);
101102
}
103+
(*params)[imageParam.name()] = imageParam.parameter();
102104
}
103105

104106
void translateOutput(
@@ -156,6 +158,8 @@ Expr translateExpr(
156158
return t(0) * t(1);
157159
case '/':
158160
return t(0) / t(1);
161+
case '%':
162+
return t(0) % t(1);
159163
case lang::TK_MIN:
160164
return min(t(0), t(1));
161165
case lang::TK_MAX:
@@ -492,20 +496,25 @@ Expr reductionUpdate(Expr e) {
492496
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
493497
}
494498

499+
// translate a single TC comprehension/statement to Halide component.
500+
// funcs, bounds, reductions will be populated
495501
void translateComprehension(
496-
const lang::Comprehension& c,
502+
const lang::Comprehension& comprehension,
497503
const map<string, Parameter>& params,
498504
bool throwWarnings,
499505
map<string, Function>* funcs,
500506
FunctionBounds* bounds,
501507
vector<Function>* reductions) {
508+
// Function is the internal Halide IR type for a pipeline
509+
// stage. Func is the front-end class that wraps it. Here it's
510+
// convenient to use both. Why? what is not exposed in Func?
502511
Function f;
503-
auto it = funcs->find(c.ident().name());
512+
auto it = funcs->find(comprehension.ident().name());
504513
if (it != funcs->end()) {
505514
f = it->second;
506515
} else {
507-
f = Function(c.ident().name());
508-
(*funcs)[c.ident().name()] = f;
516+
f = Function(comprehension.ident().name());
517+
(*funcs)[comprehension.ident().name()] = f;
509518
}
510519
// Function is the internal Halide IR type for a pipeline
511520
// stage. Func is the front-end class that wraps it. Here it's
@@ -514,7 +523,7 @@ void translateComprehension(
514523

515524
vector<Var> lhs;
516525
vector<Expr> lhs_as_exprs;
517-
for (lang::Ident id : c.indices()) {
526+
for (lang::Ident id : comprehension.indices()) {
518527
lhs.push_back(Var(id.name()));
519528
lhs_as_exprs.push_back(lhs.back());
520529
}
@@ -523,17 +532,17 @@ void translateComprehension(
523532
// in the future we may consider using Halide Let bindings when they
524533
// are supported later
525534
map<string, Expr> lets;
526-
for (auto wc : c.whereClauses()) {
535+
for (auto wc : comprehension.whereClauses()) {
527536
if (wc->kind() == lang::TK_LET) {
528537
auto let = lang::Let(wc);
529538
lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets);
530539
}
531540
}
532541

533-
Expr rhs = translateExpr(c.rhs(), params, *funcs, lets);
542+
Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets);
534543

535544
std::vector<Expr> all_exprs;
536-
for (auto wc : c.whereClauses()) {
545+
for (auto wc : comprehension.whereClauses()) {
537546
if (wc->kind() == lang::TK_EXISTS) {
538547
all_exprs.push_back(
539548
translateExpr(lang::Exists(wc).exp(), params, *funcs, lets));
@@ -557,7 +566,7 @@ void translateComprehension(
557566
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
558567
// for the reduction and then applies the reduction.
559568
bool should_zero = false;
560-
switch (c.assignment()->kind()) {
569+
switch (comprehension.assignment()->kind()) {
561570
case lang::TK_PLUS_EQ_B:
562571
should_zero = true; // fallthrough
563572
case lang::TK_PLUS_EQ:
@@ -589,11 +598,12 @@ void translateComprehension(
589598
case '=':
590599
break;
591600
default:
592-
throw lang::ErrorReport(c) << "Unimplemented reduction "
593-
<< c.assignment()->range().text() << "\n";
601+
throw lang::ErrorReport(comprehension)
602+
<< "Unimplemented reduction "
603+
<< comprehension.assignment()->range().text() << "\n";
594604
}
595605

596-
if (c.assignment()->kind() != '=') {
606+
if (comprehension.assignment()->kind() != '=') {
597607
reductions->push_back(f);
598608
}
599609

@@ -633,7 +643,7 @@ void translateComprehension(
633643
Scope<Interval> solution;
634644

635645
// Put anything explicitly specified with a 'where' class in the solution
636-
for (auto constraint_ : c.whereClauses()) {
646+
for (auto constraint_ : comprehension.whereClauses()) {
637647
if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT)
638648
continue;
639649
auto constraint = lang::RangeConstraint(constraint_);
@@ -654,7 +664,8 @@ void translateComprehension(
654664

655665
// Infer the rest
656666
all_exprs.push_back(rhs);
657-
forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution);
667+
forwardBoundsInference(
668+
all_exprs, *bounds, comprehension, throwWarnings, &solution);
658669

659670
// TODO: What if subsequent updates have incompatible bounds
660671
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -665,7 +676,7 @@ void translateComprehension(
665676

666677
for (Var v : lhs) {
667678
if (!solution.contains(v.name())) {
668-
throw lang::ErrorReport(c)
679+
throw lang::ErrorReport(comprehension)
669680
<< "Free variable " << v
670681
<< " was not solved in range inference. May not be used right-hand side";
671682
}
@@ -689,7 +700,7 @@ void translateComprehension(
689700
for (size_t i = 0; i < unbound.size(); i++) {
690701
auto v = unbound[unbound.size() - 1 - i];
691702
if (!solution.contains(v->name)) {
692-
throw lang::ErrorReport(c)
703+
throw lang::ErrorReport(comprehension)
693704
<< "Free variable " << v << " is unconstrained. "
694705
<< "Use a 'where' clause to set its range.";
695706
}
@@ -737,6 +748,7 @@ void translateComprehension(
737748
stage.reorder(loop_nest);
738749
}
739750

751+
// translate a semantically checked TC def to Halide components struct
740752
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
741753
map<string, Function> funcs;
742754
HalideComponents components;
@@ -956,6 +968,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
956968
lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings);
957969
}
958970

971+
// NOTE: there is no guarantee here that the tc string has only one def. It
972+
// could have many defs. Only first def will be converted in that case.
959973
HalideComponents
960974
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) {
961975
LOG_IF(INFO, tc::FLAGS_debug_halide) << tc;

tc/core/tc2halide.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace tc2halide {
2727
// of the input and output tensors. We do not explicitly enumerate the
2828
// scalar params.
2929
struct HalideComponents {
30-
lang::TreeRef
31-
def; // post-semantic analaysis tree, used for later error reporting
30+
// post-semantic analaysis tree, used for later error reporting
31+
lang::TreeRef def;
3232
Halide::Internal::Stmt stmt;
3333
std::vector<Halide::ImageParam> inputs;
3434
std::map<std::string, Halide::Internal::Parameter> params;

tc/lang/lexer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ namespace lang {
8787
_(TK_LET, "let", "") \
8888
_(TK_EXISTS, "exists", "exists")
8989

90-
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
90+
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%";
9191

9292
enum TokenKind {
9393
// we use characters to represent themselves so skip all valid characters
@@ -137,7 +137,7 @@ struct SharedParserData {
137137
{TK_AND},
138138
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
139139
{'+', '-'},
140-
{'*', '/'},
140+
{'*', '/', '%'},
141141
};
142142
std::vector<std::vector<int>> unary_ops = {
143143
{'-', '!'},

0 commit comments

Comments
 (0)