@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
62
62
}
63
63
}
64
64
65
+ // translate the TC def input params to corresponding Halide components.
66
+ // params, inputs will be populated here
65
67
void translateParam (
66
68
const lang::Param& p,
67
69
map<string, Parameter>* params,
68
70
vector<ImageParam>* inputs) {
71
+ // check if the param is already converted to halide components
69
72
if (params->find (p.ident ().name ()) != params->end ()) {
70
73
return ;
71
- } else {
72
- lang::TensorType type = p.tensorType ();
73
- int dimensions = (int )type.dims ().size ();
74
- ImageParam imageParam (
75
- translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
76
- inputs->push_back (imageParam);
77
- vector<Expr> dims;
78
- for (auto d_ : type.dims ()) {
79
- if (d_->kind () == lang::TK_IDENT) {
80
- auto d = lang::Ident (d_);
81
- auto it = params->find (d.name ());
82
- Parameter p;
83
- if (it != params->end ()) {
84
- p = it->second ;
85
- } else {
86
- p = Parameter (Int (32 ), false , 0 , d.name (), true );
87
- (*params)[d.name ()] = p;
88
- }
89
- dims.push_back (Variable::make (Int (32 ), p.name (), p));
74
+ }
75
+ lang::TensorType type = p.tensorType ();
76
+ int dimensions = (int )type.dims ().size ();
77
+ ImageParam imageParam (
78
+ translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79
+ inputs->push_back (imageParam);
80
+ vector<Expr> dims;
81
+ for (auto d_ : type.dims ()) {
82
+ if (d_->kind () == lang::TK_IDENT) {
83
+ auto d = lang::Ident (d_);
84
+ auto it = params->find (d.name ());
85
+ Parameter p;
86
+ if (it != params->end ()) {
87
+ p = it->second ;
90
88
} else {
91
- CHECK (d_->kind () == lang::TK_CONST);
92
- int32_t value = lang::Const (d_).value ();
93
- dims.push_back (Expr (value));
89
+ p = Parameter (Int (32 ), false , 0 , d.name (), true );
90
+ (*params)[d.name ()] = p;
94
91
}
92
+ dims.push_back (Variable::make (Int (32 ), p.name (), p));
93
+ } else {
94
+ CHECK (d_->kind () == lang::TK_CONST);
95
+ int32_t value = lang::Const (d_).value ();
96
+ dims.push_back (Expr (value));
95
97
}
98
+ }
96
99
97
- for (int i = 0 ; i < imageParam.dimensions (); i++) {
98
- imageParam.dim (i).set_bounds (0 , dims[i]);
99
- }
100
- (*params)[imageParam.name ()] = imageParam.parameter ();
100
+ for (int i = 0 ; i < imageParam.dimensions (); i++) {
101
+ imageParam.dim (i).set_bounds (0 , dims[i]);
101
102
}
103
+ (*params)[imageParam.name ()] = imageParam.parameter ();
102
104
}
103
105
104
106
void translateOutput (
@@ -156,6 +158,8 @@ Expr translateExpr(
156
158
return t (0 ) * t (1 );
157
159
case ' /' :
158
160
return t (0 ) / t (1 );
161
+ case ' %' :
162
+ return t (0 ) % t (1 );
159
163
case lang::TK_MIN:
160
164
return min (t (0 ), t (1 ));
161
165
case lang::TK_MAX:
@@ -492,20 +496,25 @@ Expr reductionUpdate(Expr e) {
492
496
return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
493
497
}
494
498
499
+ // translate a single TC comprehension/statement to Halide component.
500
+ // funcs, bounds, reductions will be populated
495
501
void translateComprehension (
496
- const lang::Comprehension& c ,
502
+ const lang::Comprehension& comprehension ,
497
503
const map<string, Parameter>& params,
498
504
bool throwWarnings,
499
505
map<string, Function>* funcs,
500
506
FunctionBounds* bounds,
501
507
vector<Function>* reductions) {
508
+ // Function is the internal Halide IR type for a pipeline
509
+ // stage. Func is the front-end class that wraps it. Here it's
510
+ // convenient to use both. Why? what is not exposed in Func?
502
511
Function f;
503
- auto it = funcs->find (c .ident ().name ());
512
+ auto it = funcs->find (comprehension .ident ().name ());
504
513
if (it != funcs->end ()) {
505
514
f = it->second ;
506
515
} else {
507
- f = Function (c .ident ().name ());
508
- (*funcs)[c .ident ().name ()] = f;
516
+ f = Function (comprehension .ident ().name ());
517
+ (*funcs)[comprehension .ident ().name ()] = f;
509
518
}
510
519
// Function is the internal Halide IR type for a pipeline
511
520
// stage. Func is the front-end class that wraps it. Here it's
@@ -514,7 +523,7 @@ void translateComprehension(
514
523
515
524
vector<Var> lhs;
516
525
vector<Expr> lhs_as_exprs;
517
- for (lang::Ident id : c .indices ()) {
526
+ for (lang::Ident id : comprehension .indices ()) {
518
527
lhs.push_back (Var (id.name ()));
519
528
lhs_as_exprs.push_back (lhs.back ());
520
529
}
@@ -523,17 +532,17 @@ void translateComprehension(
523
532
// in the future we may consider using Halide Let bindings when they
524
533
// are supported later
525
534
map<string, Expr> lets;
526
- for (auto wc : c .whereClauses ()) {
535
+ for (auto wc : comprehension .whereClauses ()) {
527
536
if (wc->kind () == lang::TK_LET) {
528
537
auto let = lang::Let (wc);
529
538
lets[let.name ().name ()] = translateExpr (let.rhs (), params, *funcs, lets);
530
539
}
531
540
}
532
541
533
- Expr rhs = translateExpr (c .rhs (), params, *funcs, lets);
542
+ Expr rhs = translateExpr (comprehension .rhs (), params, *funcs, lets);
534
543
535
544
std::vector<Expr> all_exprs;
536
- for (auto wc : c .whereClauses ()) {
545
+ for (auto wc : comprehension .whereClauses ()) {
537
546
if (wc->kind () == lang::TK_EXISTS) {
538
547
all_exprs.push_back (
539
548
translateExpr (lang::Exists (wc).exp (), params, *funcs, lets));
@@ -557,7 +566,7 @@ void translateComprehension(
557
566
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
558
567
// for the reduction and then applies the reduction.
559
568
bool should_zero = false ;
560
- switch (c .assignment ()->kind ()) {
569
+ switch (comprehension .assignment ()->kind ()) {
561
570
case lang::TK_PLUS_EQ_B:
562
571
should_zero = true ; // fallthrough
563
572
case lang::TK_PLUS_EQ:
@@ -589,11 +598,12 @@ void translateComprehension(
589
598
case ' =' :
590
599
break ;
591
600
default :
592
- throw lang::ErrorReport (c) << " Unimplemented reduction "
593
- << c.assignment ()->range ().text () << " \n " ;
601
+ throw lang::ErrorReport (comprehension)
602
+ << " Unimplemented reduction "
603
+ << comprehension.assignment ()->range ().text () << " \n " ;
594
604
}
595
605
596
- if (c .assignment ()->kind () != ' =' ) {
606
+ if (comprehension .assignment ()->kind () != ' =' ) {
597
607
reductions->push_back (f);
598
608
}
599
609
@@ -633,7 +643,7 @@ void translateComprehension(
633
643
Scope<Interval> solution;
634
644
635
645
// Put anything explicitly specified with a 'where' class in the solution
636
- for (auto constraint_ : c .whereClauses ()) {
646
+ for (auto constraint_ : comprehension .whereClauses ()) {
637
647
if (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
638
648
continue ;
639
649
auto constraint = lang::RangeConstraint (constraint_);
@@ -654,7 +664,8 @@ void translateComprehension(
654
664
655
665
// Infer the rest
656
666
all_exprs.push_back (rhs);
657
- forwardBoundsInference (all_exprs, *bounds, c, throwWarnings, &solution);
667
+ forwardBoundsInference (
668
+ all_exprs, *bounds, comprehension, throwWarnings, &solution);
658
669
659
670
// TODO: What if subsequent updates have incompatible bounds
660
671
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -665,7 +676,7 @@ void translateComprehension(
665
676
666
677
for (Var v : lhs) {
667
678
if (!solution.contains (v.name ())) {
668
- throw lang::ErrorReport (c )
679
+ throw lang::ErrorReport (comprehension )
669
680
<< " Free variable " << v
670
681
<< " was not solved in range inference. May not be used right-hand side" ;
671
682
}
@@ -689,7 +700,7 @@ void translateComprehension(
689
700
for (size_t i = 0 ; i < unbound.size (); i++) {
690
701
auto v = unbound[unbound.size () - 1 - i];
691
702
if (!solution.contains (v->name )) {
692
- throw lang::ErrorReport (c )
703
+ throw lang::ErrorReport (comprehension )
693
704
<< " Free variable " << v << " is unconstrained. "
694
705
<< " Use a 'where' clause to set its range." ;
695
706
}
@@ -737,6 +748,7 @@ void translateComprehension(
737
748
stage.reorder (loop_nest);
738
749
}
739
750
751
+ // translate a semantically checked TC def to Halide components struct
740
752
HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
741
753
map<string, Function> funcs;
742
754
HalideComponents components;
@@ -956,6 +968,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
956
968
lang::Def (lang::Sema ().checkFunction (treeRef)), throwWarnings);
957
969
}
958
970
971
+ // NOTE: there is no guarantee here that the tc string has only one def. It
972
+ // could have many defs. Only first def will be converted in that case.
959
973
HalideComponents
960
974
translate (isl::ctx ctx, const std::string& tc, bool throwWarnings) {
961
975
LOG_IF (INFO, tc::FLAGS_debug_halide) << tc;
0 commit comments