Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 3abd114

Browse files
committed
Add bitwise operators and propagate them from parser -> halide -> isl
1 parent a57562a commit 3abd114

File tree

7 files changed

+78
-2
lines changed

7 files changed

+78
-2
lines changed

tc/core/libraries.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ template<typename T> inline __device__ T floord(T n, T d) {
145145
return n < 0 ? - (-n + d - 1)/d : n / d;
146146
}
147147
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))
148+
#define shift_left(a,b) ((a) << (b))
149+
#define shift_right(a,b) ((a) >> (b))
150+
#define bitwise_and(a,b) ((a) & (b))
151+
#define bitwise_xor(a,b) ((a) ^ (b))
152+
#define bitwise_or(a,b) ((a) | (b))
153+
#define bitwise_not(a) (~(a))
148154
)C";
149155
} // namespace cpp
150156

tc/core/tc2halide.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,18 @@ Expr translateExpr(
190190
return t(0) && t(1);
191191
case lang::TK_OR:
192192
return t(0) || t(1);
193+
case lang::TK_LS:
194+
return t(0) << t(1);
195+
case lang::TK_RS:
196+
return t(0) >> t(1);
197+
case '|':
198+
return t(0) | t(1);
199+
case '^':
200+
return t(0) ^ t(1);
201+
case '&':
202+
return t(0) & t(1);
203+
case '~':
204+
return ~t(0);
193205
case lang::TK_BUILT_IN: {
194206
auto b = lang::BuiltIn(expr);
195207
vector<Expr> exprs;

tc/lang/lexer.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ namespace lang {
8484
_(TK_NE, "neq", "!=") \
8585
_(TK_AND, "and", "&&") \
8686
_(TK_OR, "or", "||") \
87+
_(TK_LS, "ls", "<<") \
88+
_(TK_RS, "rs", ">>") \
8789
_(TK_LET, "let", "") \
8890
_(TK_EXISTS, "exists", "exists")
8991

90-
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%";
92+
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%&^|~";
9193

9294
enum TokenKind {
9395
// we use characters to represent themselves so skip all valid characters
@@ -135,12 +137,16 @@ struct SharedParserData {
135137
{'?'},
136138
{TK_OR},
137139
{TK_AND},
140+
{'|'},
141+
{'^'},
142+
{'&'},
138143
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
144+
{TK_LS, TK_RS},
139145
{'+', '-'},
140146
{'*', '/', '%'},
141147
};
142148
std::vector<std::vector<int>> unary_ops = {
143-
{'-', '!'},
149+
{'-', '!', '~'},
144150
};
145151

146152
std::stringstream ss;

tc/lang/sema.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ struct Sema {
294294
case '*':
295295
case '/':
296296
case '%':
297+
case '~':
298+
case '|':
299+
case '^':
300+
case '&':
301+
case TK_LS:
302+
case TK_RS:
297303
case TK_MIN:
298304
case TK_MAX: {
299305
auto nexp =
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
(|
2+
(^
3+
(&
4+
(~ (const 3 (int32)))
5+
(const 4 (int32)))
6+
(const 5 (int32)))
7+
(ls
8+
(rs
9+
(const 6 (int32))
10+
(const 8 (int32)))
11+
(const 2 (int32))))

test/cuda/test_corner_cases.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,33 @@ TEST(TestCornerCases, E24) {
296296
CHECK_EQ(at::Scalar(r[0]).toInt(), 0);
297297
}
298298

299+
TEST(TestCornerCases, E25){
300+
#define GEN_BITWISE(op) \
301+
{ \
302+
auto a = 2 * I(); \
303+
auto b = 2 * I(); \
304+
auto r = I(0); \
305+
Succeed( \
306+
"def f(int32 a, int32 b) -> (c) { c(i) = int32(a " #op \
307+
" b) where i in 0:1 }", \
308+
{a, b}, \
309+
{r}); \
310+
auto e = at::Scalar(a).toInt() op at::Scalar(b).toInt(); \
311+
CHECK_EQ(e, at::Scalar(r[0]).toInt()); \
312+
}
313+
314+
GEN_BITWISE(<<) GEN_BITWISE(>>) GEN_BITWISE(&) GEN_BITWISE(|)
315+
GEN_BITWISE (^)}
316+
317+
TEST(TestCornerCases, E26) {
318+
auto a = I();
319+
auto r = I(0);
320+
Succeed(
321+
"def f(int32 a) -> (c) { c(i) = int32(~a) where i in 0:1 }", {a}, {r});
322+
auto e = ~at::Scalar(a).toInt();
323+
CHECK_EQ(at::Scalar(r[0]).toInt(), e);
324+
}
325+
299326
int main(int argc, char** argv) {
300327
::testing::InitGoogleTest(&argc, argv);
301328
::gflags::ParseCommandLineFlags(&argc, &argv, true);

test/test_lang.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,14 @@ int main(int argc, char** argv) {
244244
ss2 << p2.parseExp();
245245
assertEqual("function.expected", ss2.str());
246246
}
247+
{
248+
std::string bitOps = "~3&4^5|6>>8<<2";
249+
Parser p(bitOps);
250+
auto r = p.parseExp();
251+
std::stringstream ss;
252+
ss << r;
253+
assertEqual("bitwise.expected", ss.str());
254+
}
247255
assertParseEqual("trinary.expected", "a ? 3 : b ? 3 : 4", [&](Parser& p) {
248256
return p.parseExp();
249257
});

0 commit comments

Comments
 (0)