Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a521a7a

Browse files
prigoyalftynse
authored andcommitted
add support for modulo operator
support % operator: propagate it from parser to halide to isl and add unit tests
1 parent c8788eb commit a521a7a

File tree

7 files changed

+54
-4
lines changed

7 files changed

+54
-4
lines changed

tc/core/tc2halide.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ Expr translateExpr(
158158
return t(0) * t(1);
159159
case '/':
160160
return t(0) / t(1);
161+
case '%':
162+
return t(0) % t(1);
161163
case lang::TK_MIN:
162164
return min(t(0), t(1));
163165
case lang::TK_MAX:

tc/lang/lexer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ namespace lang {
8787
_(TK_LET, "let", "") \
8888
_(TK_EXISTS, "exists", "exists")
8989

90-
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
90+
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%";
9191

9292
enum TokenKind {
9393
// we use characters to represent themselves so skip all valid characters
@@ -137,7 +137,7 @@ struct SharedParserData {
137137
{TK_AND},
138138
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
139139
{'+', '-'},
140-
{'*', '/'},
140+
{'*', '/', '%'},
141141
};
142142
std::vector<std::vector<int>> unary_ops = {
143143
{'-', '!'},

tc/lang/sema.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ struct Sema {
293293
case '-':
294294
case '*':
295295
case '/':
296+
case '%':
296297
case TK_MIN:
297298
case TK_MAX: {
298299
auto nexp =

tc/lang/test_expected/math.expected

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
(-
22
(+
33
(+
4-
(- (const 3 (int32)))
4+
(%
5+
(- (const 3 (int32)))
6+
(const 2 (int32)))
57
(*
68
(const 4 (int32))
79
(const 5 (int32))))

test/cuda/test_corner_cases.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,18 @@ i))
320320
at::Scalar(d[0]).toFloat());
321321
}
322322

323+
TEST(TestCornerCases, E25) {
324+
auto a = I();
325+
auto b = I();
326+
auto r = I(1);
327+
Succeed(
328+
"def f(int32 a, int32 b) -> (c) { c(i) = int32(a % b) where i in 0:1 }",
329+
{a, b},
330+
{r});
331+
auto e = at::Scalar(a).toInt() % at::Scalar(b).toInt();
332+
CHECK_EQ(at::Scalar(r[0]).toInt(), e);
333+
}
334+
323335
int main(int argc, char** argv) {
324336
::testing::InitGoogleTest(&argc, argv);
325337
::gflags::ParseCommandLineFlags(&argc, &argv, true);

test/test_cuda_mapper.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,39 @@ TEST_F(PolyhedralMapperTest, EmptyMappingFilter) {
11181118
mscop->codegen(specializedName);
11191119
}
11201120

1121+
TEST_F(PolyhedralMapperTest, ModulusConstantRHS) {
1122+
string tc = R"TC(
1123+
def fun(float(N) a) -> (b) { b(i) = a(i % 3) where i in 0:N }
1124+
)TC";
1125+
// This triggers tc2halide conversion and should not throw.
1126+
auto scop = Prepare(tc);
1127+
for (auto r : scop->reads.wrap().get_set_list()) {
1128+
auto read = r.unwrap();
1129+
// skip irrelevant reads, if any
1130+
if (read.range().get_tuple_name() != std::string("a")) {
1131+
continue;
1132+
}
1133+
EXPECT_EQ(r.get_stride(0), 3);
1134+
}
1135+
}
1136+
1137+
TEST_F(PolyhedralMapperTest, ModulusVariableRHS) {
1138+
string tc = R"TC(
1139+
def local_sparse_convolution(float(N, C, H, W) I, float(O, KC, KH, KW) W1) -> (O1) {
1140+
O1(n, o, h, w) +=! I(n, kc % c, h + kh, w + kw) * W1(o, kc, kh, kw) where c in 1:C
1141+
}
1142+
)TC";
1143+
// This triggers tc2halide conversion and should not throw.
1144+
auto scop = Prepare(tc);
1145+
for (auto r : scop->reads.range().get_set_list()) {
1146+
// skip irrelevant reads, if any
1147+
if (r.get_tuple_name() != std::string("I")) {
1148+
continue;
1149+
}
1150+
EXPECT_TRUE(r.plain_is_universe());
1151+
}
1152+
}
1153+
11211154
int main(int argc, char** argv) {
11221155
::testing::InitGoogleTest(&argc, argv);
11231156
::gflags::ParseCommandLineFlags(&argc, &argv, true);

test/test_lang.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ int main(int argc, char** argv) {
232232
ASSERT(s->tree(0)->stringValue() == "min");
233233
}
234234
{
235-
std::string stuff = "-3+4*5+7-a";
235+
std::string stuff = "-3%2+4*5+7-a";
236236
Parser p(stuff);
237237
auto r = p.parseExp();
238238
std::stringstream ss;

0 commit comments

Comments
 (0)