Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 8d0781c

Browse files
prigoyalnicolasvasilache
authored andcommitted
improve code readability
proper variable naming and whitelining between functions
1 parent 71539d1 commit 8d0781c

File tree

2 files changed

+56
-23
lines changed

2 files changed

+56
-23
lines changed

tc/core/tc2halide.cc

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -497,18 +497,18 @@ Expr reductionUpdate(Expr e) {
497497
// contain kReductionUpdate intrinsics. These may have to be removed
498498
// in order to be able to apply internal Halide analysis passes on them.
499499
void translateComprehension(
500-
const lang::Comprehension& c,
500+
const lang::Comprehension& comprehension,
501501
const map<string, Parameter>& params,
502502
bool throwWarnings,
503503
map<string, Function>* funcs,
504504
FunctionBounds* bounds) {
505505
Function f;
506-
auto it = funcs->find(c.ident().name());
506+
auto it = funcs->find(comprehension.ident().name());
507507
if (it != funcs->end()) {
508508
f = it->second;
509509
} else {
510-
f = Function(c.ident().name());
511-
(*funcs)[c.ident().name()] = f;
510+
f = Function(comprehension.ident().name());
511+
(*funcs)[comprehension.ident().name()] = f;
512512
}
513513
// Function is the internal Halide IR type for a pipeline
514514
// stage. Func is the front-end class that wraps it. Here it's
@@ -517,7 +517,7 @@ void translateComprehension(
517517

518518
vector<Var> lhs;
519519
vector<Expr> lhs_as_exprs;
520-
for (lang::Ident id : c.indices()) {
520+
for (lang::Ident id : comprehension.indices()) {
521521
lhs.push_back(Var(id.name()));
522522
lhs_as_exprs.push_back(lhs.back());
523523
}
@@ -526,17 +526,17 @@ void translateComprehension(
526526
// in the future we may consider using Halide Let bindings when they
527527
// are supported later
528528
map<string, Expr> lets;
529-
for (auto wc : c.whereClauses()) {
529+
for (auto wc : comprehension.whereClauses()) {
530530
if (wc->kind() == lang::TK_LET) {
531531
auto let = lang::Let(wc);
532532
lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets);
533533
}
534534
}
535535

536-
Expr rhs = translateExpr(c.rhs(), params, *funcs, lets);
536+
Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets);
537537

538538
std::vector<Expr> all_exprs;
539-
for (auto wc : c.whereClauses()) {
539+
for (auto wc : comprehension.whereClauses()) {
540540
if (wc->kind() == lang::TK_EXISTS) {
541541
all_exprs.push_back(
542542
translateExpr(lang::Exists(wc).exp(), params, *funcs, lets));
@@ -560,7 +560,7 @@ void translateComprehension(
560560
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
561561
// for the reduction and then applies the reduction.
562562
bool should_zero = false;
563-
switch (c.assignment()->kind()) {
563+
switch (comprehension.assignment()->kind()) {
564564
case lang::TK_PLUS_EQ_B:
565565
should_zero = true; // fallthrough
566566
case lang::TK_PLUS_EQ:
@@ -592,12 +592,13 @@ void translateComprehension(
592592
case '=':
593593
break;
594594
default:
595-
throw lang::ErrorReport(c) << "Unimplemented reduction "
596-
<< c.assignment()->range().text() << "\n";
595+
throw lang::ErrorReport(comprehension)
596+
<< "Unimplemented reduction "
597+
<< comprehension.assignment()->range().text() << "\n";
597598
}
598599

599600
// Tag reductions as such
600-
if (c.assignment()->kind() != '=') {
601+
if (comprehension.assignment()->kind() != '=') {
601602
rhs = reductionUpdate(rhs);
602603
}
603604

@@ -637,7 +638,7 @@ void translateComprehension(
637638
Scope<Interval> solution;
638639

639640
// Put anything explicitly specified with a 'where' class in the solution
640-
for (auto constraint_ : c.whereClauses()) {
641+
for (auto constraint_ : comprehension.whereClauses()) {
641642
if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT)
642643
continue;
643644
auto constraint = lang::RangeConstraint(constraint_);
@@ -658,7 +659,8 @@ void translateComprehension(
658659

659660
// Infer the rest
660661
all_exprs.push_back(rhs);
661-
forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution);
662+
forwardBoundsInference(
663+
all_exprs, *bounds, comprehension, throwWarnings, &solution);
662664

663665
// TODO: What if subsequent updates have incompatible bounds
664666
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -669,7 +671,7 @@ void translateComprehension(
669671

670672
for (Var v : lhs) {
671673
if (!solution.contains(v.name())) {
672-
throw lang::ErrorReport(c)
674+
throw lang::ErrorReport(comprehension)
673675
<< "Free variable " << v
674676
<< " was not solved in range inference. May not be used right-hand side";
675677
}
@@ -693,7 +695,7 @@ void translateComprehension(
693695
for (size_t i = 0; i < unbound.size(); i++) {
694696
auto v = unbound[unbound.size() - 1 - i];
695697
if (!solution.contains(v->name)) {
696-
throw lang::ErrorReport(c)
698+
throw lang::ErrorReport(comprehension)
697699
<< "Free variable " << v << " is unconstrained. "
698700
<< "Use a 'where' clause to set its range.";
699701
}

tc/lang/sema.h

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ struct Sema {
166166
}
167167
return expr_to_type.at(ref);
168168
}
169+
169170
// associate a type with this expression
170171
TreeRef withType(TreeRef expr, TreeRef type) {
171172
auto inserted = expr_to_type.emplace(expr, type).second;
@@ -179,6 +180,7 @@ struct Sema {
179180
}
180181
return TensorType(typ);
181182
}
183+
182184
TreeRef matchAllTypes(TreeRef list, TreeRef matched_type = nullptr) {
183185
for (auto e : list->trees()) {
184186
if (!matched_type)
@@ -188,23 +190,27 @@ struct Sema {
188190
}
189191
return matched_type;
190192
}
193+
191194
TreeRef expectIntegral(TreeRef e) {
192195
if (TypeInfo(typeOfExpr(e)).code() == TypeInfo::Float) {
193196
throw ErrorReport(e) << " expected integral type but found "
194197
<< kindToString(typeOfExpr(e)->kind());
195198
}
196199
return e;
197200
}
201+
198202
void expectBool(TreeRef anchor, int token) {
199203
if (token != TK_BOOL) {
200204
throw ErrorReport(anchor)
201205
<< "expected boolean but found " << kindToString(token);
202206
}
203207
}
208+
204209
TreeRef expectBool(TreeRef exp) {
205210
expectBool(exp, typeOfExpr(exp)->kind());
206211
return exp;
207212
}
213+
208214
TreeRef lookupVarOrCreateIndex(Ident ident) {
209215
TreeRef type = lookup(ident, false);
210216
if (!type) {
@@ -216,6 +222,7 @@ struct Sema {
216222
}
217223
return type;
218224
}
225+
219226
TreeRef checkExp(TreeRef exp, bool allow_access) {
220227
switch (exp->kind()) {
221228
case TK_APPLY: {
@@ -339,6 +346,7 @@ struct Sema {
339346
throw ErrorReport(exp) << "NYI - semantic checking for " << exp;
340347
}
341348
}
349+
342350
// This is the entry function for semantic analysis. It is called by
343351
// tc2halide to associate type with each node of the tree and to also make
344352
// sure that the tree is sematically correct. For example: a variable
@@ -352,7 +360,7 @@ struct Sema {
352360
//
353361
// Type checking is also done by small amount of code
354362
//
355-
// The method 'withType' can be used to associate the type with a given node
363+
// The method 'withType' is used to associate the type with a given node
356364
//
357365
TreeRef checkFunction(TreeRef func_) {
358366
auto func = Def(func_);
@@ -385,21 +393,27 @@ struct Sema {
385393
Def::create(func.range(), func.name(), params_, returns_, statements_);
386394
return r;
387395
}
396+
388397
TreeRef indexType(TreeRef anchor) {
389-
return c(TK_INT32, anchor->range(), {});
398+
return createCompound(TK_INT32, anchor->range(), {});
390399
}
400+
391401
TreeRef dimType(TreeRef anchor) {
392402
return indexType(anchor);
393403
}
404+
394405
TreeRef floatType(TreeRef anchor) {
395-
return c(TK_FLOAT, anchor->range(), {});
406+
return createCompound(TK_FLOAT, anchor->range(), {});
396407
}
408+
397409
TreeRef boolType(TreeRef anchor) {
398-
return c(TK_BOOL, anchor->range(), {});
410+
return createCompound(TK_BOOL, anchor->range(), {});
399411
}
412+
400413
void checkDim(Ident dim) {
401414
insert(env, dim, dimType(dim), false);
402415
}
416+
403417
TreeRef checkTensorType(TreeRef type) {
404418
auto tt = TensorType(type);
405419
for (const auto& d : tt.dims()) {
@@ -409,18 +423,21 @@ struct Sema {
409423
}
410424
return type;
411425
}
426+
412427
TreeRef checkParam(TreeRef param) {
413428
auto p = Param(param);
414429
TreeRef type_ = checkTensorType(p.type());
415430
insert(env, p.ident(), type_, true);
416431
live_input_names.insert(p.ident().name());
417432
return param;
418433
}
434+
419435
TreeRef checkReturn(TreeRef ret) {
420436
auto r = Param(ret);
421437
TreeRef real_type = lookup(env, r.ident(), true);
422438
return ret;
423439
}
440+
424441
TreeRef checkList(TreeRef list, std::function<TreeRef(TreeRef)> fn) {
425442
TC_ASSERT(list, list->kind() == TK_LIST);
426443
TreeList r;
@@ -429,6 +446,7 @@ struct Sema {
429446
}
430447
return List::create(list->range(), std::move(r));
431448
}
449+
432450
TreeRef checkRangeConstraint(RangeConstraint rc) {
433451
// RCs are checked _before_ the rhs of the TC, so
434452
// it is possible the index is not in the environment yet
@@ -441,11 +459,13 @@ struct Sema {
441459
auto e = expectIntegral(checkExp(rc.end(), false));
442460
return RangeConstraint::create(rc.range(), rc.ident(), s, e);
443461
}
462+
444463
TreeRef checkLet(Let l) {
445464
auto rhs = checkExp(l.rhs(), true);
446465
insert(let_env, l.name(), typeOfExpr(rhs), true);
447466
return Let::create(l.range(), l.name(), rhs);
448467
}
468+
449469
TreeRef checkWhereClause(TreeRef ref) {
450470
if (ref->kind() == TK_LET) {
451471
return checkLet(Let(ref));
@@ -456,6 +476,7 @@ struct Sema {
456476
return checkRangeConstraint(RangeConstraint(ref));
457477
}
458478
}
479+
459480
// Semantic checking for the statements/comprehensions in a TC Def.
460481
TreeRef checkStmt(TreeRef stmt_) {
461482
auto stmt = Comprehension(stmt_);
@@ -467,11 +488,13 @@ struct Sema {
467488
insert(index_env, index, typ, true);
468489
}
469490

470-
// make dimension variables for each dimension of the output tensor
491+
// check that the input is not used for output - inputs are immutable
471492
std::string name = stmt.ident().name();
472493
if (inputParameters.count(name) > 0) {
473494
throw ErrorReport(stmt_) << "TC inputs are immutable";
474495
}
496+
497+
// make dimension variables for each dimension of the output tensor
475498
TreeList output_indices;
476499
int n = stmt.indices().size();
477500
for (int i = 0; i < n; ++i) {
@@ -578,6 +601,7 @@ struct Sema {
578601

579602
return result;
580603
}
604+
581605
static bool isUninitializedReductionOperation(TreeRef assignment) {
582606
switch (assignment->kind()) {
583607
case TK_PLUS_EQ:
@@ -589,6 +613,7 @@ struct Sema {
589613
return false;
590614
}
591615
}
616+
592617
bool isNotInplace(TreeRef assignment) {
593618
switch (assignment->kind()) {
594619
case TK_PLUS_EQ_B:
@@ -600,6 +625,7 @@ struct Sema {
600625
return false;
601626
}
602627
}
628+
603629
std::string dumpEnv() {
604630
std::stringstream ss;
605631
std::vector<std::pair<std::string, TreeRef>> elems(env.begin(), env.end());
@@ -618,6 +644,7 @@ struct Sema {
618644

619645
private:
620646
using Env = std::unordered_map<std::string, TreeRef>;
647+
621648
void
622649
insert(Env& the_env, Ident ident, TreeRef value, bool must_be_undefined) {
623650
std::string name = ident.name();
@@ -630,6 +657,7 @@ struct Sema {
630657
throw ErrorReport(ident) << name << " already defined";
631658
}
632659
}
660+
633661
TreeRef lookup(Ident ident, bool required) {
634662
TreeRef v = lookup(index_env, ident, false);
635663
if (!v)
@@ -638,6 +666,7 @@ struct Sema {
638666
v = lookup(env, ident, required);
639667
return v;
640668
}
669+
641670
TreeRef lookup(Env& the_env, Ident ident, bool required) {
642671
std::string name = ident.name();
643672
auto it = the_env.find(name);
@@ -647,10 +676,12 @@ struct Sema {
647676
}
648677
return it == the_env.end() ? nullptr : it->second;
649678
}
650-
TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
679+
680+
TreeRef createCompound(int kind, const SourceRange& range, TreeList&& trees) {
651681
return Compound::create(kind, range, std::move(trees));
652682
}
653-
TreeRef s(const std::string& s) {
683+
684+
TreeRef createString(const std::string& s) {
654685
return String::create(s);
655686
}
656687

0 commit comments

Comments
 (0)