Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 1ee5f54

Browse files
committed
halide2isl: support min/max in upper/lower bounds
Loop bounds inferred by Halide may involve min or max operations, which are not convertible into affine functions. For the sake of bound computation, compute lists of affine functions by recursively flattening arguments of nested min and max operations. Use all affine functions from a list to define loop bounds. In particular, allow min in upper bounds and max in lower bounds to exploit x < min(a, min(b, c)) <=> x < a and x < b and x < c, x > max(a, max(b, c)) <=> x > a and x > b and x > c equivalences. Do not allow other cases.
1 parent a96a0e0 commit 1ee5f54

File tree

2 files changed

+129
-22
lines changed

2 files changed

+129
-22
lines changed

src/core/halide2isl.cc

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,44 +82,141 @@ isl::aff makeIslAffFromInt(isl::space space, int64_t val) {
8282
return isl::aff(isl::local_space(space), v);
8383
}
8484

85-
isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) {
85+
std::vector<isl::aff> makeIslAffBoundsFromExpr(
86+
isl::space space,
87+
const Expr& e,
88+
bool allowMin,
89+
bool allowMax);
90+
91+
namespace {
92+
/*
93+
* Convert Halide binary expression "op" into a list of isl affine functions by
94+
* converting its LHS and RHS into lists of affs and concatenating those lists.
95+
* This is intended to be used with Min/Max operations in upper/lower bound
96+
* computations, respectively. Essentially, this allows for replacements
97+
* x < min(a,min(b,c)) <=> x < a AND x < b AND x < c
98+
* x > max(a,max(b,c)) <=> x > a AND x > b AND x > c
99+
*/
100+
template <typename T>
101+
inline std::vector<isl::aff>
102+
concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
103+
std::vector<isl::aff> result;
104+
105+
for (const auto& aff :
106+
makeIslAffBoundsFromExpr(space, op->a, allowMin, allowMax)) {
107+
result.push_back(aff);
108+
}
109+
for (const auto& aff :
110+
makeIslAffBoundsFromExpr(space, op->b, allowMin, allowMax)) {
111+
result.push_back(aff);
112+
}
113+
114+
return result;
115+
}
116+
117+
/*
118+
* Convert Halide binary expression "op" into an isl affine function by
119+
* converting its LHS and RHS into affs and combining them with "combine"
120+
* into a single expression. LHS and RHS are expected to only produce a single
121+
* expression.
122+
* This is intended for use with operations other than Min/Max that do not
123+
* commute nicely in bounds, for example
124+
* x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
125+
*/
126+
template <typename T>
127+
inline isl::aff combineSingleAffs(
128+
isl::space space,
129+
T op,
130+
isl::aff (isl::aff::*combine)(isl::aff) const) {
131+
auto left = makeIslAffBoundsFromExpr(space, op->a, false, false);
132+
auto right = makeIslAffBoundsFromExpr(space, op->b, false, false);
133+
CHECK_EQ(left.size(), 1);
134+
CHECK_EQ(right.size(), 1);
135+
136+
return (left[0].*combine)(right[0]);
137+
}
138+
139+
} // end namespace
140+
141+
/*
142+
* Convert Halide expression into list of isl affine expressions usable for
143+
* defining constraints. In particular, an expression starting with (nested)
144+
* Max operations can be used for lower bounds
145+
* x > max(a,b) <=> x > a AND x > b
146+
* while an expression starting with (nested) Min operations can be used for
147+
* upper bounds
148+
* x < min(a,b) <=> x < a AND x < b.
149+
* Arguments "allowMin" and "allowMax" control whether Min and Max operations,
150+
* respecitvely, are allowed to be present in the expression. Note that they
151+
* can only appear before any other operation and cannot appear together in an
152+
* expression.
153+
* If a Halide expression cannot be converted into a list of affine expressions,
154+
* return an empty list.
155+
*/
156+
std::vector<isl::aff> makeIslAffBoundsFromExpr(
157+
isl::space space,
158+
const Expr& e,
159+
bool allowMin,
160+
bool allowMax) {
161+
CHECK(!(allowMin && allowMax));
162+
163+
const Min* minOp = e.as<Min>();
164+
const Max* maxOp = e.as<Max>();
165+
86166
if (const Variable* op = e.as<Variable>()) {
87167
isl::local_space ls = isl::local_space(space);
88168
int pos = space.find_dim_by_name(isl::dim_type::param, op->name);
89169
if (pos >= 0) {
90-
return isl::aff(ls, isl::dim_type::param, pos);
170+
return {isl::aff(ls, isl::dim_type::param, pos)};
91171
} else {
172+
// FIXME: thou shalt not rely upon set dimension names
92173
pos = space.find_dim_by_name(isl::dim_type::set, op->name);
93174
if (pos >= 0) {
94-
return isl::aff(ls, isl::dim_type::set, pos);
175+
return {isl::aff(ls, isl::dim_type::set, pos)};
95176
}
96177
}
97178
LOG(FATAL) << "Variable not found in isl::space: " << space << ": " << op
98179
<< ": " << op->name << '\n';
99-
return isl::aff();
180+
return {};
181+
} else if (minOp != nullptr && allowMin) {
182+
return concatAffs(space, minOp, allowMin, allowMax);
183+
} else if (maxOp != nullptr && allowMax) {
184+
return concatAffs(space, maxOp, allowMin, allowMax);
100185
} else if (const Add* op = e.as<Add>()) {
101-
return makeIslAffFromExpr(space, op->a)
102-
.add(makeIslAffFromExpr(space, op->b));
186+
return {combineSingleAffs(space, op, &isl::aff::add)};
103187
} else if (const Sub* op = e.as<Sub>()) {
104-
return makeIslAffFromExpr(space, op->a)
105-
.sub(makeIslAffFromExpr(space, op->b));
188+
return {combineSingleAffs(space, op, &isl::aff::sub)};
106189
} else if (const Mul* op = e.as<Mul>()) {
107-
return makeIslAffFromExpr(space, op->a)
108-
.mul(makeIslAffFromExpr(space, op->b));
190+
return {combineSingleAffs(space, op, &isl::aff::mul)};
109191
} else if (const Div* op = e.as<Div>()) {
110-
return makeIslAffFromExpr(space, op->a)
111-
.div(makeIslAffFromExpr(space, op->b));
192+
return {combineSingleAffs(space, op, &isl::aff::div)};
112193
} else if (const Mod* op = e.as<Mod>()) {
194+
std::vector<isl::aff> result;
195+
// We cannot span multiple constraints if a modulo operation is involved.
196+
// x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
197+
auto lhs = makeIslAffBoundsFromExpr(space, e, false, false);
198+
CHECK_EQ(lhs.size(), 1);
113199
if (const int64_t* b = as_const_int(op->b)) {
114-
return makeIslAffFromExpr(space, op->a)
115-
.mod(isl::val(space.get_ctx(), *b));
200+
return {lhs[0].mod(isl::val(space.get_ctx(), *b))};
116201
}
117202
} else if (const int64_t* i = as_const_int(e)) {
118-
return makeIslAffFromInt(space, *i);
203+
return {makeIslAffFromInt(space, *i)};
119204
}
120205

206+
return {};
207+
}
208+
209+
isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) {
210+
auto list = makeIslAffBoundsFromExpr(space, e, false, false);
211+
CHECK_LE(list.size(), 1) << "Halide expr " << e
212+
<< " unrolled into more than 1 isl aff"
213+
<< " but min/max operations were disabled";
214+
121215
// Non-affine
122-
return isl::aff();
216+
if (list.size() == 0) {
217+
return isl::aff();
218+
}
219+
return list[0];
123220
}
124221

125222
isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable) {
@@ -247,10 +344,22 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
247344
isl::local_space(set.get_space()), isl::dim_type::set, thisLoopIdx);
248345

249346
// Then we add our new loop bound constraints.
250-
isl::aff lb = halide2isl::makeIslAffFromExpr(set.get_space(), op->min);
347+
auto lbs = halide2isl::makeIslAffBoundsFromExpr(
348+
set.get_space(), op->min, false, true);
349+
CHECK_GT(lbs.size(), 0)
350+
<< "could not obtain polyhedral lower bounds from " << op->min;
351+
for (auto lb : lbs) {
352+
set = set.intersect(loopVar.ge_set(lb));
353+
}
354+
251355
Expr max = simplify(op->min + op->extent - 1);
252-
isl::aff ub = halide2isl::makeIslAffFromExpr(set.get_space(), max);
253-
set = set.intersect(loopVar.ge_set(lb).intersect(ub.ge_set(loopVar)));
356+
auto ubs =
357+
halide2isl::makeIslAffBoundsFromExpr(set.get_space(), max, true, false);
358+
CHECK_GT(ubs.size(), 0)
359+
<< "could not obtain polyhedral upper bounds from " << max;
360+
for (auto ub : ubs) {
361+
set = set.intersect(ub.ge_set(loopVar));
362+
}
254363

255364
// Recursively descend.
256365
auto body = makeScheduleTreeHelper(

test/test_tc_mapper_bugs.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,6 @@ TEST(LayerNorm, ReferenceBelongsToTwoGroups) {
696696
atCompl.compile("layernorm", inputs, options);
697697
}
698698

699-
// #124
700699
TEST(Halide2Isl, MinInUpperBound) {
701700
at::Tensor mat1 = at::CUDA(at::kFloat).rand({1, 100, 184, 184});
702701
at::Tensor mat1_pad = at::CUDA(at::kFloat).rand({1, 100, 186, 186});
@@ -713,8 +712,7 @@ TEST(Halide2Isl, MinInUpperBound) {
713712

714713
tc::ATenCompilationUnit atCompl;
715714
atCompl.define(TC);
716-
EXPECT_THROW(
717-
atCompl.compile("graph2", inputs, options), isl::exception_invalid);
715+
atCompl.compile("graph2", inputs, options);
718716
}
719717

720718
int main(int argc, char** argv) {

0 commit comments

Comments
 (0)