@@ -204,6 +204,17 @@ struct Sema {
204
204
expectBool (exp, typeOfExpr (exp)->kind ());
205
205
return exp;
206
206
}
207
+ TreeRef lookupVarOrCreateIndex (Ident ident) {
208
+ TreeRef type = lookup (ident, false );
209
+ if (!type) {
210
+ // variable exp is not defined, so a reduction variable is created
211
+ // a reduction variable index i
212
+ type = indexType (ident);
213
+ insert (index_env, ident, type, true );
214
+ reduction_variables.push_back (ident);
215
+ }
216
+ return type;
217
+ }
207
218
TreeRef checkExp (TreeRef exp, bool allow_access) {
208
219
switch (exp->kind ()) {
209
220
case TK_APPLY: {
@@ -250,14 +261,7 @@ struct Sema {
250
261
} break ;
251
262
case TK_IDENT: {
252
263
auto ident = Ident (exp);
253
- TreeRef type = lookup (ident, false );
254
- if (!type) {
255
- // variable exp is not defined, so a reduction variable is created
256
- // a reduction variable index i
257
- type = indexType (exp);
258
- insert (index_env, ident, type, true );
259
- reduction_variables.push_back (exp);
260
- }
264
+ auto type = lookupVarOrCreateIndex (ident);
261
265
if (type->kind () == TK_TENSOR_TYPE) {
262
266
auto tt = TensorType (type);
263
267
if (tt.dims ().size () != 0 ) {
@@ -397,6 +401,33 @@ struct Sema {
397
401
}
398
402
return List::create (list->range (), std::move (r));
399
403
}
404
+ TreeRef checkRangeConstraint (RangeConstraint rc) {
405
+ // RCs are checked _before_ the rhs of the TC, so
406
+ // it is possible the index is not in the environment yet
407
+ // calling lookupOrCreate ensures it exists
408
+ lookupVarOrCreateIndex (rc.ident ());
409
+ // calling looking directly in the index_env ensures that
410
+ // we are actually constraining an index and not some other variable
411
+ lookup (index_env, rc.ident (), true );
412
+ auto s = expectIntegral (checkExp (rc.start (), false ));
413
+ auto e = expectIntegral (checkExp (rc.end (), false ));
414
+ return RangeConstraint::create (rc.range (), rc.ident (), s, e);
415
+ }
416
+ TreeRef checkLet (Let l) {
417
+ auto rhs = checkExp (l.rhs (), true );
418
+ insert (let_env, l.name (), typeOfExpr (rhs), true );
419
+ return Let::create (l.range (), l.name (), rhs);
420
+ }
421
+ TreeRef checkWhereClause (TreeRef ref) {
422
+ if (ref->kind () == TK_LET) {
423
+ return checkLet (Let (ref));
424
+ } else if (ref->kind () == TK_EXISTS) {
425
+ auto exp = checkExp (Exists (ref).exp (), true );
426
+ return Exists::create (ref->range (), exp);
427
+ } else {
428
+ return checkRangeConstraint (RangeConstraint (ref));
429
+ }
430
+ }
400
431
TreeRef checkStmt (TreeRef stmt_) {
401
432
auto stmt = Comprehension (stmt_);
402
433
@@ -417,6 +448,11 @@ struct Sema {
417
448
output_indices.push_back (new_var);
418
449
}
419
450
451
+ // where clauses are checked _before_ the rhs because they
452
+ // introduce let bindings that are in scope for the rhs
453
+ auto where_clauses_ = stmt.whereClauses ().map (
454
+ [&](const TreeRef& rc) { return checkWhereClause (rc); });
455
+
420
456
TreeRef rhs_ = checkExp (stmt.rhs (), true );
421
457
TreeRef scalar_type = typeOfExpr (rhs_);
422
458
@@ -451,14 +487,6 @@ struct Sema {
451
487
// if we redefined an input, it is no longer valid for range expressions
452
488
live_input_names.erase (stmt.ident ().name ());
453
489
454
- auto range_constraints =
455
- stmt.rangeConstraints ().map ([&](const RangeConstraint& rc) {
456
- lookup (index_env, rc.ident (), true );
457
- auto s = expectIntegral (checkExp (rc.start (), false ));
458
- auto e = expectIntegral (checkExp (rc.end (), false ));
459
- return RangeConstraint::create (rc.range (), rc.ident (), s, e);
460
- });
461
-
462
490
auto equivalent_statement_ =
463
491
stmt.equivalent ().map ([&](const Equivalent& eq) {
464
492
auto indices_ = eq.accesses ().map (
@@ -489,10 +517,13 @@ struct Sema {
489
517
stmt.indices (),
490
518
stmt.assignment (),
491
519
rhs_,
492
- range_constraints ,
520
+ where_clauses_ ,
493
521
equivalent_statement_,
494
522
reduction_variable_list);
523
+ // clear the per-statement environments to get ready for the next statement
495
524
index_env.clear ();
525
+ let_env.clear ();
526
+
496
527
return result;
497
528
}
498
529
bool isNotInplace (const TreeRef& assignment) {
@@ -538,6 +569,8 @@ struct Sema {
538
569
}
539
570
TreeRef lookup (const Ident& ident, bool required) {
540
571
TreeRef v = lookup (index_env, ident, false );
572
+ if (!v)
573
+ v = lookup (let_env, ident, false );
541
574
if (!v)
542
575
v = lookup (env, ident, required);
543
576
return v;
@@ -560,6 +593,7 @@ struct Sema {
560
593
561
594
std::vector<TreeRef> reduction_variables; // per-statement
562
595
Env index_env; // per-statement
596
+ Env let_env; // per-statement, used for where i = <exp>
563
597
564
598
Env env; // name -> type
565
599
Env annotated_output_types; // name -> type, for all annotated returns types
0 commit comments