Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 8a4d9f1

Browse files
author
Sven Verdoolaege
committed
halide2isl::extractAccesses: construct accesses in parameter space
This will allow makeIslAffBoundsFromExpr to be changed to only deal with parameters in the next commit.
1 parent a6618b1 commit 8a4d9f1

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

tc/core/halide2isl.cc

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -251,31 +251,25 @@ isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
251251
namespace {
252252

253253
isl::map extractAccess(
254-
isl::set domain,
254+
const IterationDomain& domain,
255255
const IRNode* op,
256256
const std::string& tensor,
257257
const std::vector<Expr>& args,
258258
AccessMap* accesses) {
259259
// Make an isl::map representing this access. It maps from the iteration space
260260
// to the tensor's storage space, using the coordinates accessed.
261+
// First construct a set describing the accessed element
262+
// in terms of the parameters (including those corresponding
263+
// to the outer loop iterators) and then convert this set
264+
// into a map in terms of the iteration domain.
261265

262-
isl::space domainSpace = domain.get_space();
263-
isl::space paramSpace = domainSpace.params();
266+
isl::space paramSpace = domain.paramSpace;
264267
isl::id tensorID(paramSpace.get_ctx(), tensor);
265-
auto rangeSpace = paramSpace.named_set_from_params_id(tensorID, args.size());
268+
auto tensorSpace = paramSpace.named_set_from_params_id(tensorID, args.size());
266269

267-
// Add a tag to the domain space so that we can maintain a mapping
268-
// between each access in the IR and the reads/writes maps.
269-
std::string tag = "__tc_ref_" + std::to_string(accesses->size());
270-
isl::id tagID(domain.get_ctx(), tag);
271-
accesses->emplace(op, tagID);
272-
isl::space tagSpace = paramSpace.named_set_from_params_id(tagID, 0);
273-
domainSpace = domainSpace.product(tagSpace);
274-
275-
// Start with a totally unconstrained relation - every point in
276-
// the iteration domain could write to every point in the allocation.
277-
isl::map map =
278-
isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace));
270+
// Start with a totally unconstrained set - every point in
271+
// the allocation could be accessed.
272+
isl::set access = isl::set::universe(tensorSpace);
279273

280274
for (size_t i = 0; i < args.size(); i++) {
281275
// Then add one equality constraint per dimension to encode the
@@ -285,19 +279,34 @@ isl::map extractAccess(
285279

286280
// The coordinate written to in the range ...
287281
auto rangePoint =
288-
isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i);
289-
// ... equals the coordinate accessed as a function of the domain.
290-
auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]);
282+
isl::pw_aff(isl::local_space(tensorSpace), isl::dim_type::set, i);
283+
// ... equals the coordinate accessed as a function of the parameters.
284+
auto domainPoint = halide2isl::makeIslAffFromExpr(tensorSpace, args[i]);
291285
if (!domainPoint.is_null()) {
292-
map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint));
286+
access = access.intersect(isl::pw_aff(domainPoint).eq_set(rangePoint));
293287
}
294288
}
295289

290+
// Now convert the set into a relation with respect to the iteration domain.
291+
auto map = access.unbind_params_insert_domain(domain.tuple);
292+
293+
// Add a tag to the domain space so that we can maintain a mapping
294+
// between each access in the IR and the reads/writes maps.
295+
std::string tag = "__tc_ref_" + std::to_string(accesses->size());
296+
isl::id tagID(domain.paramSpace.get_ctx(), tag);
297+
accesses->emplace(op, tagID);
298+
isl::space domainSpace = map.get_space().domain();
299+
isl::space tagSpace = domainSpace.params().named_set_from_params_id(tagID, 0);
300+
domainSpace = domainSpace.product(tagSpace).unwrap();
301+
map = map.preimage_domain(isl::multi_aff::domain_map(domainSpace));
302+
296303
return map;
297304
}
298305

299-
std::pair<isl::union_map, isl::union_map>
300-
extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
306+
std::pair<isl::union_map, isl::union_map> extractAccesses(
307+
const IterationDomain& domain,
308+
const Stmt& s,
309+
AccessMap* accesses) {
301310
class FindAccesses : public IRGraphVisitor {
302311
using IRGraphVisitor::visit;
303312

@@ -315,17 +324,17 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
315324
writes.unite(extractAccess(domain, op, op->name, op->args, accesses));
316325
}
317326

318-
const isl::set& domain;
327+
const IterationDomain& domain;
319328
AccessMap* accesses;
320329

321330
public:
322331
isl::union_map reads, writes;
323332

324-
FindAccesses(const isl::set& domain, AccessMap* accesses)
333+
FindAccesses(const IterationDomain& domain, AccessMap* accesses)
325334
: domain(domain),
326335
accesses(accesses),
327-
reads(isl::union_map::empty(domain.get_space())),
328-
writes(isl::union_map::empty(domain.get_space())) {}
336+
reads(isl::union_map::empty(domain.tuple.get_space())),
337+
writes(isl::union_map::empty(domain.tuple.get_space())) {}
329338
} finder(domain, accesses);
330339
s.accept(&finder);
331340
return {finder.reads, finder.writes};
@@ -440,7 +449,8 @@ isl::schedule makeScheduleTreeHelper(
440449
schedule = isl::schedule::from_domain(domain);
441450

442451
isl::union_map newReads, newWrites;
443-
std::tie(newReads, newWrites) = extractAccesses(domain, op, accesses);
452+
std::tie(newReads, newWrites) =
453+
extractAccesses(iterationDomain, op, accesses);
444454

445455
*reads = reads->unite(newReads);
446456
*writes = writes->unite(newWrites);

0 commit comments

Comments
 (0)