Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit dbb909c

Browse files
authored
Merge pull request #127 from facebookresearch/support-min-max-bounds
[hotfix] Support min max bounds
2 parents cfbf098 + 1ee5f54 commit dbb909c

File tree

2 files changed

+147
-19
lines changed

2 files changed

+147
-19
lines changed

src/core/halide2isl.cc

Lines changed: 128 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,44 +82,141 @@ isl::aff makeIslAffFromInt(isl::space space, int64_t val) {
8282
return isl::aff(isl::local_space(space), v);
8383
}
8484

85-
isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) {
85+
std::vector<isl::aff> makeIslAffBoundsFromExpr(
86+
isl::space space,
87+
const Expr& e,
88+
bool allowMin,
89+
bool allowMax);
90+
91+
namespace {
92+
/*
93+
* Convert Halide binary expression "op" into a list of isl affine functions by
94+
* converting its LHS and RHS into lists of affs and concatenating those lists.
95+
* This is intended to be used with Min/Max operations in upper/lower bound
96+
* computations, respectively. Essentially, this allows for replacements
97+
* x < min(a,min(b,c)) <=> x < a AND x < b AND x < c
98+
* x > max(a,max(b,c)) <=> x > a AND x > b AND x > c
99+
*/
100+
template <typename T>
101+
inline std::vector<isl::aff>
102+
concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
103+
std::vector<isl::aff> result;
104+
105+
for (const auto& aff :
106+
makeIslAffBoundsFromExpr(space, op->a, allowMin, allowMax)) {
107+
result.push_back(aff);
108+
}
109+
for (const auto& aff :
110+
makeIslAffBoundsFromExpr(space, op->b, allowMin, allowMax)) {
111+
result.push_back(aff);
112+
}
113+
114+
return result;
115+
}
116+
117+
/*
118+
* Convert Halide binary expression "op" into an isl affine function by
119+
* converting its LHS and RHS into affs and combining them with "combine"
120+
* into a single expression. LHS and RHS are expected to only produce a single
121+
* expression.
122+
* This is intended for use with operations other than Min/Max that do not
123+
* commute nicely in bounds, for example
124+
* x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
125+
*/
126+
template <typename T>
127+
inline isl::aff combineSingleAffs(
128+
isl::space space,
129+
T op,
130+
isl::aff (isl::aff::*combine)(isl::aff) const) {
131+
auto left = makeIslAffBoundsFromExpr(space, op->a, false, false);
132+
auto right = makeIslAffBoundsFromExpr(space, op->b, false, false);
133+
CHECK_EQ(left.size(), 1);
134+
CHECK_EQ(right.size(), 1);
135+
136+
return (left[0].*combine)(right[0]);
137+
}
138+
139+
} // end namespace
140+
141+
/*
142+
* Convert Halide expression into list of isl affine expressions usable for
143+
* defining constraints. In particular, an expression starting with (nested)
144+
* Max operations can be used for lower bounds
145+
* x > max(a,b) <=> x > a AND x > b
146+
* while an expression starting with (nested) Min operations can be used for
147+
* upper bounds
148+
* x < min(a,b) <=> x < a AND x < b.
149+
* Arguments "allowMin" and "allowMax" control whether Min and Max operations,
150+
* respecitvely, are allowed to be present in the expression. Note that they
151+
* can only appear before any other operation and cannot appear together in an
152+
* expression.
153+
* If a Halide expression cannot be converted into a list of affine expressions,
154+
* return an empty list.
155+
*/
156+
std::vector<isl::aff> makeIslAffBoundsFromExpr(
157+
isl::space space,
158+
const Expr& e,
159+
bool allowMin,
160+
bool allowMax) {
161+
CHECK(!(allowMin && allowMax));
162+
163+
const Min* minOp = e.as<Min>();
164+
const Max* maxOp = e.as<Max>();
165+
86166
if (const Variable* op = e.as<Variable>()) {
87167
isl::local_space ls = isl::local_space(space);
88168
int pos = space.find_dim_by_name(isl::dim_type::param, op->name);
89169
if (pos >= 0) {
90-
return isl::aff(ls, isl::dim_type::param, pos);
170+
return {isl::aff(ls, isl::dim_type::param, pos)};
91171
} else {
172+
// FIXME: thou shalt not rely upon set dimension names
92173
pos = space.find_dim_by_name(isl::dim_type::set, op->name);
93174
if (pos >= 0) {
94-
return isl::aff(ls, isl::dim_type::set, pos);
175+
return {isl::aff(ls, isl::dim_type::set, pos)};
95176
}
96177
}
97178
LOG(FATAL) << "Variable not found in isl::space: " << space << ": " << op
98179
<< ": " << op->name << '\n';
99-
return isl::aff();
180+
return {};
181+
} else if (minOp != nullptr && allowMin) {
182+
return concatAffs(space, minOp, allowMin, allowMax);
183+
} else if (maxOp != nullptr && allowMax) {
184+
return concatAffs(space, maxOp, allowMin, allowMax);
100185
} else if (const Add* op = e.as<Add>()) {
101-
return makeIslAffFromExpr(space, op->a)
102-
.add(makeIslAffFromExpr(space, op->b));
186+
return {combineSingleAffs(space, op, &isl::aff::add)};
103187
} else if (const Sub* op = e.as<Sub>()) {
104-
return makeIslAffFromExpr(space, op->a)
105-
.sub(makeIslAffFromExpr(space, op->b));
188+
return {combineSingleAffs(space, op, &isl::aff::sub)};
106189
} else if (const Mul* op = e.as<Mul>()) {
107-
return makeIslAffFromExpr(space, op->a)
108-
.mul(makeIslAffFromExpr(space, op->b));
190+
return {combineSingleAffs(space, op, &isl::aff::mul)};
109191
} else if (const Div* op = e.as<Div>()) {
110-
return makeIslAffFromExpr(space, op->a)
111-
.div(makeIslAffFromExpr(space, op->b));
192+
return {combineSingleAffs(space, op, &isl::aff::div)};
112193
} else if (const Mod* op = e.as<Mod>()) {
194+
std::vector<isl::aff> result;
195+
// We cannot span multiple constraints if a modulo operation is involved.
196+
// x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
197+
auto lhs = makeIslAffBoundsFromExpr(space, e, false, false);
198+
CHECK_EQ(lhs.size(), 1);
113199
if (const int64_t* b = as_const_int(op->b)) {
114-
return makeIslAffFromExpr(space, op->a)
115-
.mod(isl::val(space.get_ctx(), *b));
200+
return {lhs[0].mod(isl::val(space.get_ctx(), *b))};
116201
}
117202
} else if (const int64_t* i = as_const_int(e)) {
118-
return makeIslAffFromInt(space, *i);
203+
return {makeIslAffFromInt(space, *i)};
119204
}
120205

206+
return {};
207+
}
208+
209+
isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) {
210+
auto list = makeIslAffBoundsFromExpr(space, e, false, false);
211+
CHECK_LE(list.size(), 1) << "Halide expr " << e
212+
<< " unrolled into more than 1 isl aff"
213+
<< " but min/max operations were disabled";
214+
121215
// Non-affine
122-
return isl::aff();
216+
if (list.size() == 0) {
217+
return isl::aff();
218+
}
219+
return list[0];
123220
}
124221

125222
isl::space makeParamSpace(isl::ctx ctx, const SymbolTable& symbolTable) {
@@ -247,10 +344,22 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
247344
isl::local_space(set.get_space()), isl::dim_type::set, thisLoopIdx);
248345

249346
// Then we add our new loop bound constraints.
250-
isl::aff lb = halide2isl::makeIslAffFromExpr(set.get_space(), op->min);
347+
auto lbs = halide2isl::makeIslAffBoundsFromExpr(
348+
set.get_space(), op->min, false, true);
349+
CHECK_GT(lbs.size(), 0)
350+
<< "could not obtain polyhedral lower bounds from " << op->min;
351+
for (auto lb : lbs) {
352+
set = set.intersect(loopVar.ge_set(lb));
353+
}
354+
251355
Expr max = simplify(op->min + op->extent - 1);
252-
isl::aff ub = halide2isl::makeIslAffFromExpr(set.get_space(), max);
253-
set = set.intersect(loopVar.ge_set(lb).intersect(ub.ge_set(loopVar)));
356+
auto ubs =
357+
halide2isl::makeIslAffBoundsFromExpr(set.get_space(), max, true, false);
358+
CHECK_GT(ubs.size(), 0)
359+
<< "could not obtain polyhedral upper bounds from " << max;
360+
for (auto ub : ubs) {
361+
set = set.intersect(ub.ge_set(loopVar));
362+
}
254363

255364
// Recursively descend.
256365
auto body = makeScheduleTreeHelper(

test/test_tc_mapper_bugs.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,25 @@ TEST(LayerNorm, ReferenceBelongsToTwoGroups) {
696696
atCompl.compile("layernorm", inputs, options);
697697
}
698698

699+
TEST(Halide2Isl, MinInUpperBound) {
700+
at::Tensor mat1 = at::CUDA(at::kFloat).rand({1, 100, 184, 184});
701+
at::Tensor mat1_pad = at::CUDA(at::kFloat).rand({1, 100, 186, 186});
702+
at::Tensor mat2 = at::CUDA(at::kFloat).rand({3, 3});
703+
std::vector<at::Tensor> inputs = {mat1, mat1_pad, mat2};
704+
705+
static constexpr auto TC = R"TC(
706+
def graph2(float(N, C, H, W) I, float(N, C, R, T) J, float(KH, KW) W1) -> (O, Out) {
707+
O(n, c, h, w) +=! J(n, c, h + kh, w + kw) * W1(kh, kw)
708+
Out(i, j) +=! I(n, i, h, w) * O(n, j, h, w)
709+
}
710+
)TC";
711+
auto options = tc::MappingOptions::makeNaiveMappingOptions();
712+
713+
tc::ATenCompilationUnit atCompl;
714+
atCompl.define(TC);
715+
atCompl.compile("graph2", inputs, options);
716+
}
717+
699718
int main(int argc, char** argv) {
700719
::testing::InitGoogleTest(&argc, argv);
701720
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)