Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit d9eab4b

Browse files
committed
teach sema to extract iteration variables from LHS subtrees
Sema needs the list of iteration variables on the Comprehension LHS to differentiate between reduction and non-reduction variables, the former appearing only on the RHS. Original implementation assumes Comprehension LHS is a tensor whose indices are Idents and ignores more complex constructs. With indirection support, comprehensions like O(A(i)) = B(i) are possible but i is interpreted as a reduction dimension by Sema. Traverse indices of the LHS Tensor in Comprehension recursively, inspecting subtrees of Access and Apply trees and collecting all Idents.
1 parent e9d1dc4 commit d9eab4b

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

tc/lang/sema.h

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,17 +437,35 @@ struct Sema {
437437
return checkRangeConstraint(RangeConstraint(ref));
438438
}
439439
}
440+
441+
private:
442+
// Traverse the list of trees, recursively descending into arguments of APPLY
443+
// and ACCESS subtrees and into all subtrees of different types (mostly
444+
// expressions), and collect names and types of IDENT subtrees in
445+
// "index_env". Expects to be called on the indices of the LHS tensor.
446+
template <typename Collection>
447+
void registerLHSIndices(const Collection& treeRefs) {
448+
for (const auto& treeRef : treeRefs) {
449+
if (treeRef->kind() == TK_IDENT) {
450+
std::string idx = Ident(treeRef).name();
451+
auto typ = indexType(treeRef);
452+
insert(index_env, Ident(treeRef), typ, true);
453+
} else if (treeRef->kind() == TK_APPLY) {
454+
registerLHSIndices(Apply(treeRef).arguments());
455+
} else if (treeRef->kind() == TK_ACCESS) {
456+
registerLHSIndices(Access(treeRef).arguments());
457+
} else {
458+
registerLHSIndices(treeRef->trees());
459+
}
460+
}
461+
}
462+
463+
public:
440464
TreeRef checkStmt(TreeRef stmt_) {
441465
auto stmt = Comprehension(stmt_);
442466

443467
// register index variables (non-reductions)
444-
for (const auto& index : stmt.indices()) {
445-
if (index->kind() == TK_IDENT) {
446-
std::string idx = Ident(index).name();
447-
auto typ = indexType(index);
448-
insert(index_env, Ident(index), typ, true);
449-
}
450-
}
468+
registerLHSIndices(stmt.indices());
451469

452470
// make dimension variables for each dimension of the output tensor
453471
std::string name = stmt.ident().name();
@@ -464,6 +482,7 @@ struct Sema {
464482

465483
// where clauses are checked _before_ the rhs because they
466484
// introduce let bindings that are in scope for the rhs
485+
//
467486
auto where_clauses_ = stmt.whereClauses().map(
468487
[&](TreeRef rc) { return checkWhereClause(rc); });
469488

test/test_core.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,19 @@ def gather(int32(N) A, int32(N) B) -> (O) {
346346
CHECK_EQ(scop->mustWrites, scop->mayWrites);
347347
}
348348

349+
TEST_F(TC2Isl, Computed) {
350+
string tc = R"TC(
351+
def gather(int32(N) A, int32(N) B) -> (O) {
352+
O(i - 2) = A(i) + B(i)
353+
}
354+
)TC";
355+
auto scop = MakeScop(tc);
356+
CHECK(!scop->mustWrites.is_empty())
357+
<< "expected non-empty must-writes for gather, got\n"
358+
<< scop->mustWrites;
359+
CHECK_EQ(scop->mustWrites, scop->mayWrites);
360+
}
361+
349362
// FIXME: range inference seems unaware of indirections on the LHS
350363
TEST_F(TC2Isl, DISABLED_MustWritesSubsetMayWrites) {
351364
string tc = R"TC(

0 commit comments

Comments
 (0)