@@ -493,55 +493,82 @@ struct Sema {
493
493
494
494
// Semantic checking for the statements/comprehensions in a TC Def.
495
495
TreeRef checkStmt (TreeRef stmt_) {
496
- auto stmt = Comprehension (stmt_);
496
+ if (stmt_->kind () == TK_COMPREHENSION) {
497
+ return checkComprehension (Comprehension (stmt_));
498
+ }
499
+ return checkFor (For (stmt_));
500
+ }
497
501
502
+ TreeRef checkFor (For f) {
503
+ if (lookup (f.index (), false )) {
504
+ throw ErrorReport (f) << " For loop index already defined" ;
505
+ }
506
+ TreeList stmts;
507
+ for (auto s : f.statements ()) {
508
+ if (s->kind () != TK_COMPREHENSION) {
509
+ throw ErrorReport (s) << " Nested \" for\" loops NYI" ;
510
+ }
511
+ stmts.push_back (checkComprehension (Comprehension (s)));
512
+ }
513
+ // Check the range constraint after all statements
514
+ // This way we don't need extra state to track indices coming from loops
515
+ // that may have already been defined.
516
+ checkRangeConstraint (f.rangeConstraint ());
517
+ return For::create (
518
+ f.range (),
519
+ f.index (),
520
+ f.rangeConstraint (),
521
+ List::create (f.range (), std::move (stmts)));
522
+ }
523
+
524
+ TreeRef checkComprehension (Comprehension comp) {
498
525
// register index variables (non-reductions)
499
- for (const auto & index : stmt .indices ()) {
526
+ for (const auto & index : comp .indices ()) {
500
527
std::string idx = index.name ();
501
528
auto typ = indexType (index);
502
529
insert (index_env, index, typ, true );
503
530
}
504
531
505
532
// check that the input is not used for output - inputs are immutable
506
- std::string name = stmt .ident ().name ();
533
+ std::string name = comp .ident ().name ();
507
534
if (inputParameters.count (name) > 0 ) {
508
- throw ErrorReport (stmt_ ) << " TC inputs are immutable" ;
535
+ throw ErrorReport (comp ) << " TC inputs are immutable" ;
509
536
}
510
537
511
538
// make dimension variables for each dimension of the output tensor
512
539
TreeList output_indices;
513
- int n = stmt .indices ().size ();
540
+ int n = comp .indices ().size ();
514
541
for (int i = 0 ; i < n; ++i) {
515
542
auto new_var =
516
- Ident::create (stmt .range (), name + " ." + std::to_string (i));
543
+ Ident::create (comp .range (), name + " ." + std::to_string (i));
517
544
output_indices.push_back (new_var);
518
545
}
519
546
520
547
// where clauses are checked _before_ the rhs because they
521
548
// introduce let bindings that are in scope for the rhs
522
- auto where_clauses_ = stmt .whereClauses ().map (
549
+ auto where_clauses_ = comp .whereClauses ().map (
523
550
[&](TreeRef rc) { return checkWhereClause (rc); });
524
551
525
- TreeRef rhs_ = checkExp (stmt .rhs (), true );
552
+ TreeRef rhs_ = checkExp (comp .rhs (), true );
526
553
TreeRef scalar_type = typeOfExpr (rhs_);
527
554
528
555
// if this statement will be returned and it is annotated in the return list
529
556
// with a type (e.g. float(A,B)) then force the tensor to be that type
530
557
// and check that the number of dimensions are consistent
531
- auto output_annotation = annotated_output_types.find (stmt .ident ().name ());
558
+ auto output_annotation = annotated_output_types.find (comp .ident ().name ());
532
559
if (output_annotation != annotated_output_types.end ()) {
533
560
auto tt = TensorType (output_annotation->second );
534
561
auto matched_type = match_types (scalar_type, tt.scalarTypeTree ());
535
562
if (tt.scalarTypeTree ()->kind () != matched_type->kind ()) {
536
- throw ErrorReport (stmt )
563
+ throw ErrorReport (comp )
537
564
<< " attempting to assign type "
538
565
<< kindToString (scalar_type->kind ()) << " to narrower type "
539
566
<< kindToString (tt.scalarTypeTree ()->kind ())
540
567
<< " without an explicit cast" ;
541
568
}
542
- if (tt.dims ().size () != stmt .indices ().size ()) {
543
- throw ErrorReport (stmt )
544
- << " tensor defined with " << stmt .indices ().size ()
569
+ if (tt.dims ().size () != comp .indices ().size ()) {
570
+ throw ErrorReport (comp )
571
+ << " tensor defined with " << comp .indices ().size ()
545
572
<< " dimensions but declared as an output with " << tt.dims ().size ()
546
573
<< " dimensions." ;
547
574
}
@@ -550,33 +577,33 @@ struct Sema {
550
577
// After checking rhs and before creating lhs, we check if it is a reduction
551
578
// without initialization (i.e., reduction operator without "!" suffix, and
552
579
// lhs not defined previously).
553
- if (isUninitializedReductionOperation (stmt .assignment ()) &&
554
- nullptr == lookup (stmt .ident (), false )) {
555
- ErrorReport err (stmt );
556
- std::string tk = kindToToken (stmt .assignment ()->kind ());
557
- err << " Reduction without initialization. If " << stmt .ident ().name ()
580
+ if (isUninitializedReductionOperation (comp .assignment ()) &&
581
+ nullptr == lookup (comp .ident (), false )) {
582
+ ErrorReport err (comp );
583
+ std::string tk = kindToToken (comp .assignment ()->kind ());
584
+ err << " Reduction without initialization. If " << comp .ident ().name ()
558
585
<< " is not pre-initialized before calling the TC function,"
559
586
<< " consider using the !-suffixed reduction operator " << tk
560
587
<< " ! instead of " << tk;
561
588
warn (err);
562
589
}
563
590
564
591
auto type = TensorType::create (
565
- stmt .range (),
592
+ comp .range (),
566
593
scalar_type,
567
- List::create (stmt .range (), std::move (output_indices)));
568
- insert (env, stmt .ident (), type, false );
594
+ List::create (comp .range (), std::move (output_indices)));
595
+ insert (env, comp .ident (), type, false );
569
596
570
597
// if we redefined an input, it is no longer valid for range expressions
571
- live_input_names.erase (stmt .ident ().name ());
598
+ live_input_names.erase (comp .ident ().name ());
572
599
573
- auto equivalent_statement_ = stmt .equivalent ().map ([&](Equivalent eq) {
600
+ auto equivalent_statement_ = comp .equivalent ().map ([&](Equivalent eq) {
574
601
auto indices_ = eq.accesses ().map (
575
602
[&](TreeRef index) { return checkExp (index, true ); });
576
603
return Equivalent::create (eq.range (), eq.name (), indices_);
577
604
});
578
605
579
- TreeRef assignment = stmt .assignment ();
606
+ TreeRef assignment = comp .assignment ();
580
607
// For semantic consistency we allow overwriting reductions like +=!
581
608
// to be used in the language when there are no actual reduction dimensions.
582
609
// Later compile stages assume that there is at least one reduction
@@ -586,26 +613,26 @@ struct Sema {
586
613
assignment = Compound::create (' =' , assignment->range (), {});
587
614
}
588
615
589
- if (reduction_variables.size () > 0 && stmt .assignment ()->kind () == ' =' ) {
590
- throw ErrorReport (stmt ) << " this statement includes reduction variable '"
616
+ if (reduction_variables.size () > 0 && comp .assignment ()->kind () == ' =' ) {
617
+ throw ErrorReport (comp ) << " this statement includes reduction variable '"
591
618
<< Ident (reduction_variables.back ()).name ()
592
619
<< " ' but does not specify a reduction." ;
593
620
}
594
621
TreeRef reduction_variable_list =
595
- List::create (stmt .ident ().range (), std::move (reduction_variables));
622
+ List::create (comp .ident ().range (), std::move (reduction_variables));
596
623
TreeRef result = Comprehension::create (
597
- stmt .range (),
598
- stmt .ident (),
599
- stmt .indices (),
600
- stmt .assignment (),
624
+ comp .range (),
625
+ comp .ident (),
626
+ comp .indices (),
627
+ comp .assignment (),
601
628
rhs_,
602
629
where_clauses_,
603
630
equivalent_statement_,
604
631
reduction_variable_list);
605
632
606
- if (nonTemporaries.count (stmt .ident ().name ()) == 0 ) {
607
- throw ErrorReport (stmt )
608
- << stmt .ident ().name ()
633
+ if (nonTemporaries.count (comp .ident ().name ()) == 0 ) {
634
+ throw ErrorReport (comp )
635
+ << comp .ident ().name ()
609
636
<< " is not listed as an input or output to this function. Temporaries tensors are not yet implemented" ;
610
637
}
611
638
0 commit comments