@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
62
62
}
63
63
}
64
64
65
+ // Translate the TC def input params to corresponding Halide components.
66
+ // params, inputs will be populated here.
65
67
void translateParam (
66
68
const lang::Param& p,
67
69
map<string, Parameter>* params,
68
70
vector<ImageParam>* inputs) {
71
+ // Check if the param has already been converted to halide components.
69
72
if (params->find (p.ident ().name ()) != params->end ()) {
70
73
return ;
71
- } else {
72
- lang::TensorType type = p.tensorType ();
73
- int dimensions = (int )type.dims ().size ();
74
- ImageParam imageParam (
75
- translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
76
- inputs->push_back (imageParam);
77
- vector<Expr> dims;
78
- for (auto d_ : type.dims ()) {
79
- if (d_->kind () == lang::TK_IDENT) {
80
- auto d = lang::Ident (d_);
81
- auto it = params->find (d.name ());
82
- Parameter p;
83
- if (it != params->end ()) {
84
- p = it->second ;
85
- } else {
86
- p = Parameter (Int (32 ), false , 0 , d.name (), true );
87
- (*params)[d.name ()] = p;
88
- }
89
- dims.push_back (Variable::make (Int (32 ), p.name (), p));
74
+ }
75
+ lang::TensorType type = p.tensorType ();
76
+ int dimensions = (int )type.dims ().size ();
77
+ ImageParam imageParam (
78
+ translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79
+ inputs->push_back (imageParam);
80
+ vector<Expr> dims;
81
+ for (auto d_ : type.dims ()) {
82
+ if (d_->kind () == lang::TK_IDENT) {
83
+ auto d = lang::Ident (d_);
84
+ auto it = params->find (d.name ());
85
+ Parameter p;
86
+ if (it != params->end ()) {
87
+ p = it->second ;
90
88
} else {
91
- CHECK (d_->kind () == lang::TK_CONST);
92
- int32_t value = lang::Const (d_).value ();
93
- dims.push_back (Expr (value));
89
+ p = Parameter (Int (32 ), false , 0 , d.name (), true );
90
+ (*params)[d.name ()] = p;
94
91
}
92
+ dims.push_back (Variable::make (Int (32 ), p.name (), p));
93
+ } else {
94
+ CHECK (d_->kind () == lang::TK_CONST);
95
+ int32_t value = lang::Const (d_).value ();
96
+ dims.push_back (Expr (value));
95
97
}
98
+ }
96
99
97
- for (int i = 0 ; i < imageParam.dimensions (); i++) {
98
- imageParam.dim (i).set_bounds (0 , dims[i]);
99
- }
100
- (*params)[imageParam.name ()] = imageParam.parameter ();
100
+ for (int i = 0 ; i < imageParam.dimensions (); i++) {
101
+ imageParam.dim (i).set_bounds (0 , dims[i]);
101
102
}
103
+ (*params)[imageParam.name ()] = imageParam.parameter ();
102
104
}
103
105
104
106
void translateOutput (
@@ -156,6 +158,8 @@ Expr translateExpr(
156
158
return t (0 ) * t (1 );
157
159
case ' /' :
158
160
return t (0 ) / t (1 );
161
+ case ' %' :
162
+ return t (0 ) % t (1 );
159
163
case lang::TK_MIN:
160
164
return min (t (0 ), t (1 ));
161
165
case lang::TK_MAX:
@@ -488,22 +492,25 @@ Expr reductionUpdate(Expr e) {
488
492
return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
489
493
}
490
494
495
+ // Translate a single TC comprehension/statement to Halide components: funcs,
496
+ // bounds, reductions.
497
+ //
491
498
// Note that the function definitions created by translateComprehension may
492
499
// contain kReductionUpdate intrinsics. These may have to be removed
493
500
// in order to be able to apply internal Halide analysis passes on them.
494
501
void translateComprehension (
495
- const lang::Comprehension& c ,
502
+ const lang::Comprehension& comprehension ,
496
503
const map<string, Parameter>& params,
497
504
bool throwWarnings,
498
505
map<string, Function>* funcs,
499
506
FunctionBounds* bounds) {
500
507
Function f;
501
- auto it = funcs->find (c .ident ().name ());
508
+ auto it = funcs->find (comprehension .ident ().name ());
502
509
if (it != funcs->end ()) {
503
510
f = it->second ;
504
511
} else {
505
- f = Function (c .ident ().name ());
506
- (*funcs)[c .ident ().name ()] = f;
512
+ f = Function (comprehension .ident ().name ());
513
+ (*funcs)[comprehension .ident ().name ()] = f;
507
514
}
508
515
// Function is the internal Halide IR type for a pipeline
509
516
// stage. Func is the front-end class that wraps it. Here it's
@@ -512,7 +519,7 @@ void translateComprehension(
512
519
513
520
vector<Var> lhs;
514
521
vector<Expr> lhs_as_exprs;
515
- for (lang::Ident id : c .indices ()) {
522
+ for (lang::Ident id : comprehension .indices ()) {
516
523
lhs.push_back (Var (id.name ()));
517
524
lhs_as_exprs.push_back (lhs.back ());
518
525
}
@@ -521,17 +528,17 @@ void translateComprehension(
521
528
// in the future we may consider using Halide Let bindings when they
522
529
// are supported later
523
530
map<string, Expr> lets;
524
- for (auto wc : c .whereClauses ()) {
531
+ for (auto wc : comprehension .whereClauses ()) {
525
532
if (wc->kind () == lang::TK_LET) {
526
533
auto let = lang::Let (wc);
527
534
lets[let.name ().name ()] = translateExpr (let.rhs (), params, *funcs, lets);
528
535
}
529
536
}
530
537
531
- Expr rhs = translateExpr (c .rhs (), params, *funcs, lets);
538
+ Expr rhs = translateExpr (comprehension .rhs (), params, *funcs, lets);
532
539
533
540
std::vector<Expr> all_exprs;
534
- for (auto wc : c .whereClauses ()) {
541
+ for (auto wc : comprehension .whereClauses ()) {
535
542
if (wc->kind () == lang::TK_EXISTS) {
536
543
all_exprs.push_back (
537
544
translateExpr (lang::Exists (wc).exp (), params, *funcs, lets));
@@ -555,7 +562,7 @@ void translateComprehension(
555
562
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
556
563
// for the reduction and then applies the reduction.
557
564
bool should_zero = false ;
558
- switch (c .assignment ()->kind ()) {
565
+ switch (comprehension .assignment ()->kind ()) {
559
566
case lang::TK_PLUS_EQ_B:
560
567
should_zero = true ; // fallthrough
561
568
case lang::TK_PLUS_EQ:
@@ -587,12 +594,13 @@ void translateComprehension(
587
594
case ' =' :
588
595
break ;
589
596
default :
590
- throw lang::ErrorReport (c) << " Unimplemented reduction "
591
- << c.assignment ()->range ().text () << " \n " ;
597
+ throw lang::ErrorReport (comprehension)
598
+ << " Unimplemented reduction "
599
+ << comprehension.assignment ()->range ().text () << " \n " ;
592
600
}
593
601
594
602
// Tag reductions as such
595
- if (c .assignment ()->kind () != ' =' ) {
603
+ if (comprehension .assignment ()->kind () != ' =' ) {
596
604
rhs = reductionUpdate (rhs);
597
605
}
598
606
@@ -632,7 +640,7 @@ void translateComprehension(
632
640
Scope<Interval> solution;
633
641
634
642
// Put anything explicitly specified with a 'where' class in the solution
635
- for (auto constraint_ : c .whereClauses ()) {
643
+ for (auto constraint_ : comprehension .whereClauses ()) {
636
644
if (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
637
645
continue ;
638
646
auto constraint = lang::RangeConstraint (constraint_);
@@ -653,7 +661,8 @@ void translateComprehension(
653
661
654
662
// Infer the rest
655
663
all_exprs.push_back (rhs);
656
- forwardBoundsInference (all_exprs, *bounds, c, throwWarnings, &solution);
664
+ forwardBoundsInference (
665
+ all_exprs, *bounds, comprehension, throwWarnings, &solution);
657
666
658
667
// TODO: What if subsequent updates have incompatible bounds
659
668
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -664,7 +673,7 @@ void translateComprehension(
664
673
665
674
for (Var v : lhs) {
666
675
if (!solution.contains (v.name ())) {
667
- throw lang::ErrorReport (c )
676
+ throw lang::ErrorReport (comprehension )
668
677
<< " Free variable " << v
669
678
<< " was not solved in range inference. May not be used right-hand side" ;
670
679
}
@@ -688,7 +697,7 @@ void translateComprehension(
688
697
for (size_t i = 0 ; i < unbound.size (); i++) {
689
698
auto v = unbound[unbound.size () - 1 - i];
690
699
if (!solution.contains (v->name )) {
691
- throw lang::ErrorReport (c )
700
+ throw lang::ErrorReport (comprehension )
692
701
<< " Free variable " << v << " is unconstrained. "
693
702
<< " Use a 'where' clause to set its range." ;
694
703
}
@@ -736,6 +745,7 @@ void translateComprehension(
736
745
stage.reorder (loop_nest);
737
746
}
738
747
748
+ // Translate a semantically checked TC def to HalideComponents struct.
739
749
HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
740
750
map<string, Function> funcs;
741
751
HalideComponents components;
@@ -895,6 +905,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
895
905
lang::Def (lang::Sema ().checkFunction (treeRef)), throwWarnings);
896
906
}
897
907
908
+ // NOTE: there is no guarantee here that the tc string has only one def. It
909
+ // could have many defs. Only first def will be converted in that case.
898
910
HalideComponents
899
911
translate (isl::ctx ctx, const std::string& tc, bool throwWarnings) {
900
912
LOG_IF (INFO, tc::FLAGS_debug_halide) << tc;
0 commit comments