@@ -82,44 +82,141 @@ isl::aff makeIslAffFromInt(isl::space space, int64_t val) {
82
82
return isl::aff (isl::local_space (space), v);
83
83
}
84
84
85
- isl::aff makeIslAffFromExpr (isl::space space, const Expr& e) {
85
+ std::vector<isl::aff> makeIslAffBoundsFromExpr (
86
+ isl::space space,
87
+ const Expr& e,
88
+ bool allowMin,
89
+ bool allowMax);
90
+
91
+ namespace {
92
+ /*
93
+ * Convert Halide binary expression "op" into a list of isl affine functions by
94
+ * converting its LHS and RHS into lists of affs and concatenating those lists.
95
+ * This is intended to be used with Min/Max operations in upper/lower bound
96
+ * computations, respectively. Essentially, this allows for replacements
97
+ * x < min(a,min(b,c)) <=> x < a AND x < b AND x < c
98
+ * x > max(a,max(b,c)) <=> x > a AND x > b AND x > c
99
+ */
100
+ template <typename T>
101
+ inline std::vector<isl::aff>
102
+ concatAffs (isl::space space, T op, bool allowMin, bool allowMax) {
103
+ std::vector<isl::aff> result;
104
+
105
+ for (const auto & aff :
106
+ makeIslAffBoundsFromExpr (space, op->a , allowMin, allowMax)) {
107
+ result.push_back (aff);
108
+ }
109
+ for (const auto & aff :
110
+ makeIslAffBoundsFromExpr (space, op->b , allowMin, allowMax)) {
111
+ result.push_back (aff);
112
+ }
113
+
114
+ return result;
115
+ }
116
+
117
+ /*
118
+ * Convert Halide binary expression "op" into an isl affine function by
119
+ * converting its LHS and RHS into affs and combining them with "combine"
120
+ * into a single expression. LHS and RHS are expected to only produce a single
121
+ * expression.
122
+ * This is intended for use with operations other than Min/Max that do not
123
+ * commute nicely in bounds, for example
124
+ * x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
125
+ */
126
+ template <typename T>
127
+ inline isl::aff combineSingleAffs (
128
+ isl::space space,
129
+ T op,
130
+ isl::aff (isl::aff::*combine)(isl::aff) const ) {
131
+ auto left = makeIslAffBoundsFromExpr (space, op->a , false , false );
132
+ auto right = makeIslAffBoundsFromExpr (space, op->b , false , false );
133
+ CHECK_EQ (left.size (), 1 );
134
+ CHECK_EQ (right.size (), 1 );
135
+
136
+ return (left[0 ].*combine)(right[0 ]);
137
+ }
138
+
139
+ } // end namespace
140
+
141
+ /*
142
+ * Convert Halide expression into list of isl affine expressions usable for
143
+ * defining constraints. In particular, an expression starting with (nested)
144
+ * Max operations can be used for lower bounds
145
+ * x > max(a,b) <=> x > a AND x > b
146
+ * while an expression starting with (nested) Min operations can be used for
147
+ * upper bounds
148
+ * x < min(a,b) <=> x < a AND x < b.
149
+ * Arguments "allowMin" and "allowMax" control whether Min and Max operations,
150
+ * respecitvely, are allowed to be present in the expression. Note that they
151
+ * can only appear before any other operation and cannot appear together in an
152
+ * expression.
153
+ * If a Halide expression cannot be converted into a list of affine expressions,
154
+ * return an empty list.
155
+ */
156
+ std::vector<isl::aff> makeIslAffBoundsFromExpr (
157
+ isl::space space,
158
+ const Expr& e,
159
+ bool allowMin,
160
+ bool allowMax) {
161
+ CHECK (!(allowMin && allowMax));
162
+
163
+ const Min* minOp = e.as <Min>();
164
+ const Max* maxOp = e.as <Max>();
165
+
86
166
if (const Variable* op = e.as <Variable>()) {
87
167
isl::local_space ls = isl::local_space (space);
88
168
int pos = space.find_dim_by_name (isl::dim_type::param, op->name );
89
169
if (pos >= 0 ) {
90
- return isl::aff (ls, isl::dim_type::param, pos);
170
+ return { isl::aff (ls, isl::dim_type::param, pos)} ;
91
171
} else {
172
+ // FIXME: thou shalt not rely upon set dimension names
92
173
pos = space.find_dim_by_name (isl::dim_type::set, op->name );
93
174
if (pos >= 0 ) {
94
- return isl::aff (ls, isl::dim_type::set, pos);
175
+ return { isl::aff (ls, isl::dim_type::set, pos)} ;
95
176
}
96
177
}
97
178
LOG (FATAL) << " Variable not found in isl::space: " << space << " : " << op
98
179
<< " : " << op->name << ' \n ' ;
99
- return isl::aff ();
180
+ return {};
181
+ } else if (minOp != nullptr && allowMin) {
182
+ return concatAffs (space, minOp, allowMin, allowMax);
183
+ } else if (maxOp != nullptr && allowMax) {
184
+ return concatAffs (space, maxOp, allowMin, allowMax);
100
185
} else if (const Add* op = e.as <Add>()) {
101
- return makeIslAffFromExpr (space, op->a )
102
- .add (makeIslAffFromExpr (space, op->b ));
186
+ return {combineSingleAffs (space, op, &isl::aff::add)};
103
187
} else if (const Sub* op = e.as <Sub>()) {
104
- return makeIslAffFromExpr (space, op->a )
105
- .sub (makeIslAffFromExpr (space, op->b ));
188
+ return {combineSingleAffs (space, op, &isl::aff::sub)};
106
189
} else if (const Mul* op = e.as <Mul>()) {
107
- return makeIslAffFromExpr (space, op->a )
108
- .mul (makeIslAffFromExpr (space, op->b ));
190
+ return {combineSingleAffs (space, op, &isl::aff::mul)};
109
191
} else if (const Div* op = e.as <Div>()) {
110
- return makeIslAffFromExpr (space, op->a )
111
- .div (makeIslAffFromExpr (space, op->b ));
192
+ return {combineSingleAffs (space, op, &isl::aff::div)};
112
193
} else if (const Mod* op = e.as <Mod>()) {
194
+ std::vector<isl::aff> result;
195
+ // We cannot span multiple constraints if a modulo operation is involved.
196
+ // x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
197
+ auto lhs = makeIslAffBoundsFromExpr (space, e, false , false );
198
+ CHECK_EQ (lhs.size (), 1 );
113
199
if (const int64_t * b = as_const_int (op->b )) {
114
- return makeIslAffFromExpr (space, op->a )
115
- .mod (isl::val (space.get_ctx (), *b));
200
+ return {lhs[0 ].mod (isl::val (space.get_ctx (), *b))};
116
201
}
117
202
} else if (const int64_t * i = as_const_int (e)) {
118
- return makeIslAffFromInt (space, *i);
203
+ return { makeIslAffFromInt (space, *i)} ;
119
204
}
120
205
206
+ return {};
207
+ }
208
+
209
+ isl::aff makeIslAffFromExpr (isl::space space, const Expr& e) {
210
+ auto list = makeIslAffBoundsFromExpr (space, e, false , false );
211
+ CHECK_LE (list.size (), 1 ) << " Halide expr " << e
212
+ << " unrolled into more than 1 isl aff"
213
+ << " but min/max operations were disabled" ;
214
+
121
215
// Non-affine
122
- return isl::aff ();
216
+ if (list.size () == 0 ) {
217
+ return isl::aff ();
218
+ }
219
+ return list[0 ];
123
220
}
124
221
125
222
isl::space makeParamSpace (isl::ctx ctx, const SymbolTable& symbolTable) {
@@ -247,10 +344,22 @@ ScheduleTreeAndDomain makeScheduleTreeHelper(
247
344
isl::local_space (set.get_space ()), isl::dim_type::set, thisLoopIdx);
248
345
249
346
// Then we add our new loop bound constraints.
250
- isl::aff lb = halide2isl::makeIslAffFromExpr (set.get_space (), op->min );
347
+ auto lbs = halide2isl::makeIslAffBoundsFromExpr (
348
+ set.get_space (), op->min , false , true );
349
+ CHECK_GT (lbs.size (), 0 )
350
+ << " could not obtain polyhedral lower bounds from " << op->min ;
351
+ for (auto lb : lbs) {
352
+ set = set.intersect (loopVar.ge_set (lb));
353
+ }
354
+
251
355
Expr max = simplify (op->min + op->extent - 1 );
252
- isl::aff ub = halide2isl::makeIslAffFromExpr (set.get_space (), max);
253
- set = set.intersect (loopVar.ge_set (lb).intersect (ub.ge_set (loopVar)));
356
+ auto ubs =
357
+ halide2isl::makeIslAffBoundsFromExpr (set.get_space (), max, true , false );
358
+ CHECK_GT (ubs.size (), 0 )
359
+ << " could not obtain polyhedral upper bounds from " << max;
360
+ for (auto ub : ubs) {
361
+ set = set.intersect (ub.ge_set (loopVar));
362
+ }
254
363
255
364
// Recursively descend.
256
365
auto body = makeScheduleTreeHelper (
0 commit comments