Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 931eeed

Browse files
author
Sven Verdoolaege
committed
[RFC] use templated isl types halide2isl
Templated isl types require the user to specify the domain and range universes of isl objects, allowing the compiler to check whether it makes sense to combine pairs of objects. This RFC only converts isPromotableToRegistersBelow and some related functions to illustrate the effect. The isPromotableToRegistersBelow was already applying operations correctly, so the code itself did not require any changes. However, one variable was reused to store different types of intermediate result and this one had to be split up into several variables because they now have different types.
1 parent a804fea commit 931eeed

File tree

3 files changed

+80
-72
lines changed

3 files changed

+80
-72
lines changed

tc/core/halide2isl.cc

Lines changed: 71 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tc/core/check.h"
2222
#include "tc/core/constants.h"
2323
#include "tc/core/polyhedral/body.h"
24+
#include "tc/core/polyhedral/domain_types.h"
2425
#include "tc/core/polyhedral/schedule_isl_conversion.h"
2526
#include "tc/core/polyhedral/schedule_transforms.h"
2627
#include "tc/core/polyhedral/schedule_tree.h"
@@ -80,14 +81,14 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
8081
return builder.table;
8182
}
8283

83-
isl::aff makeIslAffFromInt(isl::space space, int64_t val) {
84+
isl::AffOn<> makeIslAffFromInt(isl::Space<> space, int64_t val) {
8485
isl::val v = isl::val(space.get_ctx(), val);
85-
return isl::aff(isl::local_space(space), v);
86+
return isl::AffOn<>(isl::aff(isl::local_space(space), v));
8687
}
8788

88-
std::vector<isl::aff> makeIslAffBoundsFromExpr(
89-
isl::space space,
90-
const Expr& e,
89+
std::vector<isl::AffOn<>> makeIslAffBoundsFromExpr(
90+
isl::Space<> space,
91+
const Halide::Expr& e,
9192
bool allowMin,
9293
bool allowMax);
9394

@@ -101,9 +102,9 @@ namespace {
101102
* x > max(a,max(b,c)) <=> x > a AND x > b AND x > c
102103
*/
103104
template <typename T>
104-
inline std::vector<isl::aff>
105-
concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
106-
std::vector<isl::aff> result;
105+
inline std::vector<isl::AffOn<>>
106+
concatAffs(isl::Space<> space, T op, bool allowMin, bool allowMax) {
107+
std::vector<isl::AffOn<>> result;
107108

108109
for (const auto& aff :
109110
makeIslAffBoundsFromExpr(space, op->a, allowMin, allowMax)) {
@@ -129,10 +130,10 @@ concatAffs(isl::space space, T op, bool allowMin, bool allowMax) {
129130
* x < a + max(b,c) NOT <=> x < a + b AND x < a + c for negative values.
130131
*/
131132
template <typename T>
132-
inline std::vector<isl::aff> combineSingleAffs(
133-
isl::space space,
133+
inline std::vector<isl::AffOn<>> combineSingleAffs(
134+
isl::Space<> space,
134135
T op,
135-
isl::aff (isl::aff::*combine)(isl::aff) const) {
136+
isl::AffOn<> (isl::AffOn<>::*combine)(const isl::AffOn<>&) const) {
136137
auto left = makeIslAffBoundsFromExpr(space, op->a, false, false);
137138
auto right = makeIslAffBoundsFromExpr(space, op->b, false, false);
138139
TC_CHECK_LE(left.size(), 1u);
@@ -162,9 +163,9 @@ inline std::vector<isl::aff> combineSingleAffs(
162163
* If a Halide expression cannot be converted into a list of affine expressions,
163164
* return an empty list.
164165
*/
165-
std::vector<isl::aff> makeIslAffBoundsFromExpr(
166-
isl::space space,
167-
const Expr& e,
166+
std::vector<isl::AffOn<>> makeIslAffBoundsFromExpr(
167+
isl::Space<> space,
168+
const Halide::Expr& e,
168169
bool allowMin,
169170
bool allowMax) {
170171
TC_CHECK(!(allowMin && allowMax));
@@ -178,7 +179,7 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
178179
if (const Variable* op = e.as<Variable>()) {
179180
isl::id id(space.get_ctx(), op->name);
180181
if (space.has_param(id)) {
181-
return {isl::aff::param_on_domain_space(space, id)};
182+
return {isl::AffOn<>::param_on_domain_space(space, id)};
182183
}
183184
LOG(FATAL) << "Variable not found in isl::space: " << space << ": " << op
184185
<< ": " << op->name << '\n';
@@ -188,13 +189,13 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
188189
} else if (maxOp != nullptr && allowMax) {
189190
return concatAffs(space, maxOp, allowMin, allowMax);
190191
} else if (const Add* op = e.as<Add>()) {
191-
return combineSingleAffs(space, op, &isl::aff::add);
192+
return combineSingleAffs(space, op, &isl::AffOn<>::add);
192193
} else if (const Sub* op = e.as<Sub>()) {
193-
return combineSingleAffs(space, op, &isl::aff::sub);
194+
return combineSingleAffs(space, op, &isl::AffOn<>::sub);
194195
} else if (const Mul* op = e.as<Mul>()) {
195-
return combineSingleAffs(space, op, &isl::aff::mul);
196+
return combineSingleAffs(space, op, &isl::AffOn<>::mul);
196197
} else if (const Div* op = e.as<Div>()) {
197-
return combineSingleAffs(space, op, &isl::aff::div);
198+
return combineSingleAffs(space, op, &isl::AffOn<>::div);
198199
} else if (const Mod* op = e.as<Mod>()) {
199200
std::vector<isl::aff> result;
200201
// We cannot span multiple constraints if a modulo operation is involved.
@@ -211,45 +212,50 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
211212
return {};
212213
}
213214

214-
isl::aff makeIslAffFromExpr(isl::space space, const Expr& e) {
215+
isl::AffOn<> makeIslAffFromExpr(isl::Space<> space, const Halide::Expr& e) {
215216
auto list = makeIslAffBoundsFromExpr(space, e, false, false);
216217
TC_CHECK_LE(list.size(), 1u)
217218
<< "Halide expr " << e << " unrolled into more than 1 isl aff"
218219
<< " but min/max operations were disabled";
219220

220221
// Non-affine
221222
if (list.size() == 0) {
222-
return isl::aff();
223+
return isl::AffOn<>();
223224
}
224225
return list[0];
225226
}
226227

227-
isl::space makeParamSpace(isl::ctx ctx, const ParameterVector& params) {
228+
isl::Space<> makeParamSpace(isl::ctx ctx, const ParameterVector& params) {
228229
auto space = isl::space(ctx, 0);
229230
// set parameter names
230231
for (auto p : params) {
231232
space = space.add_param(isl::id(ctx, p.name()));
232233
}
233-
return space;
234+
return isl::Space<>(space);
234235
}
235236

236-
isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
237+
isl::Set<> makeParamContext(isl::ctx ctx, const ParameterVector& params) {
237238
auto space = makeParamSpace(ctx, params);
238-
auto context = isl::set::universe(space);
239+
auto context = isl::Set<>::universe(space);
239240
for (auto p : params) {
240-
isl::aff a(isl::aff::param_on_domain_space(space, isl::id(ctx, p.name())));
241-
context = context & (a >= 0);
241+
auto a(isl::AffOn<>::param_on_domain_space(space, isl::id(ctx, p.name())));
242+
context = context & isl::PwAffOn<>(a).nonneg_set();
242243
}
243244
return context;
244245
}
245246

246247
namespace {
247248

248-
isl::map extractAccess(
249+
template <typename Domain, typename Range>
250+
static isl::MultiAff<isl::Pair<Domain, Range>, Domain> domainMap(isl::Space<Domain, Range> space) {
251+
return isl::MultiAff<isl::Pair<Domain, Range>, Domain>::domain_map(space);
252+
}
253+
254+
isl::Map<isl::Pair<Statement, Tag>, Tensor> extractAccess(
249255
const IterationDomain& domain,
250256
const IRNode* op,
251257
const std::string& tensor,
252-
const std::vector<Expr>& args,
258+
const std::vector<Halide::Expr>& args,
253259
AccessMap* accesses) {
254260
// Make an isl::map representing this access. It maps from the iteration space
255261
// to the tensor's storage space, using the coordinates accessed.
@@ -258,16 +264,16 @@ isl::map extractAccess(
258264
// to the outer loop iterators) and then convert this set
259265
// into a map in terms of the iteration domain.
260266

261-
auto paramSpace = isl::Space<>(domain.paramSpace);
267+
auto paramSpace = domain.paramSpace;
262268
isl::id tensorID(paramSpace.get_ctx(), tensor);
263269
auto tensorTuple = constructTensorTuple(paramSpace, tensorID, args.size());
264270
auto tensorSpace = tensorTuple.get_space();
265271

266272
// Start with a totally unconstrained set - every point in
267273
// the allocation could be accessed.
268-
isl::set access = isl::set::universe(tensorSpace);
274+
auto access = isl::Set<Tensor>::universe(tensorSpace);
269275

270-
auto identity = isl::multi_aff::identity(tensorSpace.map_from_set());
276+
auto identity = isl::MultiAff<Tensor, Tensor>::identity(tensorSpace.map_from_set());
271277
for (size_t i = 0; i < args.size(); i++) {
272278
// Then add one equality constraint per dimension to encode the
273279
// point in the allocation actually read/written for each point in
@@ -279,8 +285,8 @@ isl::map extractAccess(
279285
// ... equals the coordinate accessed as a function of the parameters.
280286
auto domainPoint = halide2isl::makeIslAffFromExpr(paramSpace, args[i]);
281287
if (!domainPoint.is_null()) {
282-
domainPoint = domainPoint.unbind_params_insert_domain(tensorTuple);
283-
access = access.intersect(domainPoint.eq_set(rangePoint));
288+
auto domainPoint2 = domainPoint.unbind_params_insert_domain(tensorTuple);
289+
access = access.intersect(domainPoint2.eq_set(rangePoint));
284290
}
285291
}
286292

@@ -292,15 +298,13 @@ isl::map extractAccess(
292298
std::string tag = "__tc_ref_" + std::to_string(accesses->size());
293299
isl::id tagID(domain.paramSpace.get_ctx(), tag);
294300
accesses->emplace(op, tagID);
295-
isl::space domainSpace = map.get_space().domain();
296-
isl::space tagSpace = domainSpace.params().add_named_tuple_id_ui(tagID, 0);
297-
domainSpace = domainSpace.product(tagSpace).unwrap();
298-
map = map.preimage_domain(isl::multi_aff::domain_map(domainSpace));
299-
300-
return map;
301+
auto domainSpace = map.get_space().domain();
302+
auto tagSpace = domainSpace.params().add_named_tuple_id_ui<Tag>(tagID, 0);
303+
auto taggedSpace = domainSpace.product(tagSpace).unwrap<Statement, Tag>();
304+
return map.preimage_domain(domainMap(taggedSpace));
301305
}
302306

303-
std::pair<isl::union_map, isl::union_map> extractAccesses(
307+
std::pair<isl::UnionMap<isl::Pair<Statement, Tag>, Tensor>, isl::UnionMap<isl::Pair<Statement, Tag>, Tensor>> extractAccesses(
304308
const IterationDomain& domain,
305309
const Stmt& s,
306310
AccessMap* accesses) {
@@ -325,7 +329,7 @@ std::pair<isl::union_map, isl::union_map> extractAccesses(
325329
AccessMap* accesses;
326330

327331
public:
328-
isl::union_map reads, writes;
332+
isl::UnionMap<isl::Pair<Statement, Tag>, Tensor> reads, writes;
329333

330334
FindAccesses(const IterationDomain& domain, AccessMap* accesses)
331335
: domain(domain),
@@ -355,24 +359,24 @@ bool isReductionUpdate(const Provide* op) {
355359
* then converted into an expression on that iteration domain
356360
* by reinterpreting the parameters as input dimensions.
357361
*/
358-
static isl::multi_aff mapToOther(
362+
template <typename Other>
363+
static isl::MultiAff<Statement, Other> mapToOther(
359364
const IterationDomain& iterationDomain,
360365
std::unordered_set<std::string> skip,
361366
isl::id id) {
362367
auto ctx = iterationDomain.tuple.get_ctx();
363-
auto list = isl::aff_list(ctx, 0);
368+
auto list = isl::AffListOn<Statement>(isl::aff_list(ctx, 0));
364369
for (auto id : iterationDomain.tuple.get_id_list()) {
365370
if (skip.count(id.get_name()) == 1) {
366371
continue;
367372
}
368-
auto aff = isl::aff::param_on_domain_space(iterationDomain.paramSpace, id);
369-
aff = aff.unbind_params_insert_domain(iterationDomain.tuple);
370-
list = list.add(aff);
373+
auto aff = isl::AffOn<>::param_on_domain_space(iterationDomain.paramSpace, id);
374+
list = list.add(aff.unbind_params_insert_domain(iterationDomain.tuple));
371375
}
372376
auto domainSpace = iterationDomain.tuple.get_space();
373-
auto space = domainSpace.params().add_named_tuple_id_ui(id, list.size());
374-
space = domainSpace.product(space).unwrap();
375-
return isl::multi_aff(space, list);
377+
auto space = domainSpace.params().add_named_tuple_id_ui<Other>(id, list.size());
378+
auto productSpace = domainSpace.product(space).template unwrap<Statement, Other>();
379+
return isl::MultiAff<Statement, Other>(productSpace, list);
376380
}
377381

378382
/*
@@ -392,7 +396,7 @@ static isl::multi_aff mapToOther(
392396
* that all statement instances that belong to the same reduction
393397
* write to the same tensor element.
394398
*/
395-
isl::union_map extractReduction(
399+
isl::UnionMap<Statement, Reduction> extractReduction(
396400
const IterationDomain& iterationDomain,
397401
const Provide* op,
398402
size_t index) {
@@ -409,16 +413,19 @@ isl::union_map extractReduction(
409413
} finder;
410414

411415
if (!isReductionUpdate(op)) {
412-
return isl::union_map::empty(iterationDomain.tuple.get_space().params());
416+
auto space = iterationDomain.tuple.get_space().params();
417+
return isl::UnionMap<Statement, Reduction>::empty(space);
413418
}
414419
op->accept(&finder);
415420
if (finder.reductionVars.size() == 0) {
416-
return isl::union_map::empty(iterationDomain.tuple.get_space().params());
421+
auto space = iterationDomain.tuple.get_space().params();
422+
return isl::UnionMap<Statement, Reduction>(isl::union_map::empty(space));
417423
}
418424
auto ctx = iterationDomain.tuple.get_ctx();
419425
isl::id id(ctx, kReductionLabel + op->name + "_" + std::to_string(index));
420-
auto reduction = mapToOther(iterationDomain, finder.reductionVars, id);
421-
return isl::union_map(isl::map(reduction));
426+
auto reduction = mapToOther<Reduction>(iterationDomain, finder.reductionVars, id);
427+
return isl::UnionMap<Statement, Reduction>(
428+
isl::Map<Statement, Reduction>(reduction));
422429
}
423430

424431
/*
@@ -458,7 +465,7 @@ onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
458465
*/
459466
isl::schedule makeScheduleTreeHelper(
460467
const Stmt& s,
461-
isl::set set,
468+
isl::Set<> set,
462469
isl::id_list outer,
463470
Body* body,
464471
AccessMap* accesses,
@@ -472,7 +479,7 @@ isl::schedule makeScheduleTreeHelper(
472479

473480
// Construct a variable (affine function) that references
474481
// the new parameter.
475-
auto loopVar = isl::aff::param_on_domain_space(space, id);
482+
auto loopVar = isl::AffOn<>::param_on_domain_space(space, id);
476483

477484
// Then we add our new loop bound constraints.
478485
auto lbs =
@@ -483,7 +490,7 @@ isl::schedule makeScheduleTreeHelper(
483490
set = set.intersect(loopVar.ge_set(lb));
484491
}
485492

486-
Expr max = simplify(op->min + op->extent - 1);
493+
Halide::Expr max = simplify(op->min + op->extent - 1);
487494
auto ubs = halide2isl::makeIslAffBoundsFromExpr(space, max, true, false);
488495
TC_CHECK_GT(ubs.size(), 0u)
489496
<< "could not obtain polyhedral upper bounds from " << max;
@@ -527,16 +534,16 @@ isl::schedule makeScheduleTreeHelper(
527534
size_t stmtIndex = statements->size();
528535
isl::id id(set.get_ctx(), kStatementLabel + std::to_string(stmtIndex));
529536
statements->emplace(id, op);
530-
auto tupleSpace = isl::space(set.get_ctx(), 0);
531-
tupleSpace = tupleSpace.add_named_tuple_id_ui(id, outer.size());
537+
auto space = isl::Space<>(isl::space(set.get_ctx(), 0));
538+
auto tupleSpace = space.add_named_tuple_id_ui<Statement>(id, outer.size());
532539
IterationDomain iterationDomain;
533540
iterationDomain.paramSpace = set.get_space();
534-
iterationDomain.tuple = isl::multi_id(tupleSpace, outer);
541+
iterationDomain.tuple = isl::MultiId<Statement>(tupleSpace, outer);
535542
domains->emplace(id, iterationDomain);
536543
auto domain = set.unbind_params(iterationDomain.tuple);
537544
schedule = isl::schedule::from_domain(domain);
538545

539-
isl::union_map newReads, newWrites;
546+
isl::UnionMap<isl::Pair<Statement, Tag>, Tensor> newReads, newWrites;
540547
std::tie(newReads, newWrites) =
541548
extractAccesses(iterationDomain, op, accesses);
542549
// A tensor may be involved in multiple reductions.
@@ -553,7 +560,7 @@ isl::schedule makeScheduleTreeHelper(
553560
return schedule;
554561
};
555562

556-
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
563+
ScheduleTreeAndAccesses makeScheduleTree(isl::Space<> paramSpace, const Stmt& s) {
557564
ScheduleTreeAndAccesses result;
558565

559566
Body body(paramSpace);
@@ -562,7 +569,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
562569
isl::id_list outer(paramSpace.get_ctx(), 0);
563570
auto schedule = makeScheduleTreeHelper(
564571
s,
565-
isl::set::universe(paramSpace),
572+
isl::Set<>::universe(paramSpace),
566573
outer,
567574
&body,
568575
&result.accesses,

0 commit comments

Comments
 (0)