Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 989c9b9

Browse files
committed
Backporting comparison operators support
1 parent de59ea2 commit 989c9b9

File tree

6 files changed

+116
-5
lines changed

6 files changed

+116
-5
lines changed

include/tc/core/libraries.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ constexpr auto boundsAsTemplate = R"C(
140140
template<typename T> inline __device__ T floord(T n, T d) {
141141
return n < 0 ? - (-n + d - 1)/d : n / d;
142142
}
143+
#define if_then_else(cond,a,b) (cond) ? (a) : (b);
143144
)C";
144145
} // namespace cpp
145146

include/tc/lang/lexer.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,15 @@ namespace lang {
7777
_(TK_UINT64, "uint64", "uint64") \
7878
_(TK_BOOL, "bool", "bool") \
7979
_(TK_CAST, "cast", "") \
80-
_(TK_IN, "in", "in")
80+
_(TK_IN, "in", "in") \
81+
_(TK_GE, "ge", ">=") \
82+
_(TK_LE, "le", "<=") \
83+
_(TK_EQ, "eq", "==") \
84+
_(TK_NE, "neq", "!=") \
85+
_(TK_AND, "and", "&&") \
86+
_(TK_OR, "or", "||")
8187

82-
static const char* valid_single_char_tokens = "+-*/()[]?:,={}>";
88+
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
8389

8490
enum TokenKind {
8591
// we use characters to represent themselves so skip all valid characters
@@ -121,11 +127,14 @@ struct SharedParserData {
121127
// listed in increasing order of precedence
122128
std::vector<std::vector<int>> binary_ops = {
123129
{'?'},
130+
{TK_OR},
131+
{TK_AND},
132+
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
124133
{'+', '-'},
125134
{'*', '/'},
126135
};
127136
std::vector<std::vector<int>> unary_ops = {
128-
{'-'},
137+
{'-', '!'},
129138
};
130139

131140
std::stringstream ss;

include/tc/lang/sema.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,16 @@ struct Sema {
194194
}
195195
return e;
196196
}
197+
void expectBool(TreeRef anchor, int token) {
198+
if (token != TK_BOOL) {
199+
throw ErrorReport(anchor)
200+
<< "expected boolean but found " << kindToString(token);
201+
}
202+
}
203+
TreeRef expectBool(TreeRef exp) {
204+
expectBool(exp, typeOfExpr(exp)->kind());
205+
return exp;
206+
}
197207
TreeRef checkExp(TreeRef exp, bool allow_access) {
198208
switch (exp->kind()) {
199209
case TK_APPLY: {
@@ -205,6 +215,7 @@ struct Sema {
205215
throw ErrorReport(exp)
206216
<< "tensor accesses cannot be used in this context";
207217
}
218+
208219
// also handle built-in functions log, exp, etc.
209220
auto ident = a.name();
210221
if (builtin_functions.count(ident.name()) > 0) {
@@ -276,6 +287,35 @@ struct Sema {
276287
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
277288
return withType(nexp, matchAllTypes(nexp));
278289
} break;
290+
case TK_EQ:
291+
case TK_NE:
292+
case TK_GE:
293+
case TK_LE:
294+
case '<':
295+
case '>': {
296+
auto nexp =
297+
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
298+
// make sure the types match but the return type
299+
// is always bool
300+
matchAllTypes(nexp);
301+
return withType(nexp, boolType(exp));
302+
} break;
303+
case TK_AND:
304+
case TK_OR:
305+
case '!': {
306+
auto nexp =
307+
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
308+
expectBool(exp, matchAllTypes(nexp)->kind());
309+
return withType(nexp, boolType(exp));
310+
} break;
311+
case '?': {
312+
auto nexp =
313+
exp->map([&](TreeRef c) { return checkExp(c, allow_access); });
314+
expectBool(nexp->tree(0));
315+
auto rtype =
316+
match_types(typeOfExpr(nexp->tree(1)), typeOfExpr(nexp->tree(2)));
317+
return withType(nexp, rtype);
318+
}
279319
case TK_CONST: {
280320
auto c = Const(exp);
281321
return withType(exp, c.type());
@@ -322,6 +362,9 @@ struct Sema {
322362
TreeRef floatType(TreeRef anchor) {
323363
return c(TK_FLOAT, anchor->range(), {});
324364
}
365+
TreeRef boolType(TreeRef anchor) {
366+
return c(TK_BOOL, anchor->range(), {});
367+
}
325368
void checkDim(const Ident& dim) {
326369
insert(env, dim, dimType(dim), false);
327370
}

src/core/tc2halide.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,24 @@ Expr translateExpr(
162162
{cond, true_val, false_val},
163163
Call::Intrinsic);
164164
}
165+
case lang::TK_EQ:
166+
return t(0) == t(1);
167+
case lang::TK_NE:
168+
return t(0) != t(1);
169+
case lang::TK_LE:
170+
return t(0) <= t(1);
171+
case lang::TK_GE:
172+
return t(0) >= t(1);
173+
case '<':
174+
return t(0) < t(1);
175+
case '>':
176+
return t(0) > t(1);
177+
case '!':
178+
return !t(0);
179+
case lang::TK_AND:
180+
return t(0) && t(1);
181+
case lang::TK_OR:
182+
return t(0) || t(1);
165183
case lang::TK_BUILT_IN: {
166184
auto b = lang::BuiltIn(expr);
167185
vector<Expr> exprs;

test/test_corner_cases.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,46 @@ TEST(FailTest, DISABLED_E14) {
175175
{I(10)});
176176
}
177177

178+
TEST(FailTest, E15){
179+
#define GEN_COMPARATOR(op) \
180+
{ \
181+
auto a = F(); \
182+
auto b = F(); \
183+
auto c = F(1); \
184+
Succeed( \
185+
"def f(float a, float b) -> (c) { c(i) = float(a " #op \
186+
" b) where i in 0:1 }", \
187+
{a, b}, \
188+
{c}); \
189+
auto r = at::Scalar(a).toFloat() op at::Scalar(b).toFloat(); \
190+
CHECK_EQ(r, at::Scalar(c[0]).toFloat()); \
191+
}
192+
193+
GEN_COMPARATOR(<=) GEN_COMPARATOR(>=) GEN_COMPARATOR(==) GEN_COMPARATOR(!=)
194+
GEN_COMPARATOR(<) GEN_COMPARATOR(>)
195+
196+
}
197+
198+
TEST(FailTest, E16) {
199+
#define GEN_BOOLS(op) \
200+
{ \
201+
auto a = F(); \
202+
auto b = F(); \
203+
auto c = F(1); \
204+
Succeed( \
205+
"def f(float a, float b) -> (c) { c(i) = float(!(a < .5) " #op \
206+
" b > .5) where i in 0:1 }", \
207+
{a, b}, \
208+
{c}); \
209+
auto r = !(at::Scalar(a).toFloat() < .5) op at::Scalar(b).toFloat() > .5; \
210+
; \
211+
CHECK_EQ(r, at::Scalar(c[0]).toFloat()); \
212+
}
213+
214+
GEN_BOOLS(||)
215+
GEN_BOOLS(&&)
216+
}
217+
178218
int main(int argc, char** argv) {
179219
::testing::InitGoogleTest(&argc, argv);
180220
::gflags::ParseCommandLineFlags(&argc, &argv, true);

test/test_execution_engine.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ struct ATenCompilationUnitTest : public ::testing::Test {
4545
}
4646
};
4747

48-
TEST_F(ATenCompilationUnitTest, DISABLED_Concat) {
48+
TEST_F(ATenCompilationUnitTest, Concat) {
4949
at::Tensor a = at::CUDA(at::kFloat).rand({32, 16});
5050
at::Tensor b = at::CUDA(at::kFloat).rand({32, 16});
5151
std::vector<at::Tensor> inputs = {a, b};
5252
std::vector<at::Tensor> outputs;
5353

5454
Check(
5555
R"(
56-
def concat(float(M, N) A, float(M, N) B) -> (O1, O2) {
56+
def concat(float(M, N) A, float(M, N) B) -> (O1) {
5757
O1(n, i, m) = i == 0 ? A(m, n) : B(m, n) where i in 0:2
5858
}
5959
)",

0 commit comments

Comments
 (0)