@@ -251,31 +251,25 @@ isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
251
251
namespace {
252
252
253
253
isl::map extractAccess (
254
- isl::set domain,
254
+ const IterationDomain& domain,
255
255
const IRNode* op,
256
256
const std::string& tensor,
257
257
const std::vector<Expr>& args,
258
258
AccessMap* accesses) {
259
259
// Make an isl::map representing this access. It maps from the iteration space
260
260
// to the tensor's storage space, using the coordinates accessed.
261
+ // First construct a set describing the accessed element
262
+ // in terms of the parameters (including those corresponding
263
+ // to the outer loop iterators) and then convert this set
264
+ // into a map in terms of the iteration domain.
261
265
262
- isl::space domainSpace = domain.get_space ();
263
- isl::space paramSpace = domainSpace.params ();
266
+ isl::space paramSpace = domain.paramSpace ;
264
267
isl::id tensorID (paramSpace.get_ctx (), tensor);
265
- auto rangeSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
268
+ auto tensorSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
266
269
267
- // Add a tag to the domain space so that we can maintain a mapping
268
- // between each access in the IR and the reads/writes maps.
269
- std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
270
- isl::id tagID (domain.get_ctx (), tag);
271
- accesses->emplace (op, tagID);
272
- isl::space tagSpace = paramSpace.named_set_from_params_id (tagID, 0 );
273
- domainSpace = domainSpace.product (tagSpace);
274
-
275
- // Start with a totally unconstrained relation - every point in
276
- // the iteration domain could write to every point in the allocation.
277
- isl::map map =
278
- isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
270
+ // Start with a totally unconstrained set - every point in
271
+ // the allocation could be accessed.
272
+ isl::set access = isl::set::universe (tensorSpace);
279
273
280
274
for (size_t i = 0 ; i < args.size (); i++) {
281
275
// Then add one equality constraint per dimension to encode the
@@ -285,19 +279,34 @@ isl::map extractAccess(
285
279
286
280
// The coordinate written to in the range ...
287
281
auto rangePoint =
288
- isl::pw_aff (isl::local_space (rangeSpace ), isl::dim_type::set, i);
289
- // ... equals the coordinate accessed as a function of the domain .
290
- auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace , args[i]);
282
+ isl::pw_aff (isl::local_space (tensorSpace ), isl::dim_type::set, i);
283
+ // ... equals the coordinate accessed as a function of the parameters .
284
+ auto domainPoint = halide2isl::makeIslAffFromExpr (tensorSpace , args[i]);
291
285
if (!domainPoint.is_null ()) {
292
- map = map .intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
286
+ access = access .intersect (isl::pw_aff (domainPoint).eq_set (rangePoint));
293
287
}
294
288
}
295
289
290
+ // Now convert the set into a relation with respect to the iteration domain.
291
+ auto map = access.unbind_params_insert_domain (domain.tuple );
292
+
293
+ // Add a tag to the domain space so that we can maintain a mapping
294
+ // between each access in the IR and the reads/writes maps.
295
+ std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
296
+ isl::id tagID (domain.paramSpace .get_ctx (), tag);
297
+ accesses->emplace (op, tagID);
298
+ isl::space domainSpace = map.get_space ().domain ();
299
+ isl::space tagSpace = domainSpace.params ().named_set_from_params_id (tagID, 0 );
300
+ domainSpace = domainSpace.product (tagSpace).unwrap ();
301
+ map = map.preimage_domain (isl::multi_aff::domain_map (domainSpace));
302
+
296
303
return map;
297
304
}
298
305
299
- std::pair<isl::union_map, isl::union_map>
300
- extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
306
+ std::pair<isl::union_map, isl::union_map> extractAccesses (
307
+ const IterationDomain& domain,
308
+ const Stmt& s,
309
+ AccessMap* accesses) {
301
310
class FindAccesses : public IRGraphVisitor {
302
311
using IRGraphVisitor::visit;
303
312
@@ -315,17 +324,17 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
315
324
writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
316
325
}
317
326
318
- const isl::set & domain;
327
+ const IterationDomain & domain;
319
328
AccessMap* accesses;
320
329
321
330
public:
322
331
isl::union_map reads, writes;
323
332
324
- FindAccesses (const isl::set & domain, AccessMap* accesses)
333
+ FindAccesses (const IterationDomain & domain, AccessMap* accesses)
325
334
: domain(domain),
326
335
accesses (accesses),
327
- reads(isl::union_map::empty(domain.get_space())),
328
- writes(isl::union_map::empty(domain.get_space())) {}
336
+ reads(isl::union_map::empty(domain.tuple. get_space())),
337
+ writes(isl::union_map::empty(domain.tuple. get_space())) {}
329
338
} finder(domain, accesses);
330
339
s.accept(&finder);
331
340
return {finder.reads , finder.writes };
@@ -440,7 +449,8 @@ isl::schedule makeScheduleTreeHelper(
440
449
schedule = isl::schedule::from_domain (domain);
441
450
442
451
isl::union_map newReads, newWrites;
443
- std::tie (newReads, newWrites) = extractAccesses (domain, op, accesses);
452
+ std::tie (newReads, newWrites) =
453
+ extractAccesses (iterationDomain, op, accesses);
444
454
445
455
*reads = reads->unite (newReads);
446
456
*writes = writes->unite (newWrites);
0 commit comments