@@ -33,6 +33,21 @@ using std::vector;
33
33
34
34
namespace {
35
35
36
+ using FunctionBounds = map<Function, map<string, Interval>, Function::Compare>;
37
+
38
+ struct TranslationUnit {
39
+ HalideComponents components;
40
+ Scope<Interval> enclosingLoopIndices;
41
+ map<string, Function> funcs;
42
+ FunctionBounds bounds;
43
+ bool throwWarnings;
44
+ };
45
+
46
+ void translateFor (const lang::For& f, TranslationUnit* tu);
47
+ void translateComprehension (
48
+ const lang::Comprehension& comprehension,
49
+ TranslationUnit* tu);
50
+
36
51
Type translateScalarType (int tcType) {
37
52
switch (tcType) {
38
53
case lang::TK_BOOL:
@@ -264,8 +279,6 @@ vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
264
279
return finder.result ;
265
280
}
266
281
267
- typedef map<Function, map<string, Interval>, Function::Compare> FunctionBounds;
268
-
269
282
void forwardBoundsInference (
270
283
const std::vector<Expr>& exprs,
271
284
const FunctionBounds& bounds,
@@ -500,6 +513,29 @@ Expr reductionUpdate(Expr e) {
500
513
return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
501
514
}
502
515
516
+ void translateStatement (const lang::TreeRef& stmt, TranslationUnit* tu) {
517
+ if (stmt->kind () == lang::TK_COMPREHENSION) {
518
+ translateComprehension (lang::Comprehension (stmt), tu);
519
+ } else {
520
+ CHECK_EQ (stmt->kind (), lang::TK_FOR);
521
+ translateFor (lang::For (stmt), tu);
522
+ }
523
+ }
524
+
525
+ void translateFor (const lang::For& f, TranslationUnit* pTU) {
526
+ const map<string, Parameter>& params = pTU->components .params ;
527
+ auto constraint = lang::RangeConstraint (f.rangeConstraint ());
528
+ Interval i;
529
+ const map<string, Expr> lets;
530
+ i.min = translateExpr (constraint.start (), params, pTU->funcs , lets);
531
+ i.max = translateExpr (constraint.end (), params, pTU->funcs , lets) - 1 ;
532
+ pTU->enclosingLoopIndices .push (f.index ().name (), i);
533
+ for (auto stm : f.statements ()) {
534
+ translateStatement (stm, pTU);
535
+ }
536
+ pTU->enclosingLoopIndices .pop (f.index ().name ());
537
+ }
538
+
503
539
// Translate a single TC comprehension/statement to Halide components: funcs,
504
540
// bounds, reductions.
505
541
//
@@ -508,10 +544,11 @@ Expr reductionUpdate(Expr e) {
508
544
// in order to be able to apply internal Halide analysis passes on them.
509
545
void translateComprehension (
510
546
const lang::Comprehension& comprehension,
511
- const map<string, Parameter>& params,
512
- bool throwWarnings,
513
- map<string, Function>* funcs,
514
- FunctionBounds* bounds) {
547
+ TranslationUnit* pTU) {
548
+ const map<string, Parameter>& params = pTU->components .params ;
549
+ bool throwWarnings = pTU->throwWarnings ;
550
+ map<string, Function>* funcs = &pTU->funcs ;
551
+ FunctionBounds* bounds = &pTU->bounds ;
515
552
Function f;
516
553
auto it = funcs->find (comprehension.ident ().name ());
517
554
if (it != funcs->end ()) {
@@ -647,6 +684,12 @@ void translateComprehension(
647
684
// demand).
648
685
Scope<Interval> solution;
649
686
687
+ // Copy information from enclosing "for" loops
688
+ for (auto entry = pTU->enclosingLoopIndices .cbegin ();
689
+ entry != pTU->enclosingLoopIndices .cend ();
690
+ ++entry) {
691
+ solution.push (entry.name (), entry.value ());
692
+ }
650
693
// Put anything explicitly specified with a 'where' class in the solution
651
694
for (auto constraint_ : comprehension.whereClauses ()) {
652
695
if (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
@@ -656,6 +699,11 @@ void translateComprehension(
656
699
i.min = translateExpr (constraint.start (), params, *funcs, lets);
657
700
i.max = translateExpr (constraint.end (), params, *funcs, lets) - 1 ;
658
701
702
+ if (solution.contains (constraint.ident ().name ())) {
703
+ throw lang::ErrorReport (constraint_)
704
+ << " Multiple range constraints per index NYI" ;
705
+ }
706
+
659
707
// TODO: In the future we'll want to make any non-trivial bounds
660
708
// into hidden scalar parameters, and just pass variables to the
661
709
// polyhedral layer instead of potentially complex
@@ -755,25 +803,20 @@ void translateComprehension(
755
803
756
804
// Translate a semantically checked TC def to HalideComponents struct.
757
805
HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
758
- map<string, Function> funcs;
759
- HalideComponents components;
760
- components.def = def;
761
- FunctionBounds bounds;
806
+ TranslationUnit tu;
807
+ tu.components .def = def;
808
+ tu.throwWarnings = throwWarnings;
762
809
763
810
for (auto p : def.params ()) {
764
- translateParam (p, &components.params , &components.inputs );
811
+ translateParam (p, &tu. components .params , &tu. components .inputs );
765
812
}
766
- for (auto c : def.statements ()) {
767
- translateComprehension (
768
- lang::Comprehension (c),
769
- components.params ,
770
- throwWarnings,
771
- &funcs,
772
- &bounds);
813
+ // Semantically valid TCs include at most one outer sequential loop for now
814
+ for (auto stm : def.statements ()) {
815
+ translateStatement (stm, &tu);
773
816
}
774
817
vector<Function> outputs;
775
818
for (auto p : def.returns ()) {
776
- translateOutput (p, funcs, &outputs);
819
+ translateOutput (p, tu. funcs , &outputs);
777
820
}
778
821
779
822
// Now apply an extremely simplified version of Halide lowering
@@ -804,11 +847,12 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
804
847
// used in the pipelines we construct here, so just make a host target.
805
848
Target target (" host" );
806
849
Stmt s = schedule_functions (outputs, fused_groups, env, target, any_memoized);
850
+ LOG_IF (ERROR, tc::FLAGS_debug_halide) << s;
807
851
// we insert these to allow for inplace mutation of in/out tensors
808
852
s = remove_undef (s);
809
853
// Apply forward bounds inference results. This replaces the usual Halide
810
854
// bounds inference.
811
- for (auto p : bounds) {
855
+ for (auto p : tu. bounds ) {
812
856
const Function& f = p.first ;
813
857
for (auto b : p.second ) {
814
858
const string& var = b.first ;
@@ -893,20 +937,20 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
893
937
};
894
938
s = SubstituteAllLets ().mutate (s);
895
939
896
- components.stmt = s;
940
+ tu. components .stmt = s;
897
941
898
942
for (Function f : outputs) {
899
943
OutputImageParam o = Func (f).output_buffers ()[0 ];
900
944
// Apply forward bounds inference results to the output buffers.
901
- const auto & b = bounds[f];
945
+ const auto & b = tu. bounds [f];
902
946
for (int i = 0 ; i < o.dimensions (); i++) {
903
947
const Interval& bound = b.at (f.args ()[i]);
904
948
o.dim (i).set_bounds (bound.min , simplify (bound.max - bound.min + 1 ));
905
949
}
906
- components.outputs .push_back (o);
950
+ tu. components .outputs .push_back (o);
907
951
}
908
952
909
- return components;
953
+ return tu. components ;
910
954
}
911
955
} // namespace
912
956
0 commit comments