Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit ef2d882

Browse files
zdevitoftynse
authored andcommitted
Parser rules for min/max
1 parent e7cd425 commit ef2d882

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

include/tc/lang/parser.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ struct Parser {
5555
auto value = parseExp();
5656
L.expect(')');
5757
return Cast::create(type->range(), value, type);
58+
} else if (L.cur().kind == TK_MIN || L.cur().kind == TK_MAX) {
59+
// min/max are treated as operators later in the compilation pipeline.
60+
// so we ensure they have precisely two arguments here so they can
61+
// use the same pathways as other operators like + where argument
62+
// count is ensured by parsing
63+
auto range = L.cur().range;
64+
auto tok = L.next().kind;
65+
L.expect('(');
66+
auto a = parseExp();
67+
L.expect(',');
68+
auto b = parseExp();
69+
L.expect(')');
70+
prefix = c(tok, range, {a, b});
5871
} else {
5972
prefix = parseIdent();
6073
auto range = L.cur().range;

test/cuda/test_corner_cases.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,32 @@ TEST(TestCornerCases, E20) {
243243
{F(1)});
244244
}
245245

246+
TEST(TestCornerCases, E21) {
247+
auto a = F(1);
248+
auto b = F(1);
249+
auto c = F(1);
250+
Succeed(
251+
"def f(float(1) a, float(1) b) -> (c) { c(i) = max(a(i), b(i)) }",
252+
{a, b},
253+
{c});
254+
CHECK_EQ(
255+
fmaxf(at::Scalar(a[0]).toFloat(), at::Scalar(b[0]).toFloat()),
256+
at::Scalar(c[0]).toFloat());
257+
}
258+
259+
TEST(TestCornerCases, E22) {
260+
auto a = F(1);
261+
auto b = F(1);
262+
auto c = F(1);
263+
Succeed(
264+
"def f(float(1) a, float(1) b) -> (c) { c(i) = min(a(i), b(i)) }",
265+
{a, b},
266+
{c});
267+
CHECK_EQ(
268+
fminf(at::Scalar(a[0]).toFloat(), at::Scalar(b[0]).toFloat()),
269+
at::Scalar(c[0]).toFloat());
270+
}
271+
246272
int main(int argc, char** argv) {
247273
::testing::InitGoogleTest(&argc, argv);
248274
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)