Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit b7359f4

Browse files
Adding experimental support for "for" loops in the language
This commit adds a `for t in 0:T { ... }` syntax to the TC language. Because it introduces an extension to the grammar and type changes, modifications need to propagate all the way to `tc2halide`. Support for loops is only implemented in the parser and semantic checker. Emitting proper HalideIR and ScheduleTree will be implemented in follow-up commits. The commit should be reviewed in this order: 1. add support to the lexer 2. add the `For` compound type in `tree_views.h` 3. support `parseStmt` in the parser (in particular see how we `parseRangeConstraint` first and reuse its index) 4. perform semantic checks that guarantee only a single nested `For` looFor` is allowed (in parsee how we `checkRangeConstraint` after checking all comprehensions) 5. update `tc2halide` 6. add new tests This commit would crash if trying to emit HalideIR with `for` loops but in order to compile we still needs to update `tc2halide`.
1 parent 28dcc4f commit b7359f4

File tree

10 files changed

+445
-41
lines changed

10 files changed

+445
-41
lines changed

tc/core/tc2halide.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,11 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
765765
}
766766
for (auto c : def.statements()) {
767767
translateComprehension(
768-
c, components.params, throwWarnings, &funcs, &bounds);
768+
lang::Comprehension(c),
769+
components.params,
770+
throwWarnings,
771+
&funcs,
772+
&bounds);
769773
}
770774
vector<Function> outputs;
771775
for (auto p : def.returns()) {

tc/lang/lexer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ namespace lang {
4040
_(TK_BOOL_VALUE, "bool_value", "") \
4141
_(TK_MIN, "min", "min") \
4242
_(TK_MAX, "max", "max") \
43+
_(TK_FOR, "for", "for") \
4344
_(TK_WHERE, "where", "where") \
4445
_(TK_DEF, "def", "def") \
4546
_(TK_ARROW, "arrow", "->") \

tc/lang/parser.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,27 @@ struct Parser {
225225
}
226226
}
227227
TreeRef parseStmt() {
228+
if (L.cur().kind == TK_FOR) {
229+
auto r = L.cur().range;
230+
L.expect(TK_FOR);
231+
// parseRangeConstraint and reuse its ident allows us to write:
232+
// "for t in A:B { ... }"
233+
// instead of "for t where t in A:B { ... }" and
234+
// ~~~~~~~~~~~~~~~~~~~~~~
235+
// WhereClause of 1 of 3 types
236+
// instead of "for t t in A:B { ... }" and
237+
// ~~~~~~~~~~~~~~~~
238+
// RangeConstraints with index duplication
239+
auto rangeConstraint = parseRangeConstraint();
240+
auto index = RangeConstraint(rangeConstraint).ident();
241+
L.expect('{');
242+
TreeList stmts;
243+
while (!L.nextIf('}')) {
244+
stmts.push_back(parseStmt());
245+
}
246+
auto stmts_list = List::create(r, std::move(stmts));
247+
return For::create(r, index, rangeConstraint, stmts_list);
248+
}
228249
auto ident = parseIdent();
229250
TreeRef list = parseOptionalIdentList();
230251
auto assign = parseAssignment();

tc/lang/sema.h

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -493,55 +493,82 @@ struct Sema {
493493

494494
// Semantic checking for the statements/comprehensions in a TC Def.
495495
TreeRef checkStmt(TreeRef stmt_) {
496-
auto stmt = Comprehension(stmt_);
496+
if (stmt_->kind() == TK_COMPREHENSION) {
497+
return checkComprehension(Comprehension(stmt_));
498+
}
499+
return checkFor(For(stmt_));
500+
}
497501

502+
TreeRef checkFor(For f) {
503+
if (lookup(f.index(), false)) {
504+
throw ErrorReport(f) << "For loop index already defined";
505+
}
506+
TreeList stmts;
507+
for (auto s : f.statements()) {
508+
if (s->kind() != TK_COMPREHENSION) {
509+
throw ErrorReport(s) << "Nested \"for\" loops NYI";
510+
}
511+
stmts.push_back(checkComprehension(Comprehension(s)));
512+
}
513+
// Check the range constraint after all statements
514+
// This way we don't need extra state to track indices coming from loops
515+
// that may have already been defined.
516+
checkRangeConstraint(f.rangeConstraint());
517+
return For::create(
518+
f.range(),
519+
f.index(),
520+
f.rangeConstraint(),
521+
List::create(f.range(), std::move(stmts)));
522+
}
523+
524+
TreeRef checkComprehension(Comprehension comp) {
498525
// register index variables (non-reductions)
499-
for (const auto& index : stmt.indices()) {
526+
for (const auto& index : comp.indices()) {
500527
std::string idx = index.name();
501528
auto typ = indexType(index);
502529
insert(index_env, index, typ, true);
503530
}
504531

505532
// check that the input is not used for output - inputs are immutable
506-
std::string name = stmt.ident().name();
533+
std::string name = comp.ident().name();
507534
if (inputParameters.count(name) > 0) {
508-
throw ErrorReport(stmt_) << "TC inputs are immutable";
535+
throw ErrorReport(comp) << "TC inputs are immutable";
509536
}
510537

511538
// make dimension variables for each dimension of the output tensor
512539
TreeList output_indices;
513-
int n = stmt.indices().size();
540+
int n = comp.indices().size();
514541
for (int i = 0; i < n; ++i) {
515542
auto new_var =
516-
Ident::create(stmt.range(), name + "." + std::to_string(i));
543+
Ident::create(comp.range(), name + "." + std::to_string(i));
517544
output_indices.push_back(new_var);
518545
}
519546

520547
// where clauses are checked _before_ the rhs because they
521548
// introduce let bindings that are in scope for the rhs
522-
auto where_clauses_ = stmt.whereClauses().map(
549+
auto where_clauses_ = comp.whereClauses().map(
523550
[&](TreeRef rc) { return checkWhereClause(rc); });
524551

525-
TreeRef rhs_ = checkExp(stmt.rhs(), true);
552+
TreeRef rhs_ = checkExp(comp.rhs(), true);
526553
TreeRef scalar_type = typeOfExpr(rhs_);
527554

528555
// if this statement will be returned and it is annotated in the return list
529556
// with a type (e.g. float(A,B)) then force the tensor to be that type
530557
// and check that the number of dimensions are consistent
531-
auto output_annotation = annotated_output_types.find(stmt.ident().name());
558+
auto output_annotation = annotated_output_types.find(comp.ident().name());
532559
if (output_annotation != annotated_output_types.end()) {
533560
auto tt = TensorType(output_annotation->second);
534561
auto matched_type = match_types(scalar_type, tt.scalarTypeTree());
535562
if (tt.scalarTypeTree()->kind() != matched_type->kind()) {
536-
throw ErrorReport(stmt)
563+
throw ErrorReport(comp)
537564
<< " attempting to assign type "
538565
<< kindToString(scalar_type->kind()) << " to narrower type "
539566
<< kindToString(tt.scalarTypeTree()->kind())
540567
<< " without an explicit cast";
541568
}
542-
if (tt.dims().size() != stmt.indices().size()) {
543-
throw ErrorReport(stmt)
544-
<< " tensor defined with " << stmt.indices().size()
569+
if (tt.dims().size() != comp.indices().size()) {
570+
throw ErrorReport(comp)
571+
<< " tensor defined with " << comp.indices().size()
545572
<< " dimensions but declared as an output with " << tt.dims().size()
546573
<< " dimensions.";
547574
}
@@ -550,33 +577,33 @@ struct Sema {
550577
// After checking rhs and before creating lhs, we check if it is a reduction
551578
// without initialization (i.e., reduction operator without "!" suffix, and
552579
// lhs not defined previously).
553-
if (isUninitializedReductionOperation(stmt.assignment()) &&
554-
nullptr == lookup(stmt.ident(), false)) {
555-
ErrorReport err(stmt);
556-
std::string tk = kindToToken(stmt.assignment()->kind());
557-
err << "Reduction without initialization. If " << stmt.ident().name()
580+
if (isUninitializedReductionOperation(comp.assignment()) &&
581+
nullptr == lookup(comp.ident(), false)) {
582+
ErrorReport err(comp);
583+
std::string tk = kindToToken(comp.assignment()->kind());
584+
err << "Reduction without initialization. If " << comp.ident().name()
558585
<< " is not pre-initialized before calling the TC function,"
559586
<< " consider using the !-suffixed reduction operator " << tk
560587
<< "! instead of " << tk;
561588
warn(err);
562589
}
563590

564591
auto type = TensorType::create(
565-
stmt.range(),
592+
comp.range(),
566593
scalar_type,
567-
List::create(stmt.range(), std::move(output_indices)));
568-
insert(env, stmt.ident(), type, false);
594+
List::create(comp.range(), std::move(output_indices)));
595+
insert(env, comp.ident(), type, false);
569596

570597
// if we redefined an input, it is no longer valid for range expressions
571-
live_input_names.erase(stmt.ident().name());
598+
live_input_names.erase(comp.ident().name());
572599

573-
auto equivalent_statement_ = stmt.equivalent().map([&](Equivalent eq) {
600+
auto equivalent_statement_ = comp.equivalent().map([&](Equivalent eq) {
574601
auto indices_ = eq.accesses().map(
575602
[&](TreeRef index) { return checkExp(index, true); });
576603
return Equivalent::create(eq.range(), eq.name(), indices_);
577604
});
578605

579-
TreeRef assignment = stmt.assignment();
606+
TreeRef assignment = comp.assignment();
580607
// For semantic consistency we allow overwriting reductions like +=!
581608
// to be used in the language when there are no actual reduction dimensions.
582609
// Later compile stages assume that there is at least one reduction
@@ -586,26 +613,26 @@ struct Sema {
586613
assignment = Compound::create('=', assignment->range(), {});
587614
}
588615

589-
if (reduction_variables.size() > 0 && stmt.assignment()->kind() == '=') {
590-
throw ErrorReport(stmt) << "this statement includes reduction variable '"
616+
if (reduction_variables.size() > 0 && comp.assignment()->kind() == '=') {
617+
throw ErrorReport(comp) << "this statement includes reduction variable '"
591618
<< Ident(reduction_variables.back()).name()
592619
<< "' but does not specify a reduction.";
593620
}
594621
TreeRef reduction_variable_list =
595-
List::create(stmt.ident().range(), std::move(reduction_variables));
622+
List::create(comp.ident().range(), std::move(reduction_variables));
596623
TreeRef result = Comprehension::create(
597-
stmt.range(),
598-
stmt.ident(),
599-
stmt.indices(),
600-
stmt.assignment(),
624+
comp.range(),
625+
comp.ident(),
626+
comp.indices(),
627+
comp.assignment(),
601628
rhs_,
602629
where_clauses_,
603630
equivalent_statement_,
604631
reduction_variable_list);
605632

606-
if (nonTemporaries.count(stmt.ident().name()) == 0) {
607-
throw ErrorReport(stmt)
608-
<< stmt.ident().name()
633+
if (nonTemporaries.count(comp.ident().name()) == 0) {
634+
throw ErrorReport(comp)
635+
<< comp.ident().name()
609636
<< " is not listed as an input or output to this function. Temporaries tensors are not yet implemented";
610637
}
611638

tc/lang/tc_format.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace lang {
2121
namespace {
2222

2323
void showExpr(std::ostream& s, const TreeRef& expr);
24+
void showStmt(std::ostream& s, const TreeRef& stmt);
2425

2526
template <typename T>
2627
void show(std::ostream& s, T x) {
@@ -59,6 +60,16 @@ std::ostream& operator<<(std::ostream& s, const Param& p) {
5960
return s << p.ident();
6061
}
6162

63+
std::ostream& operator<<(std::ostream& s, const For& f) {
64+
s << "for " << f.index() << " in " << f.range().start() << ":"
65+
<< f.range().end() << " {";
66+
for (const TreeRef& stmt : f.statements()) {
67+
showStmt(s, stmt);
68+
}
69+
s << "}";
70+
return s;
71+
}
72+
6273
std::ostream& operator<<(std::ostream& s, const Comprehension& comp) {
6374
s << comp.ident() << "(" << comp.indices() << ") "
6475
<< kindToToken(comp.assignment()->kind()) << " ";
@@ -71,6 +82,21 @@ std::ostream& operator<<(std::ostream& s, const Comprehension& comp) {
7182
return s;
7283
}
7384

85+
void showStmt(std::ostream& s, const TreeRef& stmt) {
86+
switch (stmt->kind()) {
87+
case TK_FOR:
88+
s << " " << For(stmt) << "\n";
89+
break;
90+
case TK_COMPREHENSION:
91+
s << " " << Comprehension(stmt) << "\n";
92+
break;
93+
default:
94+
std::stringstream ss;
95+
ss << "Incorrect statement kind: " << stmt->kind();
96+
throw std::runtime_error(ss.str());
97+
}
98+
}
99+
74100
void showExpr(std::ostream& s, const TreeRef& expr) {
75101
switch (expr->kind()) {
76102
case TK_IDENT: {
@@ -174,8 +200,8 @@ void tcFormat(std::ostream& s, TreeRef _def) {
174200
Def def{_def};
175201
s << "def " << def.name() << "(" << def.params() << ")"
176202
<< " -> (" << def.returns() << ") {\n";
177-
for (const Comprehension& c : def.statements()) {
178-
s << " " << c << "\n";
203+
for (const TreeRef& stmt : def.statements()) {
204+
showStmt(s, stmt);
179205
}
180206
s << "}";
181207
}

tc/lang/test_expected/for1.expected

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
(def
2+
(ident fun)
3+
(list
4+
(param
5+
(ident X)
6+
(tensor_type
7+
(float)
8+
(list (ident M) (ident N))))
9+
(param
10+
(ident Meta)
11+
(tensor_type
12+
(float)
13+
(list (ident T)))))
14+
(list
15+
(param (ident R1) (inferred))
16+
(param (ident R2) (inferred)))
17+
(list
18+
(for
19+
(ident t)
20+
(range_constraint
21+
(ident t)
22+
(const 0 (int32))
23+
(ident T))
24+
(list
25+
(comprehension
26+
(ident R1)
27+
(list
28+
(ident t)
29+
(ident m)
30+
(ident n))
31+
(=)
32+
(?
33+
(eq
34+
(ident t)
35+
(const 0 (int32)))
36+
(access
37+
(ident X)
38+
(list
39+
(ident m)
40+
(ident n)))
41+
(const 0 (float)))
42+
(list)
43+
(option)
44+
(list))
45+
(comprehension
46+
(ident R2)
47+
(list
48+
(ident t)
49+
(ident m)
50+
(ident n))
51+
(=)
52+
(access
53+
(ident R1)
54+
(list
55+
(ident t)
56+
(ident m)
57+
(ident n)))
58+
(list)
59+
(option)
60+
(list))))))
61+
M: (int32)
62+
Meta: (tensor_type (float) (list (ident T)))
63+
N: (int32)
64+
R1: (tensor_type
65+
(float)
66+
(list
67+
(ident R1.0)
68+
(ident R1.1)
69+
(ident R1.2)))
70+
R2: (tensor_type
71+
(float)
72+
(list
73+
(ident R2.0)
74+
(ident R2.1)
75+
(ident R2.2)))
76+
T: (int32)
77+
X: (tensor_type
78+
(float)
79+
(list (ident M) (ident N)))

0 commit comments

Comments
 (0)