Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 68efb12

Browse files
authored
Merge pull request #325 from dingobye/uninitialized_reduction
Warning for uninitialized reductions.
2 parents a933afc + 3c18ab3 commit 68efb12

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

tc/lang/sema.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,20 @@ struct Sema {
490490
}
491491
}
492492

493+
// After checking rhs and before creating lhs, we check if it is a reduction
494+
// without initialization (i.e., reduction operator without "!" suffix, and
495+
// lhs not defined previously).
496+
if (isUninitializedReductionOperation(stmt.assignment()) &&
497+
nullptr == lookup(stmt.ident(), false)) {
498+
ErrorReport err(stmt);
499+
std::string tk = kindToToken(stmt.assignment()->kind());
500+
err << "Reduction without initialization. If " << stmt.ident().name()
501+
<< " is not pre-initialized before calling the TC function,"
502+
<< " consider using the !-suffixed reduction operator " << tk
503+
<< "! instead of " << tk;
504+
warn(err);
505+
}
506+
493507
auto type = TensorType::create(
494508
stmt.range(),
495509
scalar_type,
@@ -544,6 +558,17 @@ struct Sema {
544558

545559
return result;
546560
}
561+
static bool isUninitializedReductionOperation(TreeRef assignment) {
562+
switch (assignment->kind()) {
563+
case TK_PLUS_EQ:
564+
case TK_TIMES_EQ:
565+
case TK_MIN_EQ:
566+
case TK_MAX_EQ:
567+
return true;
568+
default:
569+
return false;
570+
}
571+
}
547572
bool isNotInplace(TreeRef assignment) {
548573
switch (assignment->kind()) {
549574
case TK_PLUS_EQ_B:

test/test_lang.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,23 @@ void assertSemaThrows(const std::string& errcontents, const std::string& text) {
143143
}
144144
ASSERT(threw);
145145
}
146+
void assertSemaWarns(const std::string& warncontents, const std::string& text) {
147+
Parser parser(text);
148+
auto p = parser.parseFunction();
149+
Sema sem;
150+
// Redirect std::cerr to buffer in order to capture warning message.
151+
std::stringstream buffer;
152+
std::streambuf* old = std::cerr.rdbuf(buffer.rdbuf());
153+
sem.checkFunction(p);
154+
std::cerr.rdbuf(old);
155+
// Match the captured warning message with the expected one.
156+
// Empty warncontents denotes no warning should be raised.
157+
if (warncontents.empty()) {
158+
ASSERT(buffer.str().find("WARNING:") == std::string::npos);
159+
} else {
160+
ASSERT(buffer.str().find(warncontents) != std::string::npos);
161+
}
162+
}
146163
TreeRef loadText(const std::string& text) {
147164
return Sema().checkFunction(Parser(text).parseFunction());
148165
}
@@ -315,6 +332,57 @@ int main(int argc, char** argv) {
315332
O(i) = A(i,j)
316333
}
317334
)");
335+
assertSemaWarns(
336+
"", // Legal reduction with initialization.
337+
R"(
338+
def fun(float(M,N) A) -> (O) {
339+
O(i) +=! A(i,j)
340+
}
341+
)");
342+
assertSemaWarns(
343+
"", // Legal reduction that writes to an already initialized output.
344+
R"(
345+
def fun(float(M,N) A, float(M) B) -> (O) {
346+
O(i) = B(i)
347+
O(i) *= A(i,j)
348+
}
349+
)");
350+
assertSemaWarns(
351+
"", // Legal reduction that writes to an already initialized output.
352+
R"(
353+
def fun(float(M,N) A, float(M,N) B) -> (O) {
354+
O(i) max=! A(i,j)
355+
O(i) min= B(i,j)
356+
}
357+
)");
358+
assertSemaWarns(
359+
"+=! instead of +=",
360+
R"(
361+
def fun(float(M,N) A) -> (O) {
362+
O(i) += A(i,j)
363+
}
364+
)");
365+
assertSemaWarns(
366+
"*=! instead of *=",
367+
R"(
368+
def fun(float(M,N) A) -> (O) {
369+
O(i) *= A(i,j)
370+
}
371+
)");
372+
assertSemaWarns(
373+
"max=! instead of max=",
374+
R"(
375+
def fun(float(M,N) A) -> (O) {
376+
O(i) max= A(i,j)
377+
}
378+
)");
379+
assertSemaWarns(
380+
"min=! instead of min=",
381+
R"(
382+
def fun(float(M,N) A) -> (O) {
383+
O(i) min= A(i,j)
384+
}
385+
)");
318386

319387
auto option_one = R"(
320388
def fun(float(B, N, M) X, float(B, M, K) Y) -> (Z)

0 commit comments

Comments
 (0)