Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1f60110

Browse files
[Broken][WIP] Add Halide support for "for" loops
Disclaimer: this is WIP and does not work yet. In particular the ScheduleTree generated is incorrect (not yet properly scoped under for loop). Additionally, the example `For4InvalidHalide` seems to hit deeper structural issues of the high-level HalideIR that traditional polyhedral dependence analysis has no issue with (cc @abadams). This commit tests that the for-loop language construct properly propagates to the Halide bounds inference. This commit also explicitly adds checks that only a single RangeConstraint s`traiindex in min:max`) is used for each index. This requirement is added because it would be very surprising to mix explicit ranges coming from for loops with where clauses in a comprehension. In particular, the current behavior in Halide inference is to only keep the last RangeConstraint which systematically discards the information specified in the loop. `tc2halide` even says: ```// TODO: What if subsequent updates have incompatible bounds // (e.g. an in-place stencil)?. The .bound directive will use the // bounds of the last stage for all stages. ``` As a consequence we add an explicit check that multiple bounds specifications may not coexist.
1 parent b7359f4 commit 1f60110

File tree

3 files changed

+211
-26
lines changed

3 files changed

+211
-26
lines changed

tc/core/tc2halide.cc

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ using std::vector;
3333

3434
namespace {
3535

36+
using FunctionBounds = map<Function, map<string, Interval>, Function::Compare>;
37+
38+
struct TranslationUnit {
39+
HalideComponents components;
40+
Scope<Interval> enclosingLoopIndices;
41+
map<string, Function> funcs;
42+
FunctionBounds bounds;
43+
bool throwWarnings;
44+
};
45+
46+
void translateFor(const lang::For& f, TranslationUnit* tu);
47+
void translateComprehension(
48+
const lang::Comprehension& comprehension,
49+
TranslationUnit* tu);
50+
3651
Type translateScalarType(int tcType) {
3752
switch (tcType) {
3853
case lang::TK_BOOL:
@@ -264,8 +279,6 @@ vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
264279
return finder.result;
265280
}
266281

267-
typedef map<Function, map<string, Interval>, Function::Compare> FunctionBounds;
268-
269282
void forwardBoundsInference(
270283
const std::vector<Expr>& exprs,
271284
const FunctionBounds& bounds,
@@ -500,6 +513,29 @@ Expr reductionUpdate(Expr e) {
500513
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
501514
}
502515

516+
void translateStatement(const lang::TreeRef& stmt, TranslationUnit* tu) {
517+
if (stmt->kind() == lang::TK_COMPREHENSION) {
518+
translateComprehension(lang::Comprehension(stmt), tu);
519+
} else {
520+
CHECK_EQ(stmt->kind(), lang::TK_FOR);
521+
translateFor(lang::For(stmt), tu);
522+
}
523+
}
524+
525+
void translateFor(const lang::For& f, TranslationUnit* pTU) {
526+
const map<string, Parameter>& params = pTU->components.params;
527+
auto constraint = lang::RangeConstraint(f.rangeConstraint());
528+
Interval i;
529+
const map<string, Expr> lets;
530+
i.min = translateExpr(constraint.start(), params, pTU->funcs, lets);
531+
i.max = translateExpr(constraint.end(), params, pTU->funcs, lets) - 1;
532+
pTU->enclosingLoopIndices.push(f.index().name(), i);
533+
for (auto stm : f.statements()) {
534+
translateStatement(stm, pTU);
535+
}
536+
pTU->enclosingLoopIndices.pop(f.index().name());
537+
}
538+
503539
// Translate a single TC comprehension/statement to Halide components: funcs,
504540
// bounds, reductions.
505541
//
@@ -508,10 +544,11 @@ Expr reductionUpdate(Expr e) {
508544
// in order to be able to apply internal Halide analysis passes on them.
509545
void translateComprehension(
510546
const lang::Comprehension& comprehension,
511-
const map<string, Parameter>& params,
512-
bool throwWarnings,
513-
map<string, Function>* funcs,
514-
FunctionBounds* bounds) {
547+
TranslationUnit* pTU) {
548+
const map<string, Parameter>& params = pTU->components.params;
549+
bool throwWarnings = pTU->throwWarnings;
550+
map<string, Function>* funcs = &pTU->funcs;
551+
FunctionBounds* bounds = &pTU->bounds;
515552
Function f;
516553
auto it = funcs->find(comprehension.ident().name());
517554
if (it != funcs->end()) {
@@ -647,6 +684,12 @@ void translateComprehension(
647684
// demand).
648685
Scope<Interval> solution;
649686

687+
// Copy information from enclosing "for" loops
688+
for (auto entry = pTU->enclosingLoopIndices.cbegin();
689+
entry != pTU->enclosingLoopIndices.cend();
690+
++entry) {
691+
solution.push(entry.name(), entry.value());
692+
}
650693
// Put anything explicitly specified with a 'where' class in the solution
651694
for (auto constraint_ : comprehension.whereClauses()) {
652695
if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT)
@@ -656,6 +699,11 @@ void translateComprehension(
656699
i.min = translateExpr(constraint.start(), params, *funcs, lets);
657700
i.max = translateExpr(constraint.end(), params, *funcs, lets) - 1;
658701

702+
if (solution.contains(constraint.ident().name())) {
703+
throw lang::ErrorReport(constraint_)
704+
<< "Multiple range constraints per index NYI";
705+
}
706+
659707
// TODO: In the future we'll want to make any non-trivial bounds
660708
// into hidden scalar parameters, and just pass variables to the
661709
// polyhedral layer instead of potentially complex
@@ -755,25 +803,20 @@ void translateComprehension(
755803

756804
// Translate a semantically checked TC def to HalideComponents struct.
757805
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
758-
map<string, Function> funcs;
759-
HalideComponents components;
760-
components.def = def;
761-
FunctionBounds bounds;
806+
TranslationUnit tu;
807+
tu.components.def = def;
808+
tu.throwWarnings = throwWarnings;
762809

763810
for (auto p : def.params()) {
764-
translateParam(p, &components.params, &components.inputs);
811+
translateParam(p, &tu.components.params, &tu.components.inputs);
765812
}
766-
for (auto c : def.statements()) {
767-
translateComprehension(
768-
lang::Comprehension(c),
769-
components.params,
770-
throwWarnings,
771-
&funcs,
772-
&bounds);
813+
// Semantically valid TCs include at most one outer sequential loop for now
814+
for (auto stm : def.statements()) {
815+
translateStatement(stm, &tu);
773816
}
774817
vector<Function> outputs;
775818
for (auto p : def.returns()) {
776-
translateOutput(p, funcs, &outputs);
819+
translateOutput(p, tu.funcs, &outputs);
777820
}
778821

779822
// Now apply an extremely simplified version of Halide lowering
@@ -804,11 +847,12 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
804847
// used in the pipelines we construct here, so just make a host target.
805848
Target target("host");
806849
Stmt s = schedule_functions(outputs, fused_groups, env, target, any_memoized);
850+
LOG_IF(ERROR, tc::FLAGS_debug_halide) << s;
807851
// we insert these to allow for inplace mutation of in/out tensors
808852
s = remove_undef(s);
809853
// Apply forward bounds inference results. This replaces the usual Halide
810854
// bounds inference.
811-
for (auto p : bounds) {
855+
for (auto p : tu.bounds) {
812856
const Function& f = p.first;
813857
for (auto b : p.second) {
814858
const string& var = b.first;
@@ -893,20 +937,20 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
893937
};
894938
s = SubstituteAllLets().mutate(s);
895939

896-
components.stmt = s;
940+
tu.components.stmt = s;
897941

898942
for (Function f : outputs) {
899943
OutputImageParam o = Func(f).output_buffers()[0];
900944
// Apply forward bounds inference results to the output buffers.
901-
const auto& b = bounds[f];
945+
const auto& b = tu.bounds[f];
902946
for (int i = 0; i < o.dimensions(); i++) {
903947
const Interval& bound = b.at(f.args()[i]);
904948
o.dim(i).set_bounds(bound.min, simplify(bound.max - bound.min + 1));
905949
}
906-
components.outputs.push_back(o);
950+
tu.components.outputs.push_back(o);
907951
}
908952

909-
return components;
953+
return tu.components;
910954
}
911955
} // namespace
912956

test/test_inference.cc

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ using namespace std;
2424
using namespace lang;
2525

2626
struct InferenceTest : public ::testing::Test {
27-
void Check(const string& tc, const string& expected) {
27+
void Check(const string& tc, const string& expected, bool warn = true) {
2828
auto halideComponents =
29-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tc, true);
29+
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tc, warn);
3030

3131
stringstream ss;
3232
// Ordered map for repro
@@ -529,6 +529,94 @@ def fun(float(N) size) -> (O) {
529529
EXPECT_THROW(Check(tc, {}), ::lang::ErrorReport);
530530
}
531531

532+
TEST_F(InferenceTest, For1) {
533+
string tc = R"TC(
534+
def fun(float(M, N) X) -> (R1) {
535+
for t in 0:123 {
536+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0
537+
}
538+
}
539+
)TC";
540+
Check(tc, R"HALIDE(mins:
541+
X@[0; 0; ]
542+
R1@[0; 0; 0; ]
543+
extents:
544+
X@[M; N; ]
545+
R1@[123; M; N; ]
546+
)HALIDE");
547+
}
548+
549+
TEST_F(InferenceTest, For2) {
550+
string tc = R"TC(
551+
def fun(float(M, N) X, float(T) Meta) -> (R1, R2) {
552+
for t in 0:T {
553+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0
554+
R2(t, m, n) = R1(t, m, n)
555+
}
556+
}
557+
)TC";
558+
Check(
559+
tc,
560+
R"HALIDE(mins:
561+
X@[0; 0; ]
562+
Meta@[0; ]
563+
R1@[0; 0; 0; ]
564+
R2@[0; 0; 0; ]
565+
extents:
566+
X@[M; N; ]
567+
Meta@[T; ]
568+
R1@[T; M; N; ]
569+
R2@[T; M; N; ]
570+
)HALIDE");
571+
}
572+
573+
TEST_F(InferenceTest, For3) {
574+
string tc = R"TC(
575+
def fun(float(M, N) X, float(T) Meta) -> (R1, R2) {
576+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0 where t in 0:T
577+
R2(t, m, n) = 0.0 where t in 0:T, m in 0:M, n in 0:N
578+
for t in 1:T {
579+
R1(t, m, n) += R1(t-1, m, n)
580+
R2(t, m, n) += R1(t, m, n)
581+
}
582+
}
583+
)TC";
584+
Check(
585+
tc,
586+
R"HALIDE(mins:
587+
X@[0; 0; ]
588+
Meta@[0; ]
589+
R1@[1; 0; 0; ]
590+
R2@[1; 0; 0; ]
591+
extents:
592+
X@[M; N; ]
593+
Meta@[T; ]
594+
R1@[(T + -1); M; N; ]
595+
R2@[(T + -1); M; N; ]
596+
)HALIDE",
597+
false); // nowarn
598+
}
599+
600+
TEST_F(InferenceTest, MultipleRangeConstraints) {
601+
string tc = R"TC(
602+
def fun(float(N, M) I) -> (O) {
603+
O(i, j) = I(i, j) where i in 3:N-1, i in 1:N
604+
}
605+
)TC";
606+
EXPECT_THROW(Check(tc, {}), ::lang::ErrorReport);
607+
}
608+
609+
TEST_F(InferenceTest, MultipleRangeConstraintsFor) {
610+
string tc = R"TC(
611+
def fun(float(M, N) X) -> (R1) {
612+
for t in 0:123 {
613+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0 where t in 0:N
614+
}
615+
}
616+
)TC";
617+
EXPECT_THROW(Check(tc, {}), ::lang::ErrorReport);
618+
}
619+
532620
int main(int argc, char** argv) {
533621
::testing::InitGoogleTest(&argc, argv);
534622
::gflags::ParseCommandLineFlags(&argc, &argv, true);

test/test_tc2halide.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ struct TC2Isl : public ::testing::Test {
3838
isl::with_exceptions::globalIslCtx(), halide);
3939
auto scheduleHalide = scop->scheduleRoot();
4040

41+
LOG_IF(ERROR, FLAGS_debug_tc_mapper) << *scheduleHalide;
42+
4143
polyhedral::detail::validateSchedule(scheduleHalide);
4244
}
4345
};
@@ -198,6 +200,57 @@ def foo(float(N) A) -> (B) {
198200
EXPECT_THROW(Check(tc), ::lang::ErrorReport);
199201
}
200202

203+
TEST_F(TC2Isl, For1) {
204+
string tc = R"TC(
205+
def fun(float(M, N) X) -> (R1) {
206+
for t in 0:123 {
207+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0
208+
}
209+
}
210+
)TC";
211+
Check(tc);
212+
}
213+
214+
TEST_F(TC2Isl, For2) {
215+
string tc = R"TC(
216+
def fun(float(M, N) X, float(T) Meta) -> (R1, R2) {
217+
for t in 0:T {
218+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0
219+
R2(t, m, n) = R1(t, m, n)
220+
}
221+
}
222+
)TC";
223+
Check(tc);
224+
}
225+
226+
TEST_F(TC2Isl, For3) {
227+
string tc = R"TC(
228+
def fun(float(M, N) X, float(T) Meta) -> (R1, R2) {
229+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0 where t in 0:T
230+
R2(t, m, n) = 0.0 where t in 0:T, m in 0:M, n in 0:N
231+
for t in 1:T {
232+
R1(t, m, n) += R1(t-1, m, n)
233+
R2(t, m, n) += R1(t, m, n)
234+
}
235+
}
236+
)TC";
237+
Check(tc);
238+
}
239+
240+
TEST_F(TC2Isl, For4InvalidHalide) {
241+
string tc = R"TC(
242+
def fun(float(M, N) X, float(T) Meta) -> (R1, R2) {
243+
R1(t, m, n) = (t == 0) ? X(m, n) : 0.0 where t in 0:T
244+
R2(t, m, n) = 0.0 where t in 0:T, m in 0:M, n in 0:N
245+
for t in 1:T {
246+
R1(t, m, n) += R1(t-1, m, n) + R2(t-1, m, n)
247+
R2(t, m, n) += R1(t, m, n)
248+
}
249+
}
250+
)TC";
251+
EXPECT_THROW(Check(tc), std::exception);
252+
}
253+
201254
TEST_F(TC2Isl, Types) {
202255
for (auto type : {"bool",
203256
"uint8",

0 commit comments

Comments
 (0)