Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 3fd7669

Browse files
authored
Merge pull request #28 from facebookresearch/pr/let_bindings
Implement where clause variables and exists
2 parents 96b4e4c + 9a03d57 commit 3fd7669

File tree

11 files changed

+400
-195
lines changed

11 files changed

+400
-195
lines changed

docs/source/inference.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,20 @@ Strided indexing with dynamic stride
271271
The value of :code:`S(0)` is not fixed until runtime, so we can't resolve the
272272
size of :code:`A` or the range of the loop. This case throws a compile-time
273273
error. A :code:`where` clause that defines the range of :code:`i` is required.
274+
275+
Constant fill using an exists clause
276+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
277+
278+
.. code::
279+
280+
def constant_fill(float(N) A, float c) -> B {
281+
B(i) = c where exists A(i)
282+
}
283+
284+
An :code:`exists` clause allows you to add additional expressions to the range
285+
inference process without having the expressions affect the actual computation.
286+
In this example, it allows you to say that :code:`B(i)` should have the same size as
287+
:code:`A(i)`, but be filled with a constant value :code:`c`. That is, you should infer the
288+
range of :code:`B(i)` to exist at all the places where :code:`A(i)` exists.
289+
It is equivalent to writing the expression :code:`true ? c : A(i)`, but with
290+
clearer intentions.

include/tc/core/libraries.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ constexpr auto boundsAsTemplate = R"C(
140140
template<typename T> inline __device__ T floord(T n, T d) {
141141
return n < 0 ? - (-n + d - 1)/d : n / d;
142142
}
143+
#define if_then_else(cond,a,b) (cond) ? (a) : (b);
143144
)C";
144145
} // namespace cpp
145146

include/tc/lang/lexer.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,17 @@ namespace lang {
7777
_(TK_UINT64, "uint64", "uint64") \
7878
_(TK_BOOL, "bool", "bool") \
7979
_(TK_CAST, "cast", "") \
80-
_(TK_IN, "in", "in")
80+
_(TK_IN, "in", "in") \
81+
_(TK_GE, "ge", ">=") \
82+
_(TK_LE, "le", "<=") \
83+
_(TK_EQ, "eq", "==") \
84+
_(TK_NE, "neq", "!=") \
85+
_(TK_AND, "and", "&&") \
86+
_(TK_OR, "or", "||") \
87+
_(TK_LET, "let", "") \
88+
_(TK_EXISTS, "exists", "exists")
8189

82-
static const char* valid_single_char_tokens = "+-*/()[]?:,={}>";
90+
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
8391

8492
enum TokenKind {
8593
// we use characters to represent themselves so skip all valid characters
@@ -121,11 +129,14 @@ struct SharedParserData {
121129
// listed in increasing order of precedence
122130
std::vector<std::vector<int>> binary_ops = {
123131
{'?'},
132+
{TK_OR},
133+
{TK_AND},
134+
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
124135
{'+', '-'},
125136
{'*', '/'},
126137
};
127138
std::vector<std::vector<int>> unary_ops = {
128-
{'-'},
139+
{'-', '!'},
129140
};
130141

131142
std::stringstream ss;
@@ -371,9 +382,20 @@ struct Lexer {
371382
next();
372383
return true;
373384
}
385+
Token lookahead() {
386+
if (!lookahead_) {
387+
lookahead_.reset(new Token(lex()));
388+
}
389+
return *lookahead_;
390+
}
374391
Token next() {
375392
auto r = cur_;
376-
cur_ = lex();
393+
if (lookahead_) {
394+
cur_ = *lookahead_;
395+
lookahead_.reset();
396+
} else {
397+
cur_ = lex();
398+
}
377399
return r;
378400
}
379401
void reportError(const std::string& what, const Token& t);
@@ -407,6 +429,7 @@ struct Lexer {
407429
}
408430
size_t pos;
409431
Token cur_;
432+
std::unique_ptr<Token> lookahead_;
410433
SharedParserData& shared;
411434
};
412435
} // namespace lang

include/tc/lang/parser.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,24 @@ struct Parser {
149149
auto r = parseExp();
150150
return RangeConstraint::create(id->range(), id, l, r);
151151
}
152+
TreeRef parseLetBinding() {
153+
auto ident = parseIdent();
154+
L.expect('=');
155+
auto exp = parseExp();
156+
return Let::create(ident->range(), ident, exp);
157+
}
158+
TreeRef parseWhereClause() {
159+
auto lookahead = L.lookahead();
160+
if (lookahead.kind == '=') {
161+
return parseLetBinding();
162+
} else if (lookahead.kind == TK_IN) {
163+
return parseRangeConstraint();
164+
} else {
165+
L.expect(TK_EXISTS);
166+
auto exp = parseExp();
167+
return Exists::create(exp->range(), {exp});
168+
}
169+
}
152170
TreeRef parseParam() {
153171
if (L.cur().kind == TK_IDENT) {
154172
auto ident = parseIdent();
@@ -159,10 +177,9 @@ struct Parser {
159177
auto ident = parseIdent();
160178
return Param::create(typ->range(), ident, typ);
161179
}
162-
TreeRef parseRangeConstraints() {
180+
TreeRef parseWhereClauses() {
163181
if (L.nextIf(TK_WHERE)) {
164-
return parseNonEmptyList(
165-
',', [&](int i) { return parseRangeConstraint(); });
182+
return parseNonEmptyList(',', [&](int i) { return parseWhereClause(); });
166183
}
167184
return List::create(L.cur().range, {});
168185
}
@@ -200,7 +217,7 @@ struct Parser {
200217
auto assign = parseAssignment();
201218
auto rhs = parseExp();
202219
TreeRef equivalent_statement = parseEquivalent();
203-
TreeRef range_statements = parseRangeConstraints();
220+
TreeRef range_statements = parseWhereClauses();
204221
TreeRef empty_reduction_variables = c(TK_LIST, ident->range(), {});
205222
return Comprehension::create(
206223
ident->range(),

include/tc/lang/sema.h

Lines changed: 103 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,27 @@ struct Sema {
194194
}
195195
return e;
196196
}
197+
void expectBool(TreeRef anchor, int token) {
198+
if (token != TK_BOOL) {
199+
throw ErrorReport(anchor)
200+
<< "expected boolean but found " << kindToString(token);
201+
}
202+
}
203+
TreeRef expectBool(TreeRef exp) {
204+
expectBool(exp, typeOfExpr(exp)->kind());
205+
return exp;
206+
}
207+
TreeRef lookupVarOrCreateIndex(Ident ident) {
208+
TreeRef type = lookup(ident, false);
209+
if (!type) {
210+
// variable exp is not defined, so a reduction variable is created
211+
// a reduction variable index i
212+
type = indexType(ident);
213+
insert(index_env, ident, type, true);
214+
reduction_variables.push_back(ident);
215+
}
216+
return type;
217+
}
197218
TreeRef checkExp(TreeRef exp, bool allow_access) {
198219
switch (exp->kind()) {
199220
case TK_APPLY: {
@@ -205,6 +226,7 @@ struct Sema {
205226
throw ErrorReport(exp)
206227
<< "tensor accesses cannot be used in this context";
207228
}
229+
208230
// also handle built-in functions log, exp, etc.
209231
auto ident = a.name();
210232
if (builtin_functions.count(ident.name()) > 0) {
@@ -239,14 +261,7 @@ struct Sema {
239261
} break;
240262
case TK_IDENT: {
241263
auto ident = Ident(exp);
242-
TreeRef type = lookup(ident, false);
243-
if (!type) {
244-
// variable exp is not defined, so a reduction variable is created
245-
// a reduction variable index i
246-
type = indexType(exp);
247-
insert(index_env, ident, type, true);
248-
reduction_variables.push_back(exp);
249-
}
264+
auto type = lookupVarOrCreateIndex(ident);
250265
if (type->kind() == TK_TENSOR_TYPE) {
251266
auto tt = TensorType(type);
252267
if (tt.dims().size() != 0) {
@@ -276,6 +291,35 @@ struct Sema {
276291
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
277292
return withType(nexp, matchAllTypes(nexp));
278293
} break;
294+
case TK_EQ:
295+
case TK_NE:
296+
case TK_GE:
297+
case TK_LE:
298+
case '<':
299+
case '>': {
300+
auto nexp =
301+
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
302+
// make sure the types match but the return type
303+
// is always bool
304+
matchAllTypes(nexp);
305+
return withType(nexp, boolType(exp));
306+
} break;
307+
case TK_AND:
308+
case TK_OR:
309+
case '!': {
310+
auto nexp =
311+
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
312+
expectBool(exp, matchAllTypes(nexp)->kind());
313+
return withType(nexp, boolType(exp));
314+
} break;
315+
case '?': {
316+
auto nexp =
317+
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
318+
expectBool(nexp->tree(0));
319+
auto rtype =
320+
match_types(typeOfExpr(nexp->tree(1)), typeOfExpr(nexp->tree(2)));
321+
return withType(nexp, rtype);
322+
}
279323
case TK_CONST: {
280324
auto c = Const(exp);
281325
return withType(exp, c.type());
@@ -322,7 +366,10 @@ struct Sema {
322366
TreeRef floatType(TreeRef anchor) {
323367
return c(TK_FLOAT, anchor->range(), {});
324368
}
325-
void checkDim(const Ident& dim) {
369+
TreeRef boolType(TreeRef anchor) {
370+
return c(TK_BOOL, anchor->range(), {});
371+
}
372+
void checkDim(Ident dim) {
326373
insert(env, dim, dimType(dim), false);
327374
}
328375
TreeRef checkTensorType(TreeRef type) {
@@ -354,6 +401,33 @@ struct Sema {
354401
}
355402
return List::create(list->range(), std::move(r));
356403
}
404+
TreeRef checkRangeConstraint(RangeConstraint rc) {
405+
// RCs are checked _before_ the rhs of the TC, so
406+
// it is possible the index is not in the environment yet
407+
// calling lookupOrCreate ensures it exists
408+
lookupVarOrCreateIndex(rc.ident());
409+
// calling looking directly in the index_env ensures that
410+
// we are actually constraining an index and not some other variable
411+
lookup(index_env, rc.ident(), true);
412+
auto s = expectIntegral(checkExp(rc.start(), false));
413+
auto e = expectIntegral(checkExp(rc.end(), false));
414+
return RangeConstraint::create(rc.range(), rc.ident(), s, e);
415+
}
416+
TreeRef checkLet(Let l) {
417+
auto rhs = checkExp(l.rhs(), true);
418+
insert(let_env, l.name(), typeOfExpr(rhs), true);
419+
return Let::create(l.range(), l.name(), rhs);
420+
}
421+
TreeRef checkWhereClause(TreeRef ref) {
422+
if (ref->kind() == TK_LET) {
423+
return checkLet(Let(ref));
424+
} else if (ref->kind() == TK_EXISTS) {
425+
auto exp = checkExp(Exists(ref).exp(), true);
426+
return Exists::create(ref->range(), exp);
427+
} else {
428+
return checkRangeConstraint(RangeConstraint(ref));
429+
}
430+
}
357431
TreeRef checkStmt(TreeRef stmt_) {
358432
auto stmt = Comprehension(stmt_);
359433

@@ -374,6 +448,11 @@ struct Sema {
374448
output_indices.push_back(new_var);
375449
}
376450

451+
// where clauses are checked _before_ the rhs because they
452+
// introduce let bindings that are in scope for the rhs
453+
auto where_clauses_ = stmt.whereClauses().map(
454+
[&](TreeRef rc) { return checkWhereClause(rc); });
455+
377456
TreeRef rhs_ = checkExp(stmt.rhs(), true);
378457
TreeRef scalar_type = typeOfExpr(rhs_);
379458

@@ -408,20 +487,11 @@ struct Sema {
408487
// if we redefined an input, it is no longer valid for range expressions
409488
live_input_names.erase(stmt.ident().name());
410489

411-
auto range_constraints =
412-
stmt.rangeConstraints().map([&](const RangeConstraint& rc) {
413-
lookup(index_env, rc.ident(), true);
414-
auto s = expectIntegral(checkExp(rc.start(), false));
415-
auto e = expectIntegral(checkExp(rc.end(), false));
416-
return RangeConstraint::create(rc.range(), rc.ident(), s, e);
417-
});
418-
419-
auto equivalent_statement_ =
420-
stmt.equivalent().map([&](const Equivalent& eq) {
421-
auto indices_ = eq.accesses().map(
422-
[&](TreeRef index) { return checkExp(index, true); });
423-
return Equivalent::create(eq.range(), eq.name(), indices_);
424-
});
490+
auto equivalent_statement_ = stmt.equivalent().map([&](Equivalent eq) {
491+
auto indices_ = eq.accesses().map(
492+
[&](TreeRef index) { return checkExp(index, true); });
493+
return Equivalent::create(eq.range(), eq.name(), indices_);
494+
});
425495

426496
TreeRef assignment = stmt.assignment();
427497
// For semantic consistency we allow overwriting reductions like +=!
@@ -446,13 +516,16 @@ struct Sema {
446516
stmt.indices(),
447517
stmt.assignment(),
448518
rhs_,
449-
range_constraints,
519+
where_clauses_,
450520
equivalent_statement_,
451521
reduction_variable_list);
522+
// clear the per-statement environments to get ready for the next statement
452523
index_env.clear();
524+
let_env.clear();
525+
453526
return result;
454527
}
455-
bool isNotInplace(const TreeRef& assignment) {
528+
bool isNotInplace(TreeRef assignment) {
456529
switch (assignment->kind()) {
457530
case TK_PLUS_EQ_B:
458531
case TK_TIMES_EQ_B:
@@ -493,13 +566,15 @@ struct Sema {
493566
throw ErrorReport(ident) << name << " already defined";
494567
}
495568
}
496-
TreeRef lookup(const Ident& ident, bool required) {
569+
TreeRef lookup(Ident ident, bool required) {
497570
TreeRef v = lookup(index_env, ident, false);
571+
if (!v)
572+
v = lookup(let_env, ident, false);
498573
if (!v)
499574
v = lookup(env, ident, required);
500575
return v;
501576
}
502-
TreeRef lookup(Env& the_env, const Ident& ident, bool required) {
577+
TreeRef lookup(Env& the_env, Ident ident, bool required) {
503578
std::string name = ident.name();
504579
auto it = the_env.find(name);
505580
if (required && it == the_env.end()) {
@@ -517,6 +592,7 @@ struct Sema {
517592

518593
std::vector<TreeRef> reduction_variables; // per-statement
519594
Env index_env; // per-statement
595+
Env let_env; // per-statement, used for where i = <exp>
520596

521597
Env env; // name -> type
522598
Env annotated_output_types; // name -> type, for all annotated returns types

0 commit comments

Comments
 (0)