@@ -194,6 +194,27 @@ struct Sema {
194
194
}
195
195
return e;
196
196
}
197
+ void expectBool (TreeRef anchor, int token) {
198
+ if (token != TK_BOOL) {
199
+ throw ErrorReport (anchor)
200
+ << " expected boolean but found " << kindToString (token);
201
+ }
202
+ }
203
+ TreeRef expectBool (TreeRef exp) {
204
+ expectBool (exp, typeOfExpr (exp)->kind ());
205
+ return exp;
206
+ }
207
+ TreeRef lookupVarOrCreateIndex (Ident ident) {
208
+ TreeRef type = lookup (ident, false );
209
+ if (!type) {
210
+ // variable exp is not defined, so a reduction variable is created
211
+ // a reduction variable index i
212
+ type = indexType (ident);
213
+ insert (index_env, ident, type, true );
214
+ reduction_variables.push_back (ident);
215
+ }
216
+ return type;
217
+ }
197
218
TreeRef checkExp (TreeRef exp, bool allow_access) {
198
219
switch (exp->kind ()) {
199
220
case TK_APPLY: {
@@ -205,6 +226,7 @@ struct Sema {
205
226
throw ErrorReport (exp)
206
227
<< " tensor accesses cannot be used in this context" ;
207
228
}
229
+
208
230
// also handle built-in functions log, exp, etc.
209
231
auto ident = a.name ();
210
232
if (builtin_functions.count (ident.name ()) > 0 ) {
@@ -239,14 +261,7 @@ struct Sema {
239
261
} break ;
240
262
case TK_IDENT: {
241
263
auto ident = Ident (exp);
242
- TreeRef type = lookup (ident, false );
243
- if (!type) {
244
- // variable exp is not defined, so a reduction variable is created
245
- // a reduction variable index i
246
- type = indexType (exp);
247
- insert (index_env, ident, type, true );
248
- reduction_variables.push_back (exp);
249
- }
264
+ auto type = lookupVarOrCreateIndex (ident);
250
265
if (type->kind () == TK_TENSOR_TYPE) {
251
266
auto tt = TensorType (type);
252
267
if (tt.dims ().size () != 0 ) {
@@ -276,6 +291,35 @@ struct Sema {
276
291
exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
277
292
return withType (nexp, matchAllTypes (nexp));
278
293
} break ;
294
+ case TK_EQ:
295
+ case TK_NE:
296
+ case TK_GE:
297
+ case TK_LE:
298
+ case ' <' :
299
+ case ' >' : {
300
+ auto nexp =
301
+ exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
302
+ // make sure the types match but the return type
303
+ // is always bool
304
+ matchAllTypes (nexp);
305
+ return withType (nexp, boolType (exp));
306
+ } break ;
307
+ case TK_AND:
308
+ case TK_OR:
309
+ case ' !' : {
310
+ auto nexp =
311
+ exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
312
+ expectBool (exp, matchAllTypes (nexp)->kind ());
313
+ return withType (nexp, boolType (exp));
314
+ } break ;
315
+ case ' ?' : {
316
+ auto nexp =
317
+ exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
318
+ expectBool (nexp->tree (0 ));
319
+ auto rtype =
320
+ match_types (typeOfExpr (nexp->tree (1 )), typeOfExpr (nexp->tree (2 )));
321
+ return withType (nexp, rtype);
322
+ }
279
323
case TK_CONST: {
280
324
auto c = Const (exp);
281
325
return withType (exp, c.type ());
@@ -322,7 +366,10 @@ struct Sema {
322
366
TreeRef floatType (TreeRef anchor) {
323
367
return c (TK_FLOAT, anchor->range (), {});
324
368
}
325
- void checkDim (const Ident& dim) {
369
+ TreeRef boolType (TreeRef anchor) {
370
+ return c (TK_BOOL, anchor->range (), {});
371
+ }
372
+ void checkDim (Ident dim) {
326
373
insert (env, dim, dimType (dim), false );
327
374
}
328
375
TreeRef checkTensorType (TreeRef type) {
@@ -354,6 +401,33 @@ struct Sema {
354
401
}
355
402
return List::create (list->range (), std::move (r));
356
403
}
404
+ TreeRef checkRangeConstraint (RangeConstraint rc) {
405
+ // RCs are checked _before_ the rhs of the TC, so
406
+ // it is possible the index is not in the environment yet
407
+ // calling lookupOrCreate ensures it exists
408
+ lookupVarOrCreateIndex (rc.ident ());
409
+ // calling looking directly in the index_env ensures that
410
+ // we are actually constraining an index and not some other variable
411
+ lookup (index_env, rc.ident (), true );
412
+ auto s = expectIntegral (checkExp (rc.start (), false ));
413
+ auto e = expectIntegral (checkExp (rc.end (), false ));
414
+ return RangeConstraint::create (rc.range (), rc.ident (), s, e);
415
+ }
416
+ TreeRef checkLet (Let l) {
417
+ auto rhs = checkExp (l.rhs (), true );
418
+ insert (let_env, l.name (), typeOfExpr (rhs), true );
419
+ return Let::create (l.range (), l.name (), rhs);
420
+ }
421
+ TreeRef checkWhereClause (TreeRef ref) {
422
+ if (ref->kind () == TK_LET) {
423
+ return checkLet (Let (ref));
424
+ } else if (ref->kind () == TK_EXISTS) {
425
+ auto exp = checkExp (Exists (ref).exp (), true );
426
+ return Exists::create (ref->range (), exp);
427
+ } else {
428
+ return checkRangeConstraint (RangeConstraint (ref));
429
+ }
430
+ }
357
431
TreeRef checkStmt (TreeRef stmt_) {
358
432
auto stmt = Comprehension (stmt_);
359
433
@@ -374,6 +448,11 @@ struct Sema {
374
448
output_indices.push_back (new_var);
375
449
}
376
450
451
+ // where clauses are checked _before_ the rhs because they
452
+ // introduce let bindings that are in scope for the rhs
453
+ auto where_clauses_ = stmt.whereClauses ().map (
454
+ [&](TreeRef rc) { return checkWhereClause (rc); });
455
+
377
456
TreeRef rhs_ = checkExp (stmt.rhs (), true );
378
457
TreeRef scalar_type = typeOfExpr (rhs_);
379
458
@@ -408,20 +487,11 @@ struct Sema {
408
487
// if we redefined an input, it is no longer valid for range expressions
409
488
live_input_names.erase (stmt.ident ().name ());
410
489
411
- auto range_constraints =
412
- stmt.rangeConstraints ().map ([&](const RangeConstraint& rc) {
413
- lookup (index_env, rc.ident (), true );
414
- auto s = expectIntegral (checkExp (rc.start (), false ));
415
- auto e = expectIntegral (checkExp (rc.end (), false ));
416
- return RangeConstraint::create (rc.range (), rc.ident (), s, e);
417
- });
418
-
419
- auto equivalent_statement_ =
420
- stmt.equivalent ().map ([&](const Equivalent& eq) {
421
- auto indices_ = eq.accesses ().map (
422
- [&](TreeRef index) { return checkExp (index, true ); });
423
- return Equivalent::create (eq.range (), eq.name (), indices_);
424
- });
490
+ auto equivalent_statement_ = stmt.equivalent ().map ([&](Equivalent eq) {
491
+ auto indices_ = eq.accesses ().map (
492
+ [&](TreeRef index) { return checkExp (index, true ); });
493
+ return Equivalent::create (eq.range (), eq.name (), indices_);
494
+ });
425
495
426
496
TreeRef assignment = stmt.assignment ();
427
497
// For semantic consistency we allow overwriting reductions like +=!
@@ -446,13 +516,16 @@ struct Sema {
446
516
stmt.indices (),
447
517
stmt.assignment (),
448
518
rhs_,
449
- range_constraints ,
519
+ where_clauses_ ,
450
520
equivalent_statement_,
451
521
reduction_variable_list);
522
+ // clear the per-statement environments to get ready for the next statement
452
523
index_env.clear ();
524
+ let_env.clear ();
525
+
453
526
return result;
454
527
}
455
- bool isNotInplace (const TreeRef& assignment) {
528
+ bool isNotInplace (TreeRef assignment) {
456
529
switch (assignment->kind ()) {
457
530
case TK_PLUS_EQ_B:
458
531
case TK_TIMES_EQ_B:
@@ -493,13 +566,15 @@ struct Sema {
493
566
throw ErrorReport (ident) << name << " already defined" ;
494
567
}
495
568
}
496
- TreeRef lookup (const Ident& ident, bool required) {
569
+ TreeRef lookup (Ident ident, bool required) {
497
570
TreeRef v = lookup (index_env, ident, false );
571
+ if (!v)
572
+ v = lookup (let_env, ident, false );
498
573
if (!v)
499
574
v = lookup (env, ident, required);
500
575
return v;
501
576
}
502
- TreeRef lookup (Env& the_env, const Ident& ident, bool required) {
577
+ TreeRef lookup (Env& the_env, Ident ident, bool required) {
503
578
std::string name = ident.name ();
504
579
auto it = the_env.find (name);
505
580
if (required && it == the_env.end ()) {
@@ -517,6 +592,7 @@ struct Sema {
517
592
518
593
std::vector<TreeRef> reduction_variables; // per-statement
519
594
Env index_env; // per-statement
595
+ Env let_env; // per-statement, used for where i = <exp>
520
596
521
597
Env env; // name -> type
522
598
Env annotated_output_types; // name -> type, for all annotated returns types
0 commit comments