@@ -174,16 +174,9 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
174
174
const Max* maxOp = e.as <Max>();
175
175
176
176
if (const Variable* op = e.as <Variable>()) {
177
- isl::local_space ls = isl::local_space (space);
178
- int pos = space.find_dim_by_name (isl::dim_type::param, op->name );
179
- if (pos >= 0 ) {
180
- return {isl::aff (ls, isl::dim_type::param, pos)};
181
- } else {
182
- // FIXME: thou shalt not rely upon set dimension names
183
- pos = space.find_dim_by_name (isl::dim_type::set, op->name );
184
- if (pos >= 0 ) {
185
- return {isl::aff (ls, isl::dim_type::set, pos)};
186
- }
177
+ isl::id id (space.get_ctx (), op->name );
178
+ if (space.has_param (id)) {
179
+ return {isl::aff::param_on_domain_space (space, id)};
187
180
}
188
181
LOG (FATAL) << " Variable not found in isl::space: " << space << " : " << op
189
182
<< " : " << op->name << ' \n ' ;
@@ -248,32 +241,28 @@ isl::set makeParamContext(isl::ctx ctx, const ParameterVector& params) {
248
241
return context;
249
242
}
250
243
244
+ namespace {
245
+
251
246
isl::map extractAccess (
252
- isl::set domain,
247
+ const IterationDomain& domain,
253
248
const IRNode* op,
254
249
const std::string& tensor,
255
250
const std::vector<Expr>& args,
256
251
AccessMap* accesses) {
257
252
// Make an isl::map representing this access. It maps from the iteration space
258
253
// to the tensor's storage space, using the coordinates accessed.
254
+ // First construct a set describing the accessed element
255
+ // in terms of the parameters (including those corresponding
256
+ // to the outer loop iterators) and then convert this set
257
+ // into a map in terms of the iteration domain.
259
258
260
- isl::space domainSpace = domain.get_space ();
261
- isl::space paramSpace = domainSpace.params ();
259
+ isl::space paramSpace = domain.paramSpace ;
262
260
isl::id tensorID (paramSpace.get_ctx (), tensor);
263
- auto rangeSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
261
+ auto tensorSpace = paramSpace.named_set_from_params_id (tensorID, args.size ());
264
262
265
- // Add a tag to the domain space so that we can maintain a mapping
266
- // between each access in the IR and the reads/writes maps.
267
- std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
268
- isl::id tagID (domain.get_ctx (), tag);
269
- accesses->emplace (op, tagID);
270
- isl::space tagSpace = paramSpace.named_set_from_params_id (tagID, 0 );
271
- domainSpace = domainSpace.product (tagSpace);
272
-
273
- // Start with a totally unconstrained relation - every point in
274
- // the iteration domain could write to every point in the allocation.
275
- isl::map map =
276
- isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
263
+ // Start with a totally unconstrained set - every point in
264
+ // the allocation could be accessed.
265
+ isl::set access = isl::set::universe (tensorSpace);
277
266
278
267
for (size_t i = 0 ; i < args.size (); i++) {
279
268
// Then add one equality constraint per dimension to encode the
@@ -283,19 +272,34 @@ isl::map extractAccess(
283
272
284
273
// The coordinate written to in the range ...
285
274
auto rangePoint =
286
- isl::pw_aff (isl::local_space (rangeSpace ), isl::dim_type::set, i);
287
- // ... equals the coordinate accessed as a function of the domain .
288
- auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace , args[i]);
275
+ isl::pw_aff (isl::local_space (tensorSpace ), isl::dim_type::set, i);
276
+ // ... equals the coordinate accessed as a function of the parameters .
277
+ auto domainPoint = halide2isl::makeIslAffFromExpr (tensorSpace , args[i]);
289
278
if (!domainPoint.is_null ()) {
290
- map = map .intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
279
+ access = access .intersect (isl::pw_aff (domainPoint).eq_set (rangePoint));
291
280
}
292
281
}
293
282
283
+ // Now convert the set into a relation with respect to the iteration domain.
284
+ auto map = access.unbind_params_insert_domain (domain.tuple );
285
+
286
+ // Add a tag to the domain space so that we can maintain a mapping
287
+ // between each access in the IR and the reads/writes maps.
288
+ std::string tag = " __tc_ref_" + std::to_string (accesses->size ());
289
+ isl::id tagID (domain.paramSpace .get_ctx (), tag);
290
+ accesses->emplace (op, tagID);
291
+ isl::space domainSpace = map.get_space ().domain ();
292
+ isl::space tagSpace = domainSpace.params ().named_set_from_params_id (tagID, 0 );
293
+ domainSpace = domainSpace.product (tagSpace).unwrap ();
294
+ map = map.preimage_domain (isl::multi_aff::domain_map (domainSpace));
295
+
294
296
return map;
295
297
}
296
298
297
- std::pair<isl::union_map, isl::union_map>
298
- extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
299
+ std::pair<isl::union_map, isl::union_map> extractAccesses (
300
+ const IterationDomain& domain,
301
+ const Stmt& s,
302
+ AccessMap* accesses) {
299
303
class FindAccesses : public IRGraphVisitor {
300
304
using IRGraphVisitor::visit;
301
305
@@ -313,28 +317,46 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
313
317
writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
314
318
}
315
319
316
- const isl::set & domain;
320
+ const IterationDomain & domain;
317
321
AccessMap* accesses;
318
322
319
323
public:
320
324
isl::union_map reads, writes;
321
325
322
- FindAccesses (const isl::set & domain, AccessMap* accesses)
326
+ FindAccesses (const IterationDomain & domain, AccessMap* accesses)
323
327
: domain(domain),
324
328
accesses (accesses),
325
- reads(isl::union_map::empty(domain.get_space())),
326
- writes(isl::union_map::empty(domain.get_space())) {}
329
+ reads(isl::union_map::empty(domain.tuple. get_space())),
330
+ writes(isl::union_map::empty(domain.tuple. get_space())) {}
327
331
} finder(domain, accesses);
328
332
s.accept(&finder);
329
333
return {finder.reads , finder.writes };
330
334
}
331
335
336
+ /*
337
+ * Take a parametric expression "f" and convert it into an expression
338
+ * on the iteration domains in "domain" by reinterpreting the parameters
339
+ * as set dimensions according to the corresponding tuples in "map".
340
+ */
341
+ isl::union_pw_aff
342
+ onDomains (isl::aff f, isl::union_set domain, const IterationDomainMap& map) {
343
+ auto upa = isl::union_pw_aff::empty (domain.get_space ());
344
+ for (auto set : domain.get_set_list ()) {
345
+ auto tuple = map.at (set.get_tuple_id ()).tuple ;
346
+ auto onSet = isl::union_pw_aff (f.unbind_params_insert_domain (tuple));
347
+ upa = upa.union_add (onSet);
348
+ }
349
+ return upa;
350
+ }
351
+
352
+ } // namespace
353
+
332
354
/*
333
355
* Helper function for extracting a schedule from a Halide Stmt,
334
356
* recursively descending over the Stmt.
335
357
* "s" is the current position in the recursive descent.
336
358
* "set" describes the bounds on the outer loop iterators.
337
- * "outer" contains the names of the outer loop iterators
359
+ * "outer" contains the identifiers of the outer loop iterators
338
360
* from outermost to innermost.
339
361
* Return the schedule corresponding to the subtree at "s".
340
362
*
@@ -343,81 +365,58 @@ extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
343
365
* (for the writes) to the corresponding tag in the access relations.
344
366
* "statements" collects the mapping from instance set tuple identifiers
345
367
* to the corresponding Provide node.
346
- * "iterators " collects the mapping from instance set tuple identifiers
347
- * to the corresponding outer loop iterator names, from outermost to innermost .
368
+ * "domains " collects the mapping from instance set tuple identifiers
369
+ * to the corresponding iteration domain information .
348
370
*/
349
371
isl::schedule makeScheduleTreeHelper (
350
372
const Stmt& s,
351
373
isl::set set,
352
- std::vector<std::string>& outer,
374
+ isl::id_list outer,
353
375
isl::union_map* reads,
354
376
isl::union_map* writes,
355
377
AccessMap* accesses,
356
378
StatementMap* statements,
357
- IteratorMap* iterators ) {
379
+ IterationDomainMap* domains ) {
358
380
isl::schedule schedule;
359
381
if (auto op = s.as <For>()) {
360
- // Add one additional dimension to our set of loop variables
361
- int thisLoopIdx = set.dim (isl::dim_type::set);
362
- set = set.add_dims (isl::dim_type::set, 1 );
363
-
364
- // Make an id for this loop var. For set dimensions this is
365
- // really just for pretty-printing.
382
+ // Make an id for this loop var. It starts out as a parameter.
366
383
isl::id id (set.get_ctx (), op->name );
367
- set = set.set_dim_id (isl::dim_type::set, thisLoopIdx, id);
384
+ auto space = set.get_space (). add_param ( id);
368
385
369
- // Construct a variable (affine function) that indexes the new dimension of
370
- // this space.
371
- isl::aff loopVar (
372
- isl::local_space (set.get_space ()), isl::dim_type::set, thisLoopIdx);
386
+ // Construct a variable (affine function) that references
387
+ // the new parameter.
388
+ auto loopVar = isl::aff::param_on_domain_space (space, id);
373
389
374
390
// Then we add our new loop bound constraints.
375
- auto lbs = halide2isl::makeIslAffBoundsFromExpr (
376
- set. get_space () , op->min , false , true );
391
+ auto lbs =
392
+ halide2isl::makeIslAffBoundsFromExpr (space , op->min , false , true );
377
393
TC_CHECK_GT (lbs.size (), 0u )
378
394
<< " could not obtain polyhedral lower bounds from " << op->min ;
379
395
for (auto lb : lbs) {
380
396
set = set.intersect (loopVar.ge_set (lb));
381
397
}
382
398
383
399
Expr max = simplify (op->min + op->extent - 1 );
384
- auto ubs =
385
- halide2isl::makeIslAffBoundsFromExpr (set.get_space (), max, true , false );
400
+ auto ubs = halide2isl::makeIslAffBoundsFromExpr (space, max, true , false );
386
401
TC_CHECK_GT (ubs.size (), 0u )
387
402
<< " could not obtain polyhedral upper bounds from " << max;
388
403
for (auto ub : ubs) {
389
404
set = set.intersect (ub.ge_set (loopVar));
390
405
}
391
406
392
407
// Recursively descend.
393
- auto outerNext = outer;
394
- outerNext.push_back (op->name );
408
+ auto outerNext = outer.add (isl::id (set.get_ctx (), op->name ));
395
409
auto body = makeScheduleTreeHelper (
396
- op->body ,
397
- set,
398
- outerNext,
399
- reads,
400
- writes,
401
- accesses,
402
- statements,
403
- iterators);
410
+ op->body , set, outerNext, reads, writes, accesses, statements, domains);
404
411
405
412
// Create an affine function that defines an ordering for all
406
413
// the statements in the body of this loop over the values of
407
- // this loop. For each statement in the children we want the
408
- // function that maps everything in its space to this
409
- // dimension. The spaces may be different, but they'll all have
410
- // this loop var at the same index.
411
- isl::multi_union_pw_aff mupa;
412
- body.get_domain ().foreach_set ([&](isl::set s) {
413
- isl::aff newLoopVar (
414
- isl::local_space (s.get_space ()), isl::dim_type::set, thisLoopIdx);
415
- if (mupa) {
416
- mupa = mupa.union_add (isl::union_pw_aff (isl::pw_aff (newLoopVar)));
417
- } else {
418
- mupa = isl::union_pw_aff (isl::pw_aff (newLoopVar));
419
- }
420
- });
414
+ // this loop. Start from a parametric expression equal
415
+ // to the current loop iterator and then convert it to
416
+ // a function on the statements in the domain of the body schedule.
417
+ auto aff = isl::aff::param_on_domain_space (space, id);
418
+ auto domain = body.get_domain ();
419
+ auto mupa = isl::multi_union_pw_aff (onDomains (aff, domain, *domains));
421
420
422
421
schedule = body.insert_partial_schedule (mupa);
423
422
} else if (auto op = s.as <Halide::Internal::Block>()) {
@@ -430,7 +429,7 @@ isl::schedule makeScheduleTreeHelper(
430
429
std::vector<isl::schedule> schedules;
431
430
for (Stmt stmt : stmts) {
432
431
schedules.push_back (makeScheduleTreeHelper (
433
- stmt, set, outer, reads, writes, accesses, statements, iterators ));
432
+ stmt, set, outer, reads, writes, accesses, statements, domains ));
434
433
}
435
434
schedule = schedules[0 ].sequence (schedules[1 ]);
436
435
@@ -441,13 +440,18 @@ isl::schedule makeScheduleTreeHelper(
441
440
size_t stmtIndex = statements->size ();
442
441
isl::id id (set.get_ctx (), kStatementLabel + std::to_string (stmtIndex));
443
442
statements->emplace (id, op);
444
- iterators->emplace (id, outer);
445
- isl::set domain = set.set_tuple_id (id);
443
+ auto tupleSpace = isl::space (set.get_ctx (), 0 );
444
+ tupleSpace = tupleSpace.named_set_from_params_id (id, outer.n ());
445
+ IterationDomain iterationDomain;
446
+ iterationDomain.paramSpace = set.get_space ();
447
+ iterationDomain.tuple = isl::multi_id (tupleSpace, outer);
448
+ domains->emplace (id, iterationDomain);
449
+ auto domain = set.unbind_params (iterationDomain.tuple );
446
450
schedule = isl::schedule::from_domain (domain);
447
451
448
452
isl::union_map newReads, newWrites;
449
453
std::tie (newReads, newWrites) =
450
- halide2isl:: extractAccesses (domain , op, accesses);
454
+ extractAccesses (iterationDomain , op, accesses);
451
455
452
456
*reads = reads->unite (newReads);
453
457
*writes = writes->unite (newWrites);
@@ -464,7 +468,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
464
468
result.writes = result.reads = isl::union_map::empty (paramSpace);
465
469
466
470
// Walk the IR building a schedule tree
467
- std::vector<std::string> outer;
471
+ isl::id_list outer (paramSpace. get_ctx (), 0 ) ;
468
472
auto schedule = makeScheduleTreeHelper (
469
473
s,
470
474
isl::set::universe (paramSpace),
@@ -473,7 +477,7 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
473
477
&result.writes ,
474
478
&result.accesses ,
475
479
&result.statements ,
476
- &result.iterators );
480
+ &result.domains );
477
481
478
482
result.tree = fromIslSchedule (schedule);
479
483
0 commit comments