Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 702d7f5

Browse files
Merge pull request #348 from facebookresearch/add-operators
Add % operator and propagate it from parser -> halide -> isl
2 parents 0fb3b9b + b36f466 commit 702d7f5

File tree

9 files changed

+167
-59
lines changed

9 files changed

+167
-59
lines changed

tc/core/halide2isl.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ using namespace tc::polyhedral::detail;
3434

3535
SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
3636
// const Stmt& s) {
37-
// Collect and categorize all the Variable symbols
37+
// Collect and categorize all the Halide Variable symbols as reduction
38+
// or index variables
3839
class BuildSymbolTable : public IRVisitor {
3940
using IRVisitor::visit;
4041
std::set<std::string> included;
@@ -59,7 +60,7 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
5960

6061
components.stmt.accept(&builder);
6162
// Get params from components.params which contain everything declared in
62-
// tcdef. However, the 0-D tensors are registered as both params and inputs,
63+
// TC Def. However, the 0-D tensors are registered as both params and inputs,
6364
// filter those out.
6465
for (auto kvp : components.params) {
6566
bool skip = false;
@@ -202,7 +203,7 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
202203
std::vector<isl::aff> result;
203204
// We cannot span multiple constraints if a modulo operation is involved.
204205
// x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
205-
auto lhs = makeIslAffBoundsFromExpr(space, e, false, false);
206+
auto lhs = makeIslAffBoundsFromExpr(space, op->a, false, false);
206207
CHECK_EQ(lhs.size(), 1u);
207208
if (const int64_t* b = as_const_int(op->b)) {
208209
return {lhs[0].mod(isl::val(space.get_ctx(), *b))};

tc/core/tc2halide.cc

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
6262
}
6363
}
6464

65+
// Translate the TC def input params to corresponding Halide components.
66+
// params, inputs will be populated here.
6567
void translateParam(
6668
const lang::Param& p,
6769
map<string, Parameter>* params,
6870
vector<ImageParam>* inputs) {
71+
// Check if the param has already been converted to halide components.
6972
if (params->find(p.ident().name()) != params->end()) {
7073
return;
71-
} else {
72-
lang::TensorType type = p.tensorType();
73-
int dimensions = (int)type.dims().size();
74-
ImageParam imageParam(
75-
translateScalarType(type.scalarType()), dimensions, p.ident().name());
76-
inputs->push_back(imageParam);
77-
vector<Expr> dims;
78-
for (auto d_ : type.dims()) {
79-
if (d_->kind() == lang::TK_IDENT) {
80-
auto d = lang::Ident(d_);
81-
auto it = params->find(d.name());
82-
Parameter p;
83-
if (it != params->end()) {
84-
p = it->second;
85-
} else {
86-
p = Parameter(Int(32), false, 0, d.name(), true);
87-
(*params)[d.name()] = p;
88-
}
89-
dims.push_back(Variable::make(Int(32), p.name(), p));
74+
}
75+
lang::TensorType type = p.tensorType();
76+
int dimensions = (int)type.dims().size();
77+
ImageParam imageParam(
78+
translateScalarType(type.scalarType()), dimensions, p.ident().name());
79+
inputs->push_back(imageParam);
80+
vector<Expr> dims;
81+
for (auto d_ : type.dims()) {
82+
if (d_->kind() == lang::TK_IDENT) {
83+
auto d = lang::Ident(d_);
84+
auto it = params->find(d.name());
85+
Parameter p;
86+
if (it != params->end()) {
87+
p = it->second;
9088
} else {
91-
CHECK(d_->kind() == lang::TK_CONST);
92-
int32_t value = lang::Const(d_).value();
93-
dims.push_back(Expr(value));
89+
p = Parameter(Int(32), false, 0, d.name(), true);
90+
(*params)[d.name()] = p;
9491
}
92+
dims.push_back(Variable::make(Int(32), p.name(), p));
93+
} else {
94+
CHECK(d_->kind() == lang::TK_CONST);
95+
int32_t value = lang::Const(d_).value();
96+
dims.push_back(Expr(value));
9597
}
98+
}
9699

97-
for (int i = 0; i < imageParam.dimensions(); i++) {
98-
imageParam.dim(i).set_bounds(0, dims[i]);
99-
}
100-
(*params)[imageParam.name()] = imageParam.parameter();
100+
for (int i = 0; i < imageParam.dimensions(); i++) {
101+
imageParam.dim(i).set_bounds(0, dims[i]);
101102
}
103+
(*params)[imageParam.name()] = imageParam.parameter();
102104
}
103105

104106
void translateOutput(
@@ -156,6 +158,8 @@ Expr translateExpr(
156158
return t(0) * t(1);
157159
case '/':
158160
return t(0) / t(1);
161+
case '%':
162+
return t(0) % t(1);
159163
case lang::TK_MIN:
160164
return min(t(0), t(1));
161165
case lang::TK_MAX:
@@ -488,22 +492,25 @@ Expr reductionUpdate(Expr e) {
488492
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
489493
}
490494

495+
// Translate a single TC comprehension/statement to Halide components: funcs,
496+
// bounds, reductions.
497+
//
491498
// Note that the function definitions created by translateComprehension may
492499
// contain kReductionUpdate intrinsics. These may have to be removed
493500
// in order to be able to apply internal Halide analysis passes on them.
494501
void translateComprehension(
495-
const lang::Comprehension& c,
502+
const lang::Comprehension& comprehension,
496503
const map<string, Parameter>& params,
497504
bool throwWarnings,
498505
map<string, Function>* funcs,
499506
FunctionBounds* bounds) {
500507
Function f;
501-
auto it = funcs->find(c.ident().name());
508+
auto it = funcs->find(comprehension.ident().name());
502509
if (it != funcs->end()) {
503510
f = it->second;
504511
} else {
505-
f = Function(c.ident().name());
506-
(*funcs)[c.ident().name()] = f;
512+
f = Function(comprehension.ident().name());
513+
(*funcs)[comprehension.ident().name()] = f;
507514
}
508515
// Function is the internal Halide IR type for a pipeline
509516
// stage. Func is the front-end class that wraps it. Here it's
@@ -512,7 +519,7 @@ void translateComprehension(
512519

513520
vector<Var> lhs;
514521
vector<Expr> lhs_as_exprs;
515-
for (lang::Ident id : c.indices()) {
522+
for (lang::Ident id : comprehension.indices()) {
516523
lhs.push_back(Var(id.name()));
517524
lhs_as_exprs.push_back(lhs.back());
518525
}
@@ -521,17 +528,17 @@ void translateComprehension(
521528
// in the future we may consider using Halide Let bindings when they
522529
// are supported later
523530
map<string, Expr> lets;
524-
for (auto wc : c.whereClauses()) {
531+
for (auto wc : comprehension.whereClauses()) {
525532
if (wc->kind() == lang::TK_LET) {
526533
auto let = lang::Let(wc);
527534
lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets);
528535
}
529536
}
530537

531-
Expr rhs = translateExpr(c.rhs(), params, *funcs, lets);
538+
Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets);
532539

533540
std::vector<Expr> all_exprs;
534-
for (auto wc : c.whereClauses()) {
541+
for (auto wc : comprehension.whereClauses()) {
535542
if (wc->kind() == lang::TK_EXISTS) {
536543
all_exprs.push_back(
537544
translateExpr(lang::Exists(wc).exp(), params, *funcs, lets));
@@ -555,7 +562,7 @@ void translateComprehension(
555562
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
556563
// for the reduction and then applies the reduction.
557564
bool should_zero = false;
558-
switch (c.assignment()->kind()) {
565+
switch (comprehension.assignment()->kind()) {
559566
case lang::TK_PLUS_EQ_B:
560567
should_zero = true; // fallthrough
561568
case lang::TK_PLUS_EQ:
@@ -587,12 +594,13 @@ void translateComprehension(
587594
case '=':
588595
break;
589596
default:
590-
throw lang::ErrorReport(c) << "Unimplemented reduction "
591-
<< c.assignment()->range().text() << "\n";
597+
throw lang::ErrorReport(comprehension)
598+
<< "Unimplemented reduction "
599+
<< comprehension.assignment()->range().text() << "\n";
592600
}
593601

594602
// Tag reductions as such
595-
if (c.assignment()->kind() != '=') {
603+
if (comprehension.assignment()->kind() != '=') {
596604
rhs = reductionUpdate(rhs);
597605
}
598606

@@ -632,7 +640,7 @@ void translateComprehension(
632640
Scope<Interval> solution;
633641

634642
// Put anything explicitly specified with a 'where' class in the solution
635-
for (auto constraint_ : c.whereClauses()) {
643+
for (auto constraint_ : comprehension.whereClauses()) {
636644
if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT)
637645
continue;
638646
auto constraint = lang::RangeConstraint(constraint_);
@@ -653,7 +661,8 @@ void translateComprehension(
653661

654662
// Infer the rest
655663
all_exprs.push_back(rhs);
656-
forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution);
664+
forwardBoundsInference(
665+
all_exprs, *bounds, comprehension, throwWarnings, &solution);
657666

658667
// TODO: What if subsequent updates have incompatible bounds
659668
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -664,7 +673,7 @@ void translateComprehension(
664673

665674
for (Var v : lhs) {
666675
if (!solution.contains(v.name())) {
667-
throw lang::ErrorReport(c)
676+
throw lang::ErrorReport(comprehension)
668677
<< "Free variable " << v
669678
<< " was not solved in range inference. May not be used right-hand side";
670679
}
@@ -688,7 +697,7 @@ void translateComprehension(
688697
for (size_t i = 0; i < unbound.size(); i++) {
689698
auto v = unbound[unbound.size() - 1 - i];
690699
if (!solution.contains(v->name)) {
691-
throw lang::ErrorReport(c)
700+
throw lang::ErrorReport(comprehension)
692701
<< "Free variable " << v << " is unconstrained. "
693702
<< "Use a 'where' clause to set its range.";
694703
}
@@ -736,6 +745,7 @@ void translateComprehension(
736745
stage.reorder(loop_nest);
737746
}
738747

748+
// Translate a semantically checked TC def to HalideComponents struct.
739749
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
740750
map<string, Function> funcs;
741751
HalideComponents components;
@@ -895,6 +905,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
895905
lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings);
896906
}
897907

908+
// NOTE: there is no guarantee here that the tc string has only one def. It
909+
// could have many defs. Only first def will be converted in that case.
898910
HalideComponents
899911
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) {
900912
LOG_IF(INFO, tc::FLAGS_debug_halide) << tc;

tc/core/tc2halide.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ namespace tc2halide {
2727
// of the input and output tensors. We do not explicitly enumerate the
2828
// scalar params.
2929
struct HalideComponents {
30-
lang::TreeRef
31-
def; // post-semantic analaysis tree, used for later error reporting
30+
// post-semantic analaysis tree, used for later error reporting
31+
lang::TreeRef def;
3232
Halide::Internal::Stmt stmt;
3333
std::vector<Halide::ImageParam> inputs;
3434
std::map<std::string, Halide::Internal::Parameter> params;

tc/lang/lexer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ namespace lang {
8787
_(TK_LET, "let", "") \
8888
_(TK_EXISTS, "exists", "exists")
8989

90-
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
90+
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%";
9191

9292
enum TokenKind {
9393
// we use characters to represent themselves so skip all valid characters
@@ -137,7 +137,7 @@ struct SharedParserData {
137137
{TK_AND},
138138
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
139139
{'+', '-'},
140-
{'*', '/'},
140+
{'*', '/', '%'},
141141
};
142142
std::vector<std::vector<int>> unary_ops = {
143143
{'-', '!'},

0 commit comments

Comments
 (0)