Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 8f32ae3

Browse files
authored
Merge pull request #270 from facebookresearch/pr/minmax
Parser rules for min/max
2 parents e236bd2 + 159908a commit 8f32ae3

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

include/tc/lang/parser.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ struct Parser {
5555
auto value = parseExp();
5656
L.expect(')');
5757
return Cast::create(type->range(), value, type);
58+
} else if (L.cur().kind == TK_MIN || L.cur().kind == TK_MAX) {
59+
// min/max are treated as operators later in the compilation pipeline.
60+
// so we ensure they have precisely two arguments here so they can
61+
// use the same pathways as other operators like + where argument
62+
// count is ensured by parsing
63+
auto range = L.cur().range;
64+
auto tok = L.next().kind;
65+
L.expect('(');
66+
auto a = parseExp();
67+
L.expect(',');
68+
auto b = parseExp();
69+
L.expect(')');
70+
prefix = c(tok, range, {a, b});
5871
} else {
5972
prefix = parseIdent();
6073
auto range = L.cur().range;

test/cuda/test_corner_cases.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,48 @@ TEST(TestCornerCases, E20) {
243243
{F(1)});
244244
}
245245

246+
TEST(TestCornerCases, E21) {
247+
auto a = F(1);
248+
auto b = F(1);
249+
auto c = F(1);
250+
Succeed(
251+
"def f(float(1) a, float(1) b) -> (c) { c(i) = max(a(i), b(i)) }",
252+
{a, b},
253+
{c});
254+
CHECK_EQ(
255+
fmaxf(at::Scalar(a[0]).toFloat(), at::Scalar(b[0]).toFloat()),
256+
at::Scalar(c[0]).toFloat());
257+
}
258+
259+
TEST(TestCornerCases, E22) {
260+
auto a = F(1);
261+
auto b = F(1);
262+
auto c = F(1);
263+
Succeed(
264+
"def f(float(1) a, float(1) b) -> (c) { c(i) = min(a(i), b(i)) }",
265+
{a, b},
266+
{c});
267+
CHECK_EQ(
268+
fminf(at::Scalar(a[0]).toFloat(), at::Scalar(b[0]).toFloat()),
269+
at::Scalar(c[0]).toFloat());
270+
}
271+
272+
TEST(TestCornerCases, E23) {
273+
auto a = F(1);
274+
auto b = F(1);
275+
auto c = F(1);
276+
auto d = F(1);
277+
Succeed(
278+
"def f(float(1) a, float(1) b, float(1) c) -> (d) { d(i) = min(a(i), max(b(i), c(i))) }",
279+
{a, b, c},
280+
{d});
281+
CHECK_EQ(
282+
fminf(
283+
at::Scalar(a[0]).toFloat(),
284+
fmaxf(at::Scalar(b[0]).toFloat(), at::Scalar(c[0]).toFloat())),
285+
at::Scalar(d[0]).toFloat());
286+
}
287+
246288
int main(int argc, char** argv) {
247289
::testing::InitGoogleTest(&argc, argv);
248290
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)