17
17
18
18
#include < algorithm>
19
19
#include < numeric>
20
+ #include < tuple>
20
21
#include < unordered_set>
21
22
22
23
#include " tc/core/constants.h"
@@ -238,7 +239,20 @@ isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) {
238
239
return context;
239
240
}
240
241
241
- isl::map extractAccess (
242
+ // Extract a tagged affine access relation from Halide IR.
243
+ // The relation is tagged with a unique identifier, i.e. it lives in the space
244
+ // [D[...] -> __tc_ref_#[]] -> A[]
245
+ // where # is a unique sequential number, D is the statement identifier
246
+ // extracted from "domain" and A is the tensor identifier constructed from
247
+ // "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
248
+ // which a particular reference # appeared.
249
+ // Returns the access relation and a flag indicating whether this relation is
250
+ // exact or not. The relation is overapproximated (that is, not exact) if it
251
+ // represents a non-affine access, for example, an access with indirection such
252
+ // as O(Index(i)) = 42. In such overapproximated access relation, dimensions
253
+ // that correspond to affine subscripts are still exact while those that
254
+ // correspond to non-affine subscripts are not constrained.
255
+ std::pair<isl::map, bool > extractAccess (
242
256
isl::set domain,
243
257
const IRNode* op,
244
258
const std::string& tensor,
@@ -267,6 +281,7 @@ isl::map extractAccess(
267
281
isl::map map =
268
282
isl::map::universe (domainSpace.map_from_domain_and_range (rangeSpace));
269
283
284
+ bool exact = true ;
270
285
for (size_t i = 0 ; i < args.size (); i++) {
271
286
// Then add one equality constraint per dimension to encode the
272
287
// point in the allocation actually read/written for each point in
@@ -278,47 +293,64 @@ isl::map extractAccess(
278
293
isl::pw_aff (isl::local_space (rangeSpace), isl::dim_type::set, i);
279
294
// ... equals the coordinate accessed as a function of the domain.
280
295
auto domainPoint = halide2isl::makeIslAffFromExpr (domainSpace, args[i]);
281
- if (!domainPoint.is_null ()) {
296
+ if (!domainPoint) {
297
+ exact = false ;
298
+ } else {
282
299
map = map.intersect (isl::pw_aff (domainPoint).eq_map (rangePoint));
283
300
}
284
301
}
285
302
286
- return map;
303
+ return std::make_pair ( map, exact) ;
287
304
}
288
305
289
- std::pair< isl::union_map, isl::union_map>
306
+ std::tuple<isl::union_map, isl::union_map, isl::union_map>
290
307
extractAccesses (isl::set domain, const Stmt& s, AccessMap* accesses) {
291
308
class FindAccesses : public IRGraphVisitor {
292
309
using IRGraphVisitor::visit;
293
310
294
311
void visit (const Call* op) override {
295
312
IRGraphVisitor::visit (op);
296
313
if (op->call_type == Call::Halide || op->call_type == Call::Image) {
297
- reads = reads.unite (
298
- extractAccess (domain, op, op->name , op->args , accesses));
314
+ // Read relations can be safely overapproximated.
315
+ isl::map read;
316
+ std::tie (read, std::ignore) =
317
+ extractAccess (domain, op, op->name , op->args , accesses);
318
+ reads = reads.unite (read);
299
319
}
300
320
}
301
321
302
322
void visit (const Provide* op) override {
303
323
IRGraphVisitor::visit (op);
304
- writes =
305
- writes.unite (extractAccess (domain, op, op->name , op->args , accesses));
324
+
325
+ // If the write access relation is not exact, we consider that any
326
+ // element _may_ be written by the statement. If it is exact, then we
327
+ // can guarantee that all the elements specified by the relation _must_
328
+ // be written and any previously stored value will be killed.
329
+ isl::map write;
330
+ bool exact;
331
+ std::tie (write, exact) =
332
+ extractAccess (domain, op, op->name , op->args , accesses);
333
+ if (exact) {
334
+ mustWrites = mustWrites.unite (write);
335
+ }
336
+ mayWrites = mayWrites.unite (write);
306
337
}
307
338
308
339
const isl::set& domain;
309
340
AccessMap* accesses;
310
341
311
342
public:
312
- isl::union_map reads, writes ;
343
+ isl::union_map reads, mayWrites, mustWrites ;
313
344
314
345
FindAccesses (const isl::set& domain, AccessMap* accesses)
315
346
: domain(domain),
316
347
accesses (accesses),
317
348
reads(isl::union_map::empty(domain.get_space())),
318
- writes(isl::union_map::empty(domain.get_space())) {}
349
+ mayWrites(isl::union_map::empty(domain.get_space())),
350
+ mustWrites(isl::union_map::empty(domain.get_space())) {}
319
351
} finder(domain, accesses);
320
352
s.accept(&finder);
321
- return { finder.reads , finder.writes } ;
353
+ return std::make_tuple( finder.reads, finder.mayWrites, finder.mustWrites) ;
322
354
}
323
355
324
356
/*
@@ -343,7 +375,8 @@ isl::schedule makeScheduleTreeHelper(
343
375
isl::set set,
344
376
std::vector<std::string>& outer,
345
377
isl::union_map* reads,
346
- isl::union_map* writes,
378
+ isl::union_map* mayWrites,
379
+ isl::union_map* mustWrites,
347
380
AccessMap* accesses,
348
381
StatementMap* statements,
349
382
IteratorMap* iterators) {
@@ -389,7 +422,8 @@ isl::schedule makeScheduleTreeHelper(
389
422
set,
390
423
outerNext,
391
424
reads,
392
- writes,
425
+ mayWrites,
426
+ mustWrites,
393
427
accesses,
394
428
statements,
395
429
iterators);
@@ -422,7 +456,15 @@ isl::schedule makeScheduleTreeHelper(
422
456
std::vector<isl::schedule> schedules;
423
457
for (Stmt s : stmts) {
424
458
schedules.push_back (makeScheduleTreeHelper (
425
- s, set, outer, reads, writes, accesses, statements, iterators));
459
+ s,
460
+ set,
461
+ outer,
462
+ reads,
463
+ mayWrites,
464
+ mustWrites,
465
+ accesses,
466
+ statements,
467
+ iterators));
426
468
}
427
469
schedule = schedules[0 ].sequence (schedules[1 ]);
428
470
@@ -437,23 +479,25 @@ isl::schedule makeScheduleTreeHelper(
437
479
isl::set domain = set.set_tuple_id (id);
438
480
schedule = isl::schedule::from_domain (domain);
439
481
440
- isl::union_map newReads, newWrites ;
441
- std::tie (newReads, newWrites ) =
482
+ isl::union_map newReads, newMayWrites, newMustWrites ;
483
+ std::tie (newReads, newMayWrites, newMustWrites ) =
442
484
halide2isl::extractAccesses (domain, op, accesses);
443
485
444
486
*reads = reads->unite (newReads);
445
- *writes = writes->unite (newWrites);
487
+ *mayWrites = mayWrites->unite (newMayWrites);
488
+ *mustWrites = mustWrites->unite (newMustWrites);
446
489
447
490
} else {
448
491
LOG (FATAL) << " Unhandled Halide stmt: " << s;
449
492
}
450
493
return schedule;
451
- };
494
+ }
452
495
453
496
ScheduleTreeAndAccesses makeScheduleTree (isl::space paramSpace, const Stmt& s) {
454
497
ScheduleTreeAndAccesses result;
455
498
456
- result.writes = result.reads = isl::union_map::empty (paramSpace);
499
+ result.mayWrites = result.mustWrites = result.reads =
500
+ isl::union_map::empty (paramSpace);
457
501
458
502
// Walk the IR building a schedule tree
459
503
std::vector<std::string> outer;
@@ -462,7 +506,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
462
506
isl::set::universe (paramSpace),
463
507
outer,
464
508
&result.reads ,
465
- &result.writes ,
509
+ &result.mayWrites ,
510
+ &result.mustWrites ,
466
511
&result.accesses ,
467
512
&result.statements ,
468
513
&result.iterators );
0 commit comments