Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit e9d1dc4

Browse files
committed
extract may/must writes from Halide IR
Previous commits introduced may/must writes in Scop and dependence analysis. Extract those from Halide IR. Change extractAccess to return a flag indicating whether the affine access relation is constructed is exact or not. Exact relations correspond to must writes since we statically know which tensor elements are written. Inexact relations overapproximate non-affine accesses and should be treated as may writes, assuming the tensor elements are not necessarily written.
1 parent 16ff65a commit e9d1dc4

File tree

4 files changed

+122
-23
lines changed

4 files changed

+122
-23
lines changed

tc/core/halide2isl.cc

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <algorithm>
1919
#include <numeric>
20+
#include <tuple>
2021
#include <unordered_set>
2122

2223
#include "tc/core/constants.h"
@@ -238,7 +239,20 @@ isl::set makeParamContext(isl::ctx ctx, const SymbolTable& symbolTable) {
238239
return context;
239240
}
240241

241-
isl::map extractAccess(
242+
// Extract a tagged affine access relation from Halide IR.
243+
// The relation is tagged with a unique identifier, i.e. it lives in the space
244+
// [D[...] -> __tc_ref_#[]] -> A[]
245+
// where # is a unique sequential number, D is the statement identifier
246+
// extracted from "domain" and A is the tensor identifier constructed from
247+
// "tensor". "accesses" map is updated to keep track of the Halide IR nodes in
248+
// which a particular reference # appeared.
249+
// Returns the access relation and a flag indicating whether this relation is
250+
// exact or not. The relation is overapproximated (that is, not exact) if it
251+
// represents a non-affine access, for example, an access with indirection such
252+
// as O(Index(i)) = 42. In such overapproximated access relation, dimensions
253+
// that correspond to affine subscripts are still exact while those that
254+
// correspond to non-affine subscripts are not constrained.
255+
std::pair<isl::map, bool> extractAccess(
242256
isl::set domain,
243257
const IRNode* op,
244258
const std::string& tensor,
@@ -267,6 +281,7 @@ isl::map extractAccess(
267281
isl::map map =
268282
isl::map::universe(domainSpace.map_from_domain_and_range(rangeSpace));
269283

284+
bool exact = true;
270285
for (size_t i = 0; i < args.size(); i++) {
271286
// Then add one equality constraint per dimension to encode the
272287
// point in the allocation actually read/written for each point in
@@ -278,47 +293,64 @@ isl::map extractAccess(
278293
isl::pw_aff(isl::local_space(rangeSpace), isl::dim_type::set, i);
279294
// ... equals the coordinate accessed as a function of the domain.
280295
auto domainPoint = halide2isl::makeIslAffFromExpr(domainSpace, args[i]);
281-
if (!domainPoint.is_null()) {
296+
if (!domainPoint) {
297+
exact = false;
298+
} else {
282299
map = map.intersect(isl::pw_aff(domainPoint).eq_map(rangePoint));
283300
}
284301
}
285302

286-
return map;
303+
return std::make_pair(map, exact);
287304
}
288305

289-
std::pair<isl::union_map, isl::union_map>
306+
std::tuple<isl::union_map, isl::union_map, isl::union_map>
290307
extractAccesses(isl::set domain, const Stmt& s, AccessMap* accesses) {
291308
class FindAccesses : public IRGraphVisitor {
292309
using IRGraphVisitor::visit;
293310

294311
void visit(const Call* op) override {
295312
IRGraphVisitor::visit(op);
296313
if (op->call_type == Call::Halide || op->call_type == Call::Image) {
297-
reads = reads.unite(
298-
extractAccess(domain, op, op->name, op->args, accesses));
314+
// Read relations can be safely overapproximated.
315+
isl::map read;
316+
std::tie(read, std::ignore) =
317+
extractAccess(domain, op, op->name, op->args, accesses);
318+
reads = reads.unite(read);
299319
}
300320
}
301321

302322
void visit(const Provide* op) override {
303323
IRGraphVisitor::visit(op);
304-
writes =
305-
writes.unite(extractAccess(domain, op, op->name, op->args, accesses));
324+
325+
// If the write access relation is not exact, we consider that any
326+
// element _may_ be written by the statement. If it is exact, then we
327+
// can guarantee that all the elements specified by the relation _must_
328+
// be written and any previously stored value will be killed.
329+
isl::map write;
330+
bool exact;
331+
std::tie(write, exact) =
332+
extractAccess(domain, op, op->name, op->args, accesses);
333+
if (exact) {
334+
mustWrites = mustWrites.unite(write);
335+
}
336+
mayWrites = mayWrites.unite(write);
306337
}
307338

308339
const isl::set& domain;
309340
AccessMap* accesses;
310341

311342
public:
312-
isl::union_map reads, writes;
343+
isl::union_map reads, mayWrites, mustWrites;
313344

314345
FindAccesses(const isl::set& domain, AccessMap* accesses)
315346
: domain(domain),
316347
accesses(accesses),
317348
reads(isl::union_map::empty(domain.get_space())),
318-
writes(isl::union_map::empty(domain.get_space())) {}
349+
mayWrites(isl::union_map::empty(domain.get_space())),
350+
mustWrites(isl::union_map::empty(domain.get_space())) {}
319351
} finder(domain, accesses);
320352
s.accept(&finder);
321-
return {finder.reads, finder.writes};
353+
return std::make_tuple(finder.reads, finder.mayWrites, finder.mustWrites);
322354
}
323355

324356
/*
@@ -343,7 +375,8 @@ isl::schedule makeScheduleTreeHelper(
343375
isl::set set,
344376
std::vector<std::string>& outer,
345377
isl::union_map* reads,
346-
isl::union_map* writes,
378+
isl::union_map* mayWrites,
379+
isl::union_map* mustWrites,
347380
AccessMap* accesses,
348381
StatementMap* statements,
349382
IteratorMap* iterators) {
@@ -389,7 +422,8 @@ isl::schedule makeScheduleTreeHelper(
389422
set,
390423
outerNext,
391424
reads,
392-
writes,
425+
mayWrites,
426+
mustWrites,
393427
accesses,
394428
statements,
395429
iterators);
@@ -422,7 +456,15 @@ isl::schedule makeScheduleTreeHelper(
422456
std::vector<isl::schedule> schedules;
423457
for (Stmt s : stmts) {
424458
schedules.push_back(makeScheduleTreeHelper(
425-
s, set, outer, reads, writes, accesses, statements, iterators));
459+
s,
460+
set,
461+
outer,
462+
reads,
463+
mayWrites,
464+
mustWrites,
465+
accesses,
466+
statements,
467+
iterators));
426468
}
427469
schedule = schedules[0].sequence(schedules[1]);
428470

@@ -437,23 +479,25 @@ isl::schedule makeScheduleTreeHelper(
437479
isl::set domain = set.set_tuple_id(id);
438480
schedule = isl::schedule::from_domain(domain);
439481

440-
isl::union_map newReads, newWrites;
441-
std::tie(newReads, newWrites) =
482+
isl::union_map newReads, newMayWrites, newMustWrites;
483+
std::tie(newReads, newMayWrites, newMustWrites) =
442484
halide2isl::extractAccesses(domain, op, accesses);
443485

444486
*reads = reads->unite(newReads);
445-
*writes = writes->unite(newWrites);
487+
*mayWrites = mayWrites->unite(newMayWrites);
488+
*mustWrites = mustWrites->unite(newMustWrites);
446489

447490
} else {
448491
LOG(FATAL) << "Unhandled Halide stmt: " << s;
449492
}
450493
return schedule;
451-
};
494+
}
452495

453496
ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
454497
ScheduleTreeAndAccesses result;
455498

456-
result.writes = result.reads = isl::union_map::empty(paramSpace);
499+
result.mayWrites = result.mustWrites = result.reads =
500+
isl::union_map::empty(paramSpace);
457501

458502
// Walk the IR building a schedule tree
459503
std::vector<std::string> outer;
@@ -462,7 +506,8 @@ ScheduleTreeAndAccesses makeScheduleTree(isl::space paramSpace, const Stmt& s) {
462506
isl::set::universe(paramSpace),
463507
outer,
464508
&result.reads,
465-
&result.writes,
509+
&result.mayWrites,
510+
&result.mustWrites,
466511
&result.accesses,
467512
&result.statements,
468513
&result.iterators);

tc/core/halide2isl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ struct ScheduleTreeAndAccesses {
7070
/// Union maps describing the reads and writes done. Uses the ids in
7171
/// the schedule tree to denote the containing Stmt, and tags each
7272
/// access with a unique reference id of the form __tc_ref_N.
73-
isl::union_map reads, writes;
73+
isl::union_map reads, mayWrites, mustWrites;
7474

7575
/// The correspondence between from Call and Provide nodes and the
7676
/// reference ids in the reads and writes maps.

tc/core/polyhedral/scop.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ ScopUPtr Scop::makeScop(
6161
auto tree = halide2isl::makeScheduleTree(paramSpace, components.stmt);
6262
scop->scheduleTreeUPtr = std::move(tree.tree);
6363
scop->reads = tree.reads;
64-
scop->mayWrites = tree.writes;
65-
scop->mustWrites = isl::union_map::empty(scop->mayWrites.get_space());
64+
scop->mayWrites = tree.mayWrites;
65+
scop->mustWrites = tree.mustWrites;
6666
scop->halide.statements = std::move(tree.statements);
6767
scop->halide.accesses = std::move(tree.accesses);
6868
scop->halide.reductions = halide2isl::findReductions(components.stmt);

test/test_core.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ struct TC2Isl : public ::testing::Test {
165165
auto scheduleHalide = polyhedral::detail::fromIslSchedule(
166166
polyhedral::detail::toIslSchedule(scop->scheduleRoot()).reset_user());
167167
}
168+
169+
std::unique_ptr<polyhedral::Scop> MakeScop(const std::string& tc) {
170+
return polyhedral::Scop::makeScop(isl::with_exceptions::globalIslCtx(), tc);
171+
}
168172
};
169173

170174
TEST_F(TC2Isl, Copy1D) {
@@ -313,6 +317,56 @@ def fun(float(M, N) I) -> (O1, O2, O3) {
313317
Check(tc, {123, 13});
314318
}
315319

320+
// FIXME: range inference seems unaware of indirections on the LHS
321+
TEST_F(TC2Isl, DISABLED_MayWritesOnly) {
322+
string tc = R"TC(
323+
def scatter(int32(N) A, int32(M) B) -> (O) {
324+
O(A(i)) = B(i)
325+
}
326+
)TC";
327+
auto scop = MakeScop(tc);
328+
CHECK(scop->mustWrites.is_empty())
329+
<< "expected empty must-writes for scatter, got\n"
330+
<< scop->mustWrites;
331+
CHECK(!scop->mayWrites.is_empty())
332+
<< "expected non-empty may-writes for scatter, got\n"
333+
<< scop->mayWrites;
334+
}
335+
336+
TEST_F(TC2Isl, AllMustWrites) {
337+
string tc = R"TC(
338+
def gather(int32(N) A, int32(N) B) -> (O) {
339+
O(i) = A(B(i)) where i in 0:N
340+
}
341+
)TC";
342+
auto scop = MakeScop(tc);
343+
CHECK(!scop->mustWrites.is_empty())
344+
<< "expected non-empty must-writes for gather, got\n"
345+
<< scop->mustWrites;
346+
CHECK_EQ(scop->mustWrites, scop->mayWrites);
347+
}
348+
349+
// FIXME: range inference seems unaware of indirections on the LHS
350+
TEST_F(TC2Isl, DISABLED_MustWritesSubsetMayWrites) {
351+
string tc = R"TC(
352+
def scatter_gather(int32(N) A, int32(N) B) -> (O1,O2) {
353+
O1(i) = A(B(i)) where i in 0:N
354+
O2(A(i)) = B(i)
355+
}
356+
)TC";
357+
auto scop = MakeScop(tc);
358+
CHECK(!scop->mustWrites.is_empty()) << "expected non-empty must-writes, got\n"
359+
<< scop->mustWrites;
360+
CHECK(!scop->mayWrites.is_empty()) << "expected non-empty may-writes, got\n"
361+
<< scop->mustWrites;
362+
CHECK(scop->mustWrites.is_subset(scop->mayWrites))
363+
<< scop->mustWrites << " is expected to be a subsetset of "
364+
<< scop->mayWrites;
365+
CHECK(!scop->mayWrites.subtract(scop->mustWrites).is_empty())
366+
<< scop->mustWrites << "is expected to be a strict subset of "
367+
<< scop->mayWrites;
368+
}
369+
316370
int main(int argc, char** argv) {
317371
::testing::InitGoogleTest(&argc, argv);
318372
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)