Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit fd38b60

Browse files
committed
halide2isl: handle non-affine subexpressions
1ee5f54 modified the conversion from Halide Exprs to isl affs to add some basic support for min/max operations. When processing binary operators in Halide Exprs, it did not take into account that one of the operands may be non-affine, resulting in no affine boundary being computed. The binary combinator assumed both LHS and RHS bounds exist. Remove this assumption and remove an empty affine bound if either of LHS or RHS bound expressions is not affine.
1 parent b65f5ef commit fd38b60

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

tc/core/halide2isl.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,23 +117,29 @@ concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
117117
/*
118118
* Convert Halide binary expression "op" into an isl affine function by
119119
* converting its LHS and RHS into affs and combining them with "combine"
120-
* into a single expression. LHS and RHS are expected to only produce a single
121-
* expression.
120+
* into a single expression. LHS and RHS are expected to only produce at most
121+
* one expression. If either of them produces zero expressions, meaning the
122+
* bound is not affine, return an empty vector. Otherwise return a vector with
123+
* a single expression that is the result of applying LHS.combine(RHS).
122124
* This is intended for use with operations other than Min/Max that do not
123125
* commute nicely in bounds, for example
124126
* x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
125127
*/
126128
template <typename T>
127-
inline isl::aff combineSingleAffs(
129+
inline std::vector<isl::aff> combineSingleAffs(
128130
isl::space space,
129131
T op,
130132
isl::aff (isl::aff::*combine)(isl::aff) const) {
131133
auto left = makeIslAffBoundsFromExpr(space, op->a, false, false);
132134
auto right = makeIslAffBoundsFromExpr(space, op->b, false, false);
133-
CHECK_EQ(left.size(), 1u);
134-
CHECK_EQ(right.size(), 1u);
135+
CHECK_LE(left.size(), 1u);
136+
CHECK_LE(right.size(), 1u);
137+
138+
if (left.size() == 0 || right.size() == 0) {
139+
return {};
140+
}
135141

136-
return (left[0].*combine)(right[0]);
142+
return {(left[0].*combine)(right[0])};
137143
}
138144

139145
} // end namespace
@@ -183,13 +189,13 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
183189
} else if (maxOp != nullptr && allowMax) {
184190
return concatAffs(space, maxOp, allowMin, allowMax);
185191
} else if (const Add* op = e.as<Add>()) {
186-
return {combineSingleAffs(space, op, &isl::aff::add)};
192+
return combineSingleAffs(space, op, &isl::aff::add);
187193
} else if (const Sub* op = e.as<Sub>()) {
188-
return {combineSingleAffs(space, op, &isl::aff::sub)};
194+
return combineSingleAffs(space, op, &isl::aff::sub);
189195
} else if (const Mul* op = e.as<Mul>()) {
190-
return {combineSingleAffs(space, op, &isl::aff::mul)};
196+
return combineSingleAffs(space, op, &isl::aff::mul);
191197
} else if (const Div* op = e.as<Div>()) {
192-
return {combineSingleAffs(space, op, &isl::aff::div)};
198+
return combineSingleAffs(space, op, &isl::aff::div);
193199
} else if (const Mod* op = e.as<Mod>()) {
194200
std::vector<isl::aff> result;
195201
// We cannot span multiple constraints if a modulo operation is involved.

test/test_cuda_mapper.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,49 @@ TEST_F(PolyhedralMapperTest, ReductionMM2D) {
710710
EXPECT_TRUE(code.find("C[(t1 + c0)][(t0 + c1)] = (C") != std::string::npos);
711711
}
712712

713+
/*
714+
* Check that a subscript with affine and non-affine parts is handled by the
715+
* Halide to isl conversion, in particular that the conversion does not crash.
716+
*/
717+
TEST_F(PolyhedralMapperTest, NonAffineBoundLHSInBinOp) {
718+
string tc = R"TC(
719+
def shiftedLut(float(E, D) LUT, int32(B, L) I) -> (O) {
720+
O(i, j) +=! LUT(I(i, k) + 1, j)
721+
}
722+
)TC";
723+
// This triggers tc2halide conversion and should not throw.
724+
Prepare(tc);
725+
}
726+
727+
/*
728+
* Check that a subscript with affine and non-affine parts is handled by the
729+
* Halide to isl conversion, in particular that the conversion does not crash.
730+
*/
731+
TEST_F(PolyhedralMapperTest, NonAffineBoundRHSInBinOp) {
732+
string tc = R"TC(
733+
def shiftedLut(float(E, D) LUT, int32(B, L) I) -> (O) {
734+
O(i, j) +=! LUT(1 + j + I(i, k), j)
735+
}
736+
)TC";
737+
// This triggers tc2halide conversion and should not throw.
738+
Prepare(tc);
739+
}
740+
741+
/*
742+
* Check that a subscript with affine and non-affine parts is handled by the
743+
* Halide to isl conversion, in particular that the conversion does not crash.
744+
*/
745+
TEST_F(PolyhedralMapperTest, PerforatedConvolution) {
746+
string tc = R"TC(
747+
def perforatedConvolution(float(N, C, H, W) input, float(M, C, KH, KW) weights,
748+
int32(N, L) index) -> (output) {
749+
output(n, m, l) +=! input(n, c, index(n, l) + kh, index(n, l) + kw) * weights(m, c, kh, kw) where l in 0:L
750+
}
751+
)TC";
752+
// This triggers tc2halide conversion and should not throw.
753+
Prepare(tc);
754+
}
755+
713756
int main(int argc, char** argv) {
714757
::testing::InitGoogleTest(&argc, argv);
715758
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)