Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit aa2fe9e

Browse files
Add more type support
This commit adds missing type support to the lexer, semantic analyzer, and cuda implementation. This also adds language and end-to-end functional tests for all types we support. An annoying type issue comes from the inability to include system dependencies with NVRTC so half support is explicitly disabled. Maybe it is time to think about moving to NVCC in a separate process, which we have been discussing for some time. The following commit adds an issue submitted by @mdouze which now runs properly thanks to this commit.
1 parent 89906d9 commit aa2fe9e

File tree

6 files changed

+104
-12
lines changed

6 files changed

+104
-12
lines changed

tc/core/libraries.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,20 @@ namespace code {
3131
namespace c {
3232

3333
constexpr auto types = R"C(
34+
// Can't include system dependencies with NVRTC
35+
// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies
36+
// #include <cuda_fp16.h>
37+
3438
// Halide type handling
39+
typedef char int8;
40+
typedef short int16;
3541
typedef int int32;
3642
typedef long int64;
43+
typedef unsigned char uint8;
44+
typedef unsigned short uint16;
45+
typedef unsigned int uint32;
46+
typedef unsigned long uint64;
47+
// typedef half float16;
3748
typedef float float32;
3849
typedef double float64;
3950
)C";

tc/core/tc2halide.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,17 @@ Type translateScalarType(int tcType) {
5353
return Int(32);
5454
case lang::TK_INT64:
5555
return Int(64);
56+
case lang::TK_FLOAT16:
57+
return Float(16);
58+
case lang::TK_FLOAT32:
59+
return Float(32);
60+
case lang::TK_FLOAT64:
61+
return Float(64);
5662
case lang::TK_FLOAT:
5763
return Float(32);
5864
case lang::TK_DOUBLE:
5965
return Float(64);
66+
6067
default:
6168
LOG(FATAL) << "Unhandled TC scalar type: " << tcType << '\n';
6269
return Type();

tc/lang/lexer.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ namespace lang {
4141
_(TK_MIN, "min", "min") \
4242
_(TK_MAX, "max", "max") \
4343
_(TK_WHERE, "where", "where") \
44-
_(TK_FLOAT, "float", "float") \
45-
_(TK_DOUBLE, "double", "double") \
4644
_(TK_DEF, "def", "def") \
4745
_(TK_ARROW, "arrow", "->") \
4846
_(TK_EQUIVALENT, "equivalent", "<=>") \
@@ -67,15 +65,21 @@ namespace lang {
6765
_(TK_TIMES_EQ_B, "times_eq_b", "*=!") \
6866
_(TK_MIN_EQ_B, "min_eq_b", "min=!") \
6967
_(TK_MAX_EQ_B, "max_eq_b", "max=!") \
70-
_(TK_INT8, "int8", "int8") \
71-
_(TK_INT16, "int16", "int16") \
72-
_(TK_INT32, "int32", "int32") \
73-
_(TK_INT64, "int64", "int64") \
68+
\
69+
_(TK_BOOL, "bool", "bool") \
7470
_(TK_UINT8, "uint8", "uint8") \
7571
_(TK_UINT16, "uint16", "uint16") \
7672
_(TK_UINT32, "uint32", "uint32") \
7773
_(TK_UINT64, "uint64", "uint64") \
78-
_(TK_BOOL, "bool", "bool") \
74+
_(TK_INT8, "int8", "int8") \
75+
_(TK_INT16, "int16", "int16") \
76+
_(TK_INT32, "int32", "int32") \
77+
_(TK_INT64, "int64", "int64") \
78+
_(TK_FLOAT16, "float16", "float16") \
79+
_(TK_FLOAT32, "float32", "float32") \
80+
_(TK_FLOAT64, "float64", "float64") \
81+
_(TK_FLOAT, "float", "float") \
82+
_(TK_DOUBLE, "double", "double") \
7983
_(TK_CAST, "cast", "") \
8084
_(TK_IN, "in", "in") \
8185
_(TK_GE, "ge", ">=") \
@@ -271,15 +275,18 @@ struct SharedParserData {
271275
}
272276
bool isScalarType(int kind) {
273277
switch (kind) {
274-
case TK_INT8:
275-
case TK_INT16:
276-
case TK_INT32:
277-
case TK_INT64:
278+
case TK_BOOL:
278279
case TK_UINT8:
279280
case TK_UINT16:
280281
case TK_UINT32:
281282
case TK_UINT64:
282-
case TK_BOOL:
283+
case TK_INT8:
284+
case TK_INT16:
285+
case TK_INT32:
286+
case TK_INT64:
287+
case TK_FLOAT16:
288+
case TK_FLOAT32:
289+
case TK_FLOAT64:
283290
case TK_FLOAT:
284291
case TK_DOUBLE:
285292
return true;

tc/lang/sema.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,23 @@ struct TypeInfo {
4646
TYPE_INFO_OPTION(TK_INT16, Int, 16)
4747
TYPE_INFO_OPTION(TK_INT32, Int, 32)
4848
TYPE_INFO_OPTION(TK_INT64, Int, 64)
49+
TYPE_INFO_OPTION(TK_FLOAT16, Float, 16)
50+
TYPE_INFO_OPTION(TK_FLOAT32, Float, 32)
51+
TYPE_INFO_OPTION(TK_FLOAT64, Float, 64)
4952
TYPE_INFO_OPTION(TK_FLOAT, Float, 32)
5053
TYPE_INFO_OPTION(TK_DOUBLE, Float, 64)
54+
5155
#undef TYPE_INFO_OPTION
5256
default:
5357
throw ErrorReport(scalar_type)
5458
<< "Unhandled TC scalar type: " << scalar_type;
5559
}
60+
61+
if (code_ == Code::Float && bits_ == 16) {
62+
throw ErrorReport(scalar_type)
63+
<< "Half precision floating point not supported "
64+
<< "until we can make NVRTC include system headers";
65+
}
5666
}
5767
int toScalarToken() const {
5868
switch (code()) {
@@ -82,12 +92,15 @@ struct TypeInfo {
8292
}
8393
case Float:
8494
switch (bits()) {
95+
case 16:
96+
return TK_FLOAT16;
8597
case 32:
8698
return TK_FLOAT;
8799
case 64:
88100
return TK_DOUBLE;
89101
}
90102
}
103+
91104
throw std::runtime_error("Unknown type info?");
92105
}
93106
Code code() const {

test/cuda/test_compile_and_run.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,37 @@ def cast(float(M,N) A, int32 four) -> (int32(M,N) output) {
275275
TC_CHECK_EQ(r, 0);
276276
}
277277

278+
TEST_F(CompilationTest, Types) {
279+
struct TypeMatch {
280+
std::string s;
281+
at::ScalarType a;
282+
};
283+
for (auto type :
284+
{// TypeMatch{"bool", at::ScalarType::Bool}, // no aten version
285+
TypeMatch{"uint8", at::ScalarType::Byte},
286+
// TypeMatch{"uint16", at::ScalarType::Short}, // no aten version
287+
// TypeMatch{"uint32", at::ScalarType::Int}, // no aten version
288+
// TypeMatch{"uint64", at::ScalarType::Long}, // no aten version
289+
TypeMatch{"int8", at::ScalarType::Char},
290+
TypeMatch{"int16", at::ScalarType::Short},
291+
TypeMatch{"int32", at::ScalarType::Int},
292+
TypeMatch{"int64", at::ScalarType::Long},
293+
// NVRTC include transitive dependencies issue
294+
// TypeMatch{"float16", at::ScalarType::Half},
295+
TypeMatch{"float32", at::ScalarType::Float},
296+
TypeMatch{"float64", at::ScalarType::Double},
297+
TypeMatch{"float", at::ScalarType::Float},
298+
TypeMatch{"double", at::ScalarType::Double}}) {
299+
std::string tc = std::string("def test_type(") + std::string(type.s) +
300+
std::string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }");
301+
std::vector<at::Tensor> outputs = Check(
302+
tc,
303+
"test_type",
304+
tc::CudaMappingOptions::makeNaiveMappingOptions(),
305+
{at::CUDA(type.a).ones({100})});
306+
}
307+
}
308+
278309
int main(int argc, char** argv) {
279310
::testing::InitGoogleTest(&argc, argv);
280311
::gflags::ParseCommandLineFlags(&argc, &argv, true);

test/test_tc2halide.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,29 @@ def foo(float(N) A) -> (B) {
197197
)TC";
198198
EXPECT_THROW(Check(tc), ::lang::ErrorReport);
199199
}
200+
201+
TEST_F(TC2Isl, Types) {
202+
for (auto type : {"bool",
203+
"uint8",
204+
"uint16",
205+
"uint32",
206+
"uint64",
207+
"int8",
208+
"int16",
209+
"int32",
210+
"int64",
211+
// NVRTC include transitive dependencies issue
212+
// "float16",
213+
"float32",
214+
"float64",
215+
"float",
216+
"double"}) {
217+
string tc = string("def test_type(") + string(type) +
218+
string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }");
219+
Check(tc);
220+
}
221+
}
222+
200223
int main(int argc, char** argv) {
201224
::testing::InitGoogleTest(&argc, argv);
202225
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)