Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit cec3512

Browse files
Ensure input immutability
We have been talking about input immutability for a while in TC but it was enver actually implemented. This adds a simple check to semantic analysis as well as a unit test to expect that input parameters are immutable. Note that this is not foolproof because one may always abuse tmp/output tensors to be a view into a subportion of an input tensor but this is not intended to be commonplace. In the future we may want to relax input immutability.
1 parent bfd6d1f commit cec3512

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

include/tc/lang/sema.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ static inline TreeRef match_types(TreeRef a, TreeRef b) {
155155
/// - replace TK_APPLY with TK_BUILT_IN for built in functions
156156
/// - checks that all variables are defined, and creates index/reduction
157157
/// variable objects.
158+
/// - checks that input variables are readonly.
158159
struct Sema {
159160
std::unordered_map<TreeRef, TreeRef> expr_to_type;
160161

@@ -349,10 +350,13 @@ struct Sema {
349350
}
350351
}
351352

352-
for (auto p : func.params())
353+
for (auto p : func.params()) {
353354
nonTemporaries.insert(p.ident().name());
354-
for (auto r : func.returns())
355+
inputParameters.insert(p.ident().name());
356+
}
357+
for (auto r : func.returns()) {
355358
nonTemporaries.insert(r.ident().name());
359+
}
356360

357361
auto statements_ =
358362
checkList(func.statements(), [&](TreeRef r) { return checkStmt(r); });
@@ -445,6 +449,9 @@ struct Sema {
445449

446450
// make dimension variables for each dimension of the output tensor
447451
std::string name = stmt.ident().name();
452+
if (inputParameters.count(name) > 0) {
453+
throw ErrorReport(stmt_) << "TC inputs are immutable";
454+
}
448455
TreeList output_indices;
449456
int n = stmt.indices().size();
450457
for (int i = 0; i < n; ++i) {
@@ -614,6 +621,7 @@ struct Sema {
614621
// allowed
615622
std::unordered_set<std::string> live_input_names;
616623

624+
std::unordered_set<std::string> inputParameters;
617625
std::unordered_set<std::string> nonTemporaries;
618626
};
619627
} // namespace lang

test/test_tc2halide.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ def fun(float(M, N) I) -> (O1, O2, O3) {
197197
Check(tc, {123, 13});
198198
}
199199

200+
TEST_F(TC2Isl, MutableInput) {
201+
string tc = R"TC(
202+
def foo(float(N) A) -> (B) {
203+
A(i) = A(i) + 42
204+
B(k) +=! A(i) where k in 0:1
205+
}
206+
)TC";
207+
EXPECT_THROW(Check(tc, {123}), ::lang::ErrorReport);
208+
}
200209
int main(int argc, char** argv) {
201210
::testing::InitGoogleTest(&argc, argv);
202211
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)