Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit df50976

Browse files
committed
Implement where clause variables and exists
This adds two things to where clauses: Exists clauses are looked at in range inference but are not actually evaluated. This allows you make things that are similarly sized to another tensor (and will work well with the pad() operator once we finish it): a(i) = 0 where b(i) exists Let clauses allow you to declare temporary variables whose scope is only for the local comprehensions: a(i) = foo + foo*foo where foo = b(i) This commit also includes a simplification of TreeView objects that I back ported from pytorch that makes adding new tree views less verbose.
1 parent ff2986f commit df50976

File tree

8 files changed

+248
-170
lines changed

8 files changed

+248
-170
lines changed

include/tc/lang/lexer.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ namespace lang {
8383
_(TK_EQ, "eq", "==") \
8484
_(TK_NE, "neq", "!=") \
8585
_(TK_AND, "and", "&&") \
86-
_(TK_OR, "or", "||")
86+
_(TK_OR, "or", "||") \
87+
_(TK_LET, "let", "") \
88+
_(TK_EXISTS, "exists", "exists")
8789

8890
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
8991

@@ -380,9 +382,20 @@ struct Lexer {
380382
next();
381383
return true;
382384
}
385+
Token lookahead() {
386+
if (!lookahead_) {
387+
lookahead_.reset(new Token(lex()));
388+
}
389+
return *lookahead_;
390+
}
383391
Token next() {
384392
auto r = cur_;
385-
cur_ = lex();
393+
if (lookahead_) {
394+
cur_ = *lookahead_;
395+
lookahead_.reset();
396+
} else {
397+
cur_ = lex();
398+
}
386399
return r;
387400
}
388401
void reportError(const std::string& what, const Token& t);
@@ -416,6 +429,7 @@ struct Lexer {
416429
}
417430
size_t pos;
418431
Token cur_;
432+
std::unique_ptr<Token> lookahead_;
419433
SharedParserData& shared;
420434
};
421435
} // namespace lang

include/tc/lang/parser.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,24 @@ struct Parser {
149149
auto r = parseExp();
150150
return RangeConstraint::create(id->range(), id, l, r);
151151
}
152+
TreeRef parseLetBinding() {
153+
auto ident = parseIdent();
154+
L.expect('=');
155+
auto exp = parseExp();
156+
return Let::create(ident->range(), ident, exp);
157+
}
158+
TreeRef parseWhereClause() {
159+
auto lookahead = L.lookahead();
160+
if (lookahead.kind == '=') {
161+
return parseLetBinding();
162+
} else if (lookahead.kind == TK_IN) {
163+
return parseRangeConstraint();
164+
} else {
165+
auto exp = parseExp();
166+
L.expect(TK_EXISTS);
167+
return Exists::create(exp->range(), {exp});
168+
}
169+
}
152170
TreeRef parseParam() {
153171
if (L.cur().kind == TK_IDENT) {
154172
auto ident = parseIdent();
@@ -159,10 +177,9 @@ struct Parser {
159177
auto ident = parseIdent();
160178
return Param::create(typ->range(), ident, typ);
161179
}
162-
TreeRef parseRangeConstraints() {
180+
TreeRef parseWhereClauses() {
163181
if (L.nextIf(TK_WHERE)) {
164-
return parseNonEmptyList(
165-
',', [&](int i) { return parseRangeConstraint(); });
182+
return parseNonEmptyList(',', [&](int i) { return parseWhereClause(); });
166183
}
167184
return List::create(L.cur().range, {});
168185
}
@@ -200,7 +217,7 @@ struct Parser {
200217
auto assign = parseAssignment();
201218
auto rhs = parseExp();
202219
TreeRef equivalent_statement = parseEquivalent();
203-
TreeRef range_statements = parseRangeConstraints();
220+
TreeRef range_statements = parseWhereClauses();
204221
TreeRef empty_reduction_variables = c(TK_LIST, ident->range(), {});
205222
return Comprehension::create(
206223
ident->range(),

include/tc/lang/sema.h

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,17 @@ struct Sema {
204204
expectBool(exp, typeOfExpr(exp)->kind());
205205
return exp;
206206
}
207+
TreeRef lookupVarOrCreateIndex(Ident ident) {
208+
TreeRef type = lookup(ident, false);
209+
if (!type) {
210+
// variable exp is not defined, so a reduction variable is created
211+
// a reduction variable index i
212+
type = indexType(ident);
213+
insert(index_env, ident, type, true);
214+
reduction_variables.push_back(ident);
215+
}
216+
return type;
217+
}
207218
TreeRef checkExp(TreeRef exp, bool allow_access) {
208219
switch (exp->kind()) {
209220
case TK_APPLY: {
@@ -250,14 +261,7 @@ struct Sema {
250261
} break;
251262
case TK_IDENT: {
252263
auto ident = Ident(exp);
253-
TreeRef type = lookup(ident, false);
254-
if (!type) {
255-
// variable exp is not defined, so a reduction variable is created
256-
// a reduction variable index i
257-
type = indexType(exp);
258-
insert(index_env, ident, type, true);
259-
reduction_variables.push_back(exp);
260-
}
264+
auto type = lookupVarOrCreateIndex(ident);
261265
if (type->kind() == TK_TENSOR_TYPE) {
262266
auto tt = TensorType(type);
263267
if (tt.dims().size() != 0) {
@@ -397,6 +401,33 @@ struct Sema {
397401
}
398402
return List::create(list->range(), std::move(r));
399403
}
404+
TreeRef checkRangeConstraint(RangeConstraint rc) {
405+
// RCs are checked _before_ the rhs of the TC, so
406+
// it is possible the index is not in the environment yet
407+
// calling lookupOrCreate ensures it exists
408+
lookupVarOrCreateIndex(rc.ident());
409+
// calling looking directly in the index_env ensures that
410+
// we are actually constraining an index and not some other variable
411+
lookup(index_env, rc.ident(), true);
412+
auto s = expectIntegral(checkExp(rc.start(), false));
413+
auto e = expectIntegral(checkExp(rc.end(), false));
414+
return RangeConstraint::create(rc.range(), rc.ident(), s, e);
415+
}
416+
TreeRef checkLet(Let l) {
417+
auto rhs = checkExp(l.rhs(), true);
418+
insert(let_env, l.name(), typeOfExpr(rhs), true);
419+
return Let::create(l.range(), l.name(), rhs);
420+
}
421+
TreeRef checkWhereClause(TreeRef ref) {
422+
if (ref->kind() == TK_LET) {
423+
return checkLet(Let(ref));
424+
} else if (ref->kind() == TK_EXISTS) {
425+
auto exp = checkExp(Exists(ref).exp(), true);
426+
return Exists::create(ref->range(), exp);
427+
} else {
428+
return checkRangeConstraint(RangeConstraint(ref));
429+
}
430+
}
400431
TreeRef checkStmt(TreeRef stmt_) {
401432
auto stmt = Comprehension(stmt_);
402433

@@ -417,6 +448,11 @@ struct Sema {
417448
output_indices.push_back(new_var);
418449
}
419450

451+
// where clauses are checked _before_ the rhs because they
452+
// introduce let bindings that are in scope for the rhs
453+
auto where_clauses_ = stmt.whereClauses().map(
454+
[&](const TreeRef& rc) { return checkWhereClause(rc); });
455+
420456
TreeRef rhs_ = checkExp(stmt.rhs(), true);
421457
TreeRef scalar_type = typeOfExpr(rhs_);
422458

@@ -451,14 +487,6 @@ struct Sema {
451487
// if we redefined an input, it is no longer valid for range expressions
452488
live_input_names.erase(stmt.ident().name());
453489

454-
auto range_constraints =
455-
stmt.rangeConstraints().map([&](const RangeConstraint& rc) {
456-
lookup(index_env, rc.ident(), true);
457-
auto s = expectIntegral(checkExp(rc.start(), false));
458-
auto e = expectIntegral(checkExp(rc.end(), false));
459-
return RangeConstraint::create(rc.range(), rc.ident(), s, e);
460-
});
461-
462490
auto equivalent_statement_ =
463491
stmt.equivalent().map([&](const Equivalent& eq) {
464492
auto indices_ = eq.accesses().map(
@@ -489,10 +517,13 @@ struct Sema {
489517
stmt.indices(),
490518
stmt.assignment(),
491519
rhs_,
492-
range_constraints,
520+
where_clauses_,
493521
equivalent_statement_,
494522
reduction_variable_list);
523+
// clear the per-statement environments to get ready for the next statement
495524
index_env.clear();
525+
let_env.clear();
526+
496527
return result;
497528
}
498529
bool isNotInplace(const TreeRef& assignment) {
@@ -538,6 +569,8 @@ struct Sema {
538569
}
539570
TreeRef lookup(const Ident& ident, bool required) {
540571
TreeRef v = lookup(index_env, ident, false);
572+
if (!v)
573+
v = lookup(let_env, ident, false);
541574
if (!v)
542575
v = lookup(env, ident, required);
543576
return v;
@@ -560,6 +593,7 @@ struct Sema {
560593

561594
std::vector<TreeRef> reduction_variables; // per-statement
562595
Env index_env; // per-statement
596+
Env let_env; // per-statement, used for where i = <exp>
563597

564598
Env env; // name -> type
565599
Env annotated_output_types; // name -> type, for all annotated returns types

include/tc/lang/tree.h

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,34 +75,19 @@ struct Tree : std::enable_shared_from_this<Tree> {
7575
virtual TreeRef map(std::function<TreeRef(TreeRef)> fn) {
7676
return shared_from_this();
7777
}
78-
template <typename... Args>
79-
void match(int k, Args&... args) {
80-
matchD(k, "unknown", 0, args...);
78+
void expect(int k) {
79+
expect(k, trees().size());
8180
}
82-
template <typename... Args>
83-
void matchD(int k, const char* filename, int lineno, Args&... args) {
84-
if (kind() != k) {
81+
void expect(int k, int numsubtrees) {
82+
if (kind() != k || trees().size() != numsubtrees) {
8583
std::stringstream ss;
86-
ss << filename << ":" << lineno << ": expecting kind '" << kindToString(k)
87-
<< "' but found '" << kind() << "'\n";
84+
ss << "expected kind '" << kindToString(k) << "' with " << numsubtrees
85+
<< " subtrees but found '" << kindToString(kind()) << "' with "
86+
<< trees().size() << " subtrees.\n";
8887
range().highlight(ss);
8988
throw std::runtime_error(ss.str());
9089
}
91-
std::initializer_list<TreeRef*> vars = {&args...};
92-
if (vars.size() > trees().size()) {
93-
std::stringstream ss;
94-
ss << filename << ":" << lineno << ": trying to match " << vars.size()
95-
<< " variables against " << trees().size() << " values in list.\n";
96-
range().highlight(ss);
97-
throw std::runtime_error(ss.str());
98-
}
99-
size_t i = 0;
100-
for (TreeRef* v : vars) {
101-
*v = trees()[i++];
102-
}
10390
}
104-
105-
private:
10691
int kind_;
10792
};
10893

0 commit comments

Comments
 (0)