Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 4e83a3c

Browse files
Merge pull request #510 from facebookresearch/pr/unbind
makeIslAffFromExpr: only accept parametric expressions
2 parents 718a6d2 + 77c2df7 commit 4e83a3c

File tree

6 files changed

+125
-121
lines changed

6 files changed

+125
-121
lines changed

tc/core/halide2isl.cc

Lines changed: 90 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -174,16 +174,9 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
174174
const Max* maxOp = e.as<Max>();
175175

176176
if (const Variable* op = e.as<Variable>()) {
177-
isl::local_space ls = isl::local_space(space);
178-
int pos = space.find_dim_by_name(isl::dim_type::param, op->name);
179-
if (pos >= 0) {
180-
return {isl::aff(ls, isl::dim_type::param, pos)};
181-
} else {
182-
// FIXME: thou shalt not rely upon set dimension names
183-
pos = space.find_dim_by_name(isl::dim_type::set, op->name);
184-
if (pos >= 0) {
185-
return {isl::aff(ls, isl::dim_type::set, pos)};
186-
}
177+
isl::id id(space.get_ctx(), op->name);
178+
if (space.has_param(id)) {
179+
return {isl::aff::param_on_domain_space(space, id)};
187180
}
188181
LOG(FATAL) << "Variable not found in isl::space: " << space << ": " << op
189182
<< ": " << op->name << '\n';
@@ -248,32 +241,28 @@ isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
248241
return context;
249242
}
250243

244+
namespace {
245+
251246
isl::map extractAccess(
252-
isl::set domain,
247+
const IterationDomain& domain,
253248
const IRNode* op,
254249
const std::string& tensor,
255250
const std::vector<Expr>& args,
256251
AccessMap* accesses) {
257252
// Make an isl::map representing this access. It maps from the iteration space
258253
// to the tensor's storage space, using the coordinates accessed.
254+
// First construct a set describing the accessed element
255+
// in terms of the parameters (including those corresponding
256+
// to the outer loop iterators) and then convert this set
257+
// into a map in terms of the iteration domain.
259258

260-
isl::space domainSpace = domain.get_space();
261-
isl::space paramSpace = domainSpace.params();
259+
isl::space paramSpace = domain.paramSpace;
262260
isl::id tensorID(paramSpace.get_ctx(), tensor);
263-
auto rangeSpace = paramSpace.named_set_from_params_id(tensorID, args.size());
261+
auto tensorSpace = paramSpace.named_set_from_params_id(tensorID, args.size());
264262

265-
// Add a tag to the domain space so that we can maintain a mapping
266-
// between each access in the IR and the reads/writes maps.
267-
std::string tag = "__tc_ref_" + std::to_string(accesses->size());
268-
isl::id tagID(domain.get_ctx(), tag);
269-
accesses->emplace(op, tagID);
270-
isl::space tagSpace = paramSpace.named_set_from_params_id(tagID, 0);
271-
domainSpace = domainSpace.product(tagSpace);
272-
273-
// Start with a totally unconstrained relation - every point in
274-
// the iteration domain could write to every point in the allocation.
275-
isl::map map =
276-
isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace));
263+
// Start with a totally unconstrained set - every point in
264+
// the allocation could be accessed.
265+
isl::set access = isl::set::universe(tensorSpace);
277266

278267
for (size_t i = 0; i < args.size(); i++) {
279268
// Then add one equality constraint per dimension to encode the
@@ -283,19 +272,34 @@ isl::map extractAccess(
283272

284273
// The coordinate written to in the range ...
285274
auto rangePoint =
286-
isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i);
287-
// ... equals the coordinate accessed as a function of the domain.
288-
auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]);
275+
isl::pw_aff(isl::local_space(tensorSpace), isl::dim_type::set, i);
276+
// ... equals the coordinate accessed as a function of the parameters.
277+
auto domainPoint = halide2isl::makeIslAffFromExpr(tensorSpace, args[i]);
289278
if (!domainPoint.is_null()) {
290-
map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint));
279+
access = access.intersect(isl::pw_aff(domainPoint).eq_set(rangePoint));
291280
}
292281
}
293282

283+
// Now convert the set into a relation with respect to the iteration domain.
284+
auto map = access.unbind_params_insert_domain(domain.tuple);
285+
286+
// Add a tag to the domain space so that we can maintain a mapping
287+
// between each access in the IR and the reads/writes maps.
288+
std::string tag = "__tc_ref_" + std::to_string(accesses->size());
289+
isl::id tagID(domain.paramSpace.get_ctx(), tag);
290+
accesses->emplace(op, tagID);
291+
isl::space domainSpace = map.get_space().domain();
292+
isl::space tagSpace = domainSpace.params().named_set_from_params_id(tagID, 0);
293+
domainSpace = domainSpace.product(tagSpace).unwrap();
294+
map = map.preimage_domain(isl::multi_aff::domain_map(domainSpace));
295+
294296
return map;
295297
}
296298

297-
std::pair<isl::union_map, isl::union_map>
298-
extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
299+
std::pair<isl::union_map, isl::union_map> extractAccesses(
300+
const IterationDomain& domain,
301+
const Stmt& s,
302+
AccessMap* accesses) {
299303
class FindAccesses : public IRGraphVisitor {
300304
using IRGraphVisitor::visit;
301305

@@ -313,28 +317,46 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
313317
writes.unite(extractAccess(domain, op, op->name, op->args, accesses));
314318
}
315319

316-
const isl::set& domain;
320+
const IterationDomain& domain;
317321
AccessMap* accesses;
318322

319323
public:
320324
isl::union_map reads, writes;
321325

322-
FindAccesses(const isl::set& domain, AccessMap* accesses)
326+
FindAccesses(const IterationDomain& domain, AccessMap* accesses)
323327
: domain(domain),
324328
accesses(accesses),
325-
reads(isl::union_map::empty(domain.get_space())),
326-
writes(isl::union_map::empty(domain.get_space())) {}
329+
reads(isl::union_map::empty(domain.tuple.get_space())),
330+
writes(isl::union_map::empty(domain.tuple.get_space())) {}
327331
} finder(domain, accesses);
328332
s.accept(&finder);
329333
return {finder.reads, finder.writes};
330334
}
331335

336+
/*
337+
* Take a parametric expression "f" and convert it into an expression
338+
* on the iteration domains in "domain" by reinterpreting the parameters
339+
* as set dimensions according to the corresponding tuples in "map".
340+
*/
341+
isl::union_pw_aff
342+
onDomains(isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
343+
auto upa = isl::union_pw_aff::empty(domain.get_space());
344+
for (auto set : domain.get_set_list()) {
345+
auto tuple = map.at(set.get_tuple_id()).tuple;
346+
auto onSet = isl::union_pw_aff(f.unbind_params_insert_domain(tuple));
347+
upa = upa.union_add(onSet);
348+
}
349+
return upa;
350+
}
351+
352+
} // namespace
353+
332354
/*
333355
* Helper function for extracting a schedule from a Halide Stmt,
334356
* recursively descending over the Stmt.
335357
* "s" is the current position in the recursive descent.
336358
* "set" describes the bounds on the outer loop iterators.
337-
* "outer" contains the names of the outer loop iterators
359+
* "outer" contains the identifiers of the outer loop iterators
338360
* from outermost to innermost.
339361
* Return the schedule corresponding to the subtree at "s".
340362
*
@@ -343,81 +365,58 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
343365
* (for the writes) to the corresponding tag in the access relations.
344366
* "statements" collects the mapping from instance set tuple identifiers
345367
* to the corresponding Provide node.
346-
* "iterators" collects the mapping from instance set tuple identifiers
347-
* to the corresponding outer loop iterator names, from outermost to innermost.
368+
* "domains" collects the mapping from instance set tuple identifiers
369+
* to the corresponding iteration domain information.
348370
*/
349371
isl::schedule makeScheduleTreeHelper(
350372
const Stmt& s,
351373
isl::set set,
352-
std::vector<std::string>& outer,
374+
isl::id_list outer,
353375
isl::union_map* reads,
354376
isl::union_map* writes,
355377
AccessMap* accesses,
356378
StatementMap* statements,
357-
IteratorMap* iterators) {
379+
IterationDomainMap* domains) {
358380
isl::schedule schedule;
359381
if (auto op = s.as<For>()) {
360-
// Add one additional dimension to our set of loop variables
361-
int thisLoopIdx = set.dim(isl::dim_type::set);
362-
set = set.add_dims(isl::dim_type::set, 1);
363-
364-
// Make an id for this loop var. For set dimensions this is
365-
// really just for pretty-printing.
382+
// Make an id for this loop var. It starts out as a parameter.
366383
isl::id id(set.get_ctx(), op->name);
367-
set = set.set_dim_id(isl::dim_type::set, thisLoopIdx, id);
384+
auto space = set.get_space().add_param(id);
368385

369-
// Construct a variable (affine function) that indexes the new dimension of
370-
// this space.
371-
isl::aff loopVar(
372-
isl::local_space(set.get_space()), isl::dim_type::set, thisLoopIdx);
386+
// Construct a variable (affine function) that references
387+
// the new parameter.
388+
auto loopVar = isl::aff::param_on_domain_space(space, id);
373389

374390
// Then we add our new loop bound constraints.
375-
auto lbs = halide2isl::makeIslAffBoundsFromExpr(
376-
set.get_space(), op->min, false, true);
391+
auto lbs =
392+
halide2isl::makeIslAffBoundsFromExpr(space, op->min, false, true);
377393
TC_CHECK_GT(lbs.size(), 0u)
378394
<< "could not obtain polyhedral lower bounds from " << op->min;
379395
for (auto lb : lbs) {
380396
set = set.intersect(loopVar.ge_set(lb));
381397
}
382398

383399
Expr max = simplify(op->min + op->extent - 1);
384-
auto ubs =
385-
halide2isl::makeIslAffBoundsFromExpr(set.get_space(), max, true, false);
400+
auto ubs = halide2isl::makeIslAffBoundsFromExpr(space, max, true, false);
386401
TC_CHECK_GT(ubs.size(), 0u)
387402
<< "could not obtain polyhedral upper bounds from " << max;
388403
for (auto ub : ubs) {
389404
set = set.intersect(ub.ge_set(loopVar));
390405
}
391406

392407
// Recursively descend.
393-
auto outerNext = outer;
394-
outerNext.push_back(op->name);
408+
auto outerNext = outer.add(isl::id(set.get_ctx(), op->name));
395409
auto body = makeScheduleTreeHelper(
396-
op->body,
397-
set,
398-
outerNext,
399-
reads,
400-
writes,
401-
accesses,
402-
statements,
403-
iterators);
410+
op->body, set, outerNext, reads, writes, accesses, statements, domains);
404411

405412
// Create an affine function that defines an ordering for all
406413
// the statements in the body of this loop over the values of
407-
// this loop. For each statement in the children we want the
408-
// function that maps everything in its space to this
409-
// dimension. The spaces may be different, but they'll all have
410-
// this loop var at the same index.
411-
isl::multi_union_pw_aff mupa;
412-
body.get_domain().foreach_set([&](isl::set s) {
413-
isl::aff newLoopVar(
414-
isl::local_space(s.get_space()), isl::dim_type::set, thisLoopIdx);
415-
if (mupa) {
416-
mupa = mupa.union_add(isl::union_pw_aff(isl::pw_aff(newLoopVar)));
417-
} else {
418-
mupa = isl::union_pw_aff(isl::pw_aff(newLoopVar));
419-
}
420-
});
414+
// this loop. Start from a parametric expression equal
415+
// to the current loop iterator and then convert it to
416+
// a function on the statements in the domain of the body schedule.
417+
auto aff = isl::aff::param_on_domain_space(space, id);
418+
auto domain = body.get_domain();
419+
auto mupa = isl::multi_union_pw_aff(onDomains(aff, domain, *domains));
421420

422421
schedule = body.insert_partial_schedule(mupa);
423422
} else if (auto op = s.as<Halide::Internal::Block>()) {
@@ -430,7 +429,7 @@ isl::schedule makeScheduleTreeHelper(
430429
std::vector<isl::schedule> schedules;
431430
for (Stmt stmt : stmts) {
432431
schedules.push_back(makeScheduleTreeHelper(
433-
stmt, set, outer, reads, writes, accesses, statements, iterators));
432+
stmt, set, outer, reads, writes, accesses, statements, domains));
434433
}
435434
schedule = schedules[0].sequence(schedules[1]);
436435

@@ -441,13 +440,18 @@ isl::schedule makeScheduleTreeHelper(
441440
size_t stmtIndex = statements->size();
442441
isl::id id(set.get_ctx(), kStatementLabel + std::to_string(stmtIndex));
443442
statements->emplace(id, op);
444-
iterators->emplace(id, outer);
445-
isl::set domain = set.set_tuple_id(id);
443+
auto tupleSpace = isl::space(set.get_ctx(), 0);
444+
tupleSpace = tupleSpace.named_set_from_params_id(id, outer.n());
445+
IterationDomain iterationDomain;
446+
iterationDomain.paramSpace = set.get_space();
447+
iterationDomain.tuple = isl::multi_id(tupleSpace, outer);
448+
domains->emplace(id, iterationDomain);
449+
auto domain = set.unbind_params(iterationDomain.tuple);
446450
schedule = isl::schedule::from_domain(domain);
447451

448452
isl::union_map newReads, newWrites;
449453
std::tie(newReads, newWrites) =
450-
halide2isl::extractAccesses(domain, op, accesses);
454+
extractAccesses(iterationDomain, op, accesses);
451455

452456
*reads = reads->unite(newReads);
453457
*writes = writes->unite(newWrites);
@@ -464,7 +468,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
464468
result.writes = result.reads = isl::union_map::empty(paramSpace);
465469

466470
// Walk the IR building a schedule tree
467-
std::vector<std::string> outer;
471+
isl::id_list outer(paramSpace.get_ctx(), 0);
468472
auto schedule = makeScheduleTreeHelper(
469473
s,
470474
isl::set::universe(paramSpace),
@@ -473,7 +477,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
473477
&result.writes,
474478
&result.accesses,
475479
&result.statements,
476-
&result.iterators);
480+
&result.domains);
477481

478482
result.tree = fromIslSchedule(schedule);
479483

tc/core/halide2isl.h

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,25 @@ isl::aff makeIslAffFromInt(isl::space space, int64_t i);
5353

5454
// Make an affine function over a space from a Halide Expr. Returns a
5555
// null isl::aff if the expression is not affine. Fails if Variable
56-
// does not correspond to a parameter or set dimension of the space.
56+
// does not correspond to a parameter of the space.
57+
// Note that the input space can be either a parameter space or
58+
// a set space, but the expression can only reference
59+
// the parameters in the space.
5760
isl::aff makeIslAffFromExpr(isl::space space, const Halide::Expr& e);
5861

59-
typedef std::unordered_map<isl::id, std::vector<std::string>, isl::IslIdIslHash>
60-
IteratorMap;
62+
// Iteration domain information associated to a statement identifier.
63+
struct IterationDomain {
64+
// All parameters active at the point where the iteration domain
65+
// was created, including those corresponding to outer loop iterators.
66+
isl::space paramSpace;
67+
// The identifier tuple corresponding to the iteration domain.
68+
// The identifiers in the tuple are the outer loop iterators,
69+
// from outermost to innermost.
70+
isl::multi_id tuple;
71+
};
72+
73+
typedef std::unordered_map<isl::id, IterationDomain, isl::IslIdIslHash>
74+
IterationDomainMap;
6175
typedef std::unordered_map<isl::id, Halide::Internal::Stmt, isl::IslIdIslHash>
6276
StatementMap;
6377
typedef std::unordered_map<const Halide::Internal::IRNode*, isl::id> AccessMap;
@@ -81,9 +95,9 @@ struct ScheduleTreeAndAccesses {
8195
/// refered to above.
8296
StatementMap statements;
8397

84-
/// The correspondence between statement ids and the outer loop iterators
98+
/// The correspondence between statement ids and the iteration domain
8599
/// of the corresponding leaf Stmt.
86-
IteratorMap iterators;
100+
IterationDomainMap domains;
87101
};
88102

89103
/// Make a schedule tree from a Halide Stmt, along with auxiliary data

tc/core/polyhedral/codegen_llvm.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,19 +637,18 @@ isl::ast_node collectIteratorMaps(
637637
auto stmtId = expr.get_arg(0).as<isl::ast_expr_id>().get_id();
638638
TC_CHECK_EQ(0u, iteratorMaps.count(stmtId)) << "entry exists: " << stmtId;
639639
auto iteratorMap = isl::pw_multi_aff(scheduleMap.reverse());
640-
auto iterators = scop.halide.iterators.at(stmtId);
640+
auto tuple = scop.halide.domains.at(stmtId).tuple;
641641
auto& stmtIteratorMap = iteratorMaps[stmtId];
642-
for (size_t i = 0; i < iterators.size(); ++i) {
642+
for (int i = 0; i < tuple.size(); ++i) {
643643
auto expr = build.expr_from(iteratorMap.get_pw_aff(i));
644-
stmtIteratorMap.emplace(iterators[i], expr);
644+
stmtIteratorMap.emplace(tuple.get_id(i).get_name(), expr);
645645
}
646646
auto& subscripts = stmtSubscripts[stmtId];
647647
auto provide =
648648
scop.halide.statements.at(stmtId).as<Halide::Internal::Provide>();
649649
for (auto e : provide->args) {
650650
const auto& map = iteratorMap;
651-
auto space = map.get_space().params();
652-
auto aff = scop.makeIslAffFromStmtExpr(stmtId, space, e);
651+
auto aff = scop.makeIslAffFromStmtExpr(stmtId, e);
653652
auto pulled = isl::pw_aff(aff).pullback(map);
654653
TC_CHECK_EQ(pulled.n_piece(), 1);
655654
subscripts.push_back(build.expr_from(pulled));

tc/core/polyhedral/cuda/codegen.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ struct CodegenStatementContext : CodegenContext {
136136
// of the variables does not correspond to a parameter or
137137
// an instance identifier of the statement.
138138
isl::aff makeIslAffFromExpr(const Halide::Expr& e) const {
139-
auto space = iteratorMap().get_space().params();
140-
return scop().makeIslAffFromStmtExpr(statementId(), space, e);
139+
return scop().makeIslAffFromStmtExpr(statementId(), e);
141140
}
142141

143142
isl::id astNodeId;

0 commit comments

Comments
 (0)