Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 77ddb29

Browse files
Merge pull request #378 from facebookresearch/fix-indexing-more
halide2isl: handle non-affine subexpressions
2 parents 0ccddc5 + fd38b60 commit 77ddb29

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

tc/core/halide2isl.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,23 +117,29 @@ concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
117117
/*
118118
* Convert Halide binary expression "op" into an isl affine function by
119119
* converting its LHS and RHS into affs and combining them with "combine"
120-
* into a single expression. LHS and RHS are expected to only produce a single
121-
* expression.
120+
* into a single expression. LHS and RHS are expected to only produce at most
121+
* one expression. If either of them produces zero expressions, meaning the
122+
* bound is not affine, return an empty vector. Otherwise return a vector with
123+
* a single expression that is the result of applying LHS.combine(RHS).
122124
* This is intended for use with operations other than Min/Max that do not
123125
* commute nicely in bounds, for example
124126
* x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
125127
*/
126128
template <typename T>
127-
inline isl::aff combineSingleAffs(
129+
inline std::vector<isl::aff> combineSingleAffs(
128130
isl::space space,
129131
T op,
130132
isl::aff (isl::aff::*combine)(isl::aff) const) {
131133
auto left = makeIslAffBoundsFromExpr(space, op->a, false, false);
132134
auto right = makeIslAffBoundsFromExpr(space, op->b, false, false);
133-
CHECK_EQ(left.size(), 1u);
134-
CHECK_EQ(right.size(), 1u);
135+
CHECK_LE(left.size(), 1u);
136+
CHECK_LE(right.size(), 1u);
137+
138+
if (left.size() == 0 || right.size() == 0) {
139+
return {};
140+
}
135141

136-
return (left[0].*combine)(right[0]);
142+
return {(left[0].*combine)(right[0])};
137143
}
138144

139145
} // end namespace
@@ -183,13 +189,13 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
183189
} else if (maxOp != nullptr && allowMax) {
184190
return concatAffs(space, maxOp, allowMin, allowMax);
185191
} else if (const Add* op = e.as<Add>()) {
186-
return {combineSingleAffs(space, op, &isl::aff::add)};
192+
return combineSingleAffs(space, op, &isl::aff::add);
187193
} else if (const Sub* op = e.as<Sub>()) {
188-
return {combineSingleAffs(space, op, &isl::aff::sub)};
194+
return combineSingleAffs(space, op, &isl::aff::sub);
189195
} else if (const Mul* op = e.as<Mul>()) {
190-
return {combineSingleAffs(space, op, &isl::aff::mul)};
196+
return combineSingleAffs(space, op, &isl::aff::mul);
191197
} else if (const Div* op = e.as<Div>()) {
192-
return {combineSingleAffs(space, op, &isl::aff::div)};
198+
return combineSingleAffs(space, op, &isl::aff::div);
193199
} else if (const Mod* op = e.as<Mod>()) {
194200
std::vector<isl::aff> result;
195201
// We cannot span multiple constraints if a modulo operation is involved.

test/test_cuda_mapper.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,49 @@ TEST_F(PolyhedralMapperTest, ReductionMM2D) {
710710
EXPECT_TRUE(code.find("C[(t1 + c0)][(t0 + c1)] = (C") != std::string::npos);
711711
}
712712

713+
/*
714+
* Check that a subscript with affine and non-affine parts is handled by the
715+
* Halide to isl conversion, in particular that the conversion does not crash.
716+
*/
717+
TEST_F(PolyhedralMapperTest, NonAffineBoundLHSInBinOp) {
718+
string tc = R"TC(
719+
def shiftedLut(float(E, D) LUT, int32(B, L) I) -> (O) {
720+
O(i, j) +=! LUT(I(i, k) + 1, j)
721+
}
722+
)TC";
723+
// This triggers tc2halide conversion and should not throw.
724+
Prepare(tc);
725+
}
726+
727+
/*
728+
* Check that a subscript with affine and non-affine parts is handled by the
729+
* Halide to isl conversion, in particular that the conversion does not crash.
730+
*/
731+
TEST_F(PolyhedralMapperTest, NonAffineBoundRHSInBinOp) {
732+
string tc = R"TC(
733+
def shiftedLut(float(E, D) LUT, int32(B, L) I) -> (O) {
734+
O(i, j) +=! LUT(1 + j + I(i, k), j)
735+
}
736+
)TC";
737+
// This triggers tc2halide conversion and should not throw.
738+
Prepare(tc);
739+
}
740+
741+
/*
742+
* Check that a subscript with affine and non-affine parts is handled by the
743+
* Halide to isl conversion, in particular that the conversion does not crash.
744+
*/
745+
TEST_F(PolyhedralMapperTest, PerforatedConvolution) {
746+
string tc = R"TC(
747+
def perforatedConvolution(float(N, C, H, W) input, float(M, C, KH, KW) weights,
748+
int32(N, L) index) -> (output) {
749+
output(n, m, l) +=! input(n, c, index(n, l) + kh, index(n, l) + kw) * weights(m, c, kh, kw) where l in 0:L
750+
}
751+
)TC";
752+
// This triggers tc2halide conversion and should not throw.
753+
Prepare(tc);
754+
}
755+
713756
int main(int argc, char** argv) {
714757
::testing::InitGoogleTest(&argc, argv);
715758
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)