From 0a4bf8b7d47f7f2e5e0c25fd9b7e43f5a9960b21 Mon Sep 17 00:00:00 2001 From: nicolasvasilache Date: Tue, 12 Jun 2018 16:07:26 -0600 Subject: [PATCH 1/3] Add more type support This commit adds missing type support to the lexer, semantic analyzer, and cuda implementation. This also adds language and end-to-end functional tests for all types we support. An annoying type issue comes from the inability to include system dependencies with NVRTC so half support is explicitly disabled. Maybe it is time to think about moving to NVCC in a separate process, which we have been discussing for some time. The following commit adds an issue submitted by @mdouze which now runs properly thanks to this commit. --- tc/core/libraries.h | 11 +++++++++ tc/core/tc2halide.cc | 7 ++++++ tc/lang/lexer.h | 31 ++++++++++++++++---------- tc/lang/sema.h | 37 ++++++++++++++++++------------- test/cuda/test_compile_and_run.cc | 31 ++++++++++++++++++++++++++ test/test_tc2halide.cc | 22 ++++++++++++++++++ 6 files changed, 112 insertions(+), 27 deletions(-) diff --git a/tc/core/libraries.h b/tc/core/libraries.h index c86eb565d..c2fdb1de8 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -31,9 +31,20 @@ namespace code { namespace c { constexpr auto types = R"C( +// Can't include system dependencies with NVRTC +// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies +// #include + // Halide type handling +typedef char int8; +typedef short int16; typedef int int32; typedef long int64; +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +// typedef half float16; typedef float float32; typedef double float64; )C"; diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 01da89edf..379add821 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -53,10 +53,17 @@ Type translateScalarType(int tcType) { return Int(32); case lang::TK_INT64: return Int(64); + case lang::TK_FLOAT16: + return Float(16); + case lang::TK_FLOAT32: + return Float(32); + case lang::TK_FLOAT64: + return Float(64); case lang::TK_FLOAT: return Float(32); case lang::TK_DOUBLE: return Float(64); + default: LOG(FATAL) << "Unhandled TC scalar type: " << tcType << '\n'; return Type(); diff --git a/tc/lang/lexer.h b/tc/lang/lexer.h index 9e40a3092..2db8bdc79 100644 --- a/tc/lang/lexer.h +++ b/tc/lang/lexer.h @@ -41,8 +41,6 @@ namespace lang { _(TK_MIN, "min", "min") \ _(TK_MAX, "max", "max") \ _(TK_WHERE, "where", "where") \ - _(TK_FLOAT, "float", "float") \ - _(TK_DOUBLE, "double", "double") \ _(TK_DEF, "def", "def") \ _(TK_ARROW, "arrow", "->") \ _(TK_EQUIVALENT, "equivalent", "<=>") \ @@ -67,15 +65,21 @@ namespace lang { _(TK_TIMES_EQ_B, "times_eq_b", "*=!") \ _(TK_MIN_EQ_B, "min_eq_b", "min=!") \ _(TK_MAX_EQ_B, "max_eq_b", "max=!") \ - _(TK_INT8, "int8", "int8") \ - _(TK_INT16, "int16", "int16") \ - _(TK_INT32, "int32", "int32") \ - _(TK_INT64, "int64", "int64") \ + \ + _(TK_BOOL, "bool", "bool") \ _(TK_UINT8, "uint8", "uint8") \ _(TK_UINT16, "uint16", "uint16") \ _(TK_UINT32, "uint32", "uint32") \ _(TK_UINT64, "uint64", "uint64") \ - _(TK_BOOL, "bool", "bool") \ + _(TK_INT8, "int8", "int8") \ + _(TK_INT16, "int16", "int16") \ + _(TK_INT32, "int32", "int32") \ + _(TK_INT64, "int64", "int64") \ + _(TK_FLOAT16, "float16", "float16") \ + _(TK_FLOAT32, "float32", "float32") \ + _(TK_FLOAT64, "float64", "float64") \ + _(TK_FLOAT, "float", "float") \ + _(TK_DOUBLE, "double", "double") \ _(TK_CAST, "cast", "") \ _(TK_IN, "in", "in") \ _(TK_GE, "ge", ">=") \ @@ -271,15 +275,18 @@ struct SharedParserData { } bool isScalarType(int kind) { switch (kind) { - case TK_INT8: - case TK_INT16: - case TK_INT32: - case TK_INT64: + case TK_BOOL: case TK_UINT8: case TK_UINT16: case TK_UINT32: case TK_UINT64: - case TK_BOOL: + case TK_INT8: + case TK_INT16: + case TK_INT32: + case TK_INT64: + case TK_FLOAT16: + case TK_FLOAT32: + case TK_FLOAT64: case TK_FLOAT: case TK_DOUBLE: return true; diff --git a/tc/lang/sema.h b/tc/lang/sema.h index 03c7a9385..563e7dff9 100644 --- a/tc/lang/sema.h +++ b/tc/lang/sema.h @@ -32,22 +32,26 @@ struct TypeInfo { TypeInfo(Code code_, uint8_t bits_) : code_(code_), bits_(bits_) {} TypeInfo(TreeRef scalar_type) { switch (scalar_type->kind()) { -#define TYPE_INFO_OPTION(tok, c, b) \ - case tok: \ - code_ = c; \ - bits_ = b; \ +#define TYPE_INFO_OPTION(tok, c, b, l) \ + case tok: \ + code_ = c; \ + bits_ = b; \ break; - TYPE_INFO_OPTION(TK_BOOL, UInt, 1) - TYPE_INFO_OPTION(TK_UINT8, UInt, 8) - TYPE_INFO_OPTION(TK_UINT16, UInt, 16) - TYPE_INFO_OPTION(TK_UINT32, UInt, 32) - TYPE_INFO_OPTION(TK_UINT64, UInt, 64) - TYPE_INFO_OPTION(TK_INT8, Int, 8) - TYPE_INFO_OPTION(TK_INT16, Int, 16) - TYPE_INFO_OPTION(TK_INT32, Int, 32) - TYPE_INFO_OPTION(TK_INT64, Int, 64) - TYPE_INFO_OPTION(TK_FLOAT, Float, 32) - TYPE_INFO_OPTION(TK_DOUBLE, Float, 64) + TYPE_INFO_OPTION(TK_BOOL, UInt, 1, 1) + TYPE_INFO_OPTION(TK_UINT8, UInt, 8, 1) + TYPE_INFO_OPTION(TK_UINT16, UInt, 16, 1) + TYPE_INFO_OPTION(TK_UINT32, UInt, 32, 1) + TYPE_INFO_OPTION(TK_UINT64, UInt, 64, 1) + TYPE_INFO_OPTION(TK_INT8, Int, 8, 1) + TYPE_INFO_OPTION(TK_INT16, Int, 16, 1) + TYPE_INFO_OPTION(TK_INT32, Int, 32, 1) + TYPE_INFO_OPTION(TK_INT64, Int, 64, 1) + TYPE_INFO_OPTION(TK_FLOAT16, Float, 16, 1) + TYPE_INFO_OPTION(TK_FLOAT32, Float, 32, 1) + TYPE_INFO_OPTION(TK_FLOAT64, Float, 64, 1) + TYPE_INFO_OPTION(TK_FLOAT, Float, 32, 1) + TYPE_INFO_OPTION(TK_DOUBLE, Float, 64, 1) + #undef TYPE_INFO_OPTION default: throw ErrorReport(scalar_type) @@ -82,12 +86,15 @@ struct TypeInfo { } case Float: switch (bits()) { + case 16: + return TK_FLOAT16; case 32: return TK_FLOAT; case 64: return TK_DOUBLE; } } + throw std::runtime_error("Unknown type info?"); } Code code() const { diff --git a/test/cuda/test_compile_and_run.cc b/test/cuda/test_compile_and_run.cc index 9dc2bde61..272da3e3f 100644 --- a/test/cuda/test_compile_and_run.cc +++ b/test/cuda/test_compile_and_run.cc @@ -275,6 +275,37 @@ def cast(float(M,N) A, int32 four) -> (int32(M,N) output) { TC_CHECK_EQ(r, 0); } +TEST_F(CompilationTest, Types) { + struct TypeMatch { + std::string s; + at::ScalarType a; + }; + for (auto type : + {// TypeMatch{"bool", at::ScalarType::Bool}, // no aten version + TypeMatch{"uint8", at::ScalarType::Byte}, + // TypeMatch{"uint16", at::ScalarType::Short}, // no aten version + // TypeMatch{"uint32", at::ScalarType::Int}, // no aten version + // TypeMatch{"uint64", at::ScalarType::Long}, // no aten version + TypeMatch{"int8", at::ScalarType::Char}, + TypeMatch{"int16", at::ScalarType::Short}, + TypeMatch{"int32", at::ScalarType::Int}, + TypeMatch{"int64", at::ScalarType::Long}, + // NVRTC include transitive dependencies issue + // TypeMatch{"float16", at::ScalarType::Half}, + TypeMatch{"float32", at::ScalarType::Float}, + TypeMatch{"float64", at::ScalarType::Double}, + TypeMatch{"float", at::ScalarType::Float}, + TypeMatch{"double", at::ScalarType::Double}}) { + std::string tc = std::string("def test_type(") + std::string(type.s) + + std::string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }"); + std::vector outputs = Check( + tc, + "test_type", + tc::CudaMappingOptions::makeNaiveMappingOptions(), + {at::CUDA(type.a).ones({100})}); + } +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/test_tc2halide.cc b/test/test_tc2halide.cc index 128afacca..b0b81220f 100644 --- a/test/test_tc2halide.cc +++ b/test/test_tc2halide.cc @@ -197,6 +197,28 @@ def foo(float(N) A) -> (B) { )TC"; EXPECT_THROW(Check(tc), ::lang::ErrorReport); } + +TEST_F(TC2Isl, Types) { + for (auto type : {"bool", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + "float", + "double"}) { + string tc = string("def test_type(") + string(type) + + string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }"); + Check(tc); + } +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); From 189f44c29d1bd2c88db88c3eaf73dec4c4d9debd Mon Sep 17 00:00:00 2001 From: nicolasvasilache Date: Tue, 12 Jun 2018 16:37:27 -0600 Subject: [PATCH 2/3] Add min_distances.py This commit adds examples provided by @mdouze where the argmin over a reduced sum is required. These examples are now functional thanks to the previous commit but extra work is needed to make some of the variants perform reasonably: 1. for the fused kernel to parallelize properly across blocks we need grid synchronization. This may be a nice concrete use case @math-fehr 2. for the 1-stage fissioned implementation we need device-wide synchronization otherwise we will always be limited by running on a single SM 3. the 2-stage fissioned implementations can give us performance today after tuning. Without tuning the results on the larger size (1e7, 32, 16) are shown [here](https://gist.github.com/nicolasvasilache/8a0addfb6831a831b2dca45c612f9c2d). `mindis_16_32_10000000` is the totally fused kernel and performs evry poorly. The following 5 kernels correspond to the final use case of interest. --- .jenkins/build.sh | 4 + python/examples/min_distance.py | 190 ++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 python/examples/min_distance.py diff --git a/.jenkins/build.sh b/.jenkins/build.sh index f9c816a28..ae673d6cf 100755 --- a/.jenkins/build.sh +++ b/.jenkins/build.sh @@ -69,6 +69,10 @@ WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREF python setup.py install ./test_python/run_test.sh +for f in $(find ./python/examples -name "*.py"); do + python $f +done + FILTER_OUT="benchmark_MLP_model benchmark_kronecker" ./test.sh # 2LUT can OOM on smaller Maxwells on our CI machines ./build/tc/benchmarks/benchmark_MLP_model --gtest_filter=-*2LUT* diff --git a/python/examples/min_distance.py b/python/examples/min_distance.py new file mode 100644 index 000000000..d2e725c30 --- /dev/null +++ b/python/examples/min_distance.py @@ -0,0 +1,190 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +import tensor_comprehensions as tc +from tensor_comprehensions.tc import set_logtostderr +from tensor_comprehensions.tc import set_debug_tc_mapper +from tensor_comprehensions.tc import set_debug_cuda + +import numpy as np +import torch + +# +## Example submitted by @mdouze, originally related to uint8 type support +# + +debug = False +set_logtostderr(debug) +set_debug_tc_mapper(debug) +set_debug_cuda(debug) + +N = 1000 +M = 32 + +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') +codes = codes.view('uint8') +luts = np.random.randn(M, 256).astype('float32') + +codes_t = torch.from_numpy(codes).cuda() +luts_t = torch.from_numpy(luts).cuda() + +lang = """ +# mindis as a single kernel will require grid synchronization to run efficiently +def mindis(float(M, 256) L, uint8(N, M) Codes) -> (S, v, min_idx) { + S(n) +=! L(r_m, int32(Codes(n, r_m))) + v min=! S(r_n) + min_idx min=! (S(r_n) == v) ? r_n : N +} + +# Even when splitting in 3 kernels, global device reduction will be needed to +# run efficiently +# don't try to run it with large sizes for now +def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) { + S(n) +=! L(r_m, int32(Codes(n, r_m))) +} +def min_2d(float(N) S) -> (v) { + v min=! S(r_n) +} +def argmin_2d(float(N) S, float v) -> (min_idx) { + min_idx min=! (S(r_n) == v) ? r_n : N +} +""" + +mindis = tc.define(lang, name="mindis") +S, v, min_idx = mindis(luts_t, codes_t) +print("minval: {} minidx: {}".format(v, min_idx)) + +reduce_codes = tc.define(lang, name="reduce_codes") +min_2d = tc.define(lang, name="min_2d") +argmin_2d = tc.define(lang, name="argmin_2d") + +S = reduce_codes(luts_t, codes_t) +v = min_2d(S) +min_idx = argmin_2d(S, v) + +print("minval: {} minidx: {}".format(v, min_idx)) + +################################################################################ +# Each reduction is probably easier to optimize with a 2-staged TC where we +# artifically increase parallelism and finish the reduction in a second kernel. +# Properly choosing D such that N = D * (N / D) should result in a good version +# with 5 kernels total. +################################################################################ +N = 10 ** 5 # bump to 10**7 when ready for primetime +D = 1000 +assert N % D == 0, "D={} must divide N={}".format(D, N) +M = 32 + +lang = """ +def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) { + S(n) +=! L(r_m, int32(Codes(n, r_m))) +} +def min_2d(float(D, NBYD) S) -> (V) { + V(d) min=! S(d, r_nbyd) +} +def min_1d(float(D) V) -> (v) { + v min=! V(r_d) +} +def argmin_2d(float(D, NBYD) S, float v) -> (MinIdx) { + MinIdx(d) min=! (S(d, r_nbyd) == v) ? d * NBYD + r_nbyd : N +} +def argmin_1d(float(N) S, int32(D) MinIdx) -> (min_idx) { + min_idx min=! (MinIdx(r_d) < N) ? r_d : N +} +""" + +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') +codes = codes.view('uint8') +luts = np.random.randn(M, 256).astype('float32') + +codes_t = torch.from_numpy(codes).cuda() +luts_t = torch.from_numpy(luts).cuda() + +reduce_codes = tc.define(lang, name="reduce_codes") +min_2d = tc.define(lang, name="min_2d") +min_1d = tc.define(lang, name="min_1d") +argmin_2d = tc.define(lang, name="argmin_2d") +argmin_1d = tc.define(lang, name="argmin_1d") + +S = reduce_codes(luts_t, codes_t) +V = min_2d(S.view((D, N / D))) +v = min_1d(V) +MinIdx = argmin_2d(S.view((D, N / D)), v) +min_idx = argmin_1d(S, MinIdx) +print("minval: {} minidx: {}".format(v, min_idx)) + +################################################################################ +# Longer form version has an extra k dimension we could use for parallelism +# Unfortunately is it a small dimension (16) so it won't saturate Pascal/Volta. +# So we may want to split in 5 to run efficiently. +################################################################################ +N = 10 ** 7 # bump to 10**7 when ready for primetime +D = 1000 +assert N % D == 0, "D={} must divide N={}".format(D, N) +M = 32 +K = 16 +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') +codes = codes.view('uint8') +luts = np.random.randn(K, M, 256).astype('float32') + +codes_t = torch.from_numpy(codes).cuda() +luts_t = torch.from_numpy(luts).cuda() + +lang = """ +def mindis(float(K, M, 256) L, uint8(N, M) Codes) -> (S, V, MinIdx) { + S(k, n) +=! L(k, r_m, int32(Codes(n, r_m))) + V(k) min=! S(k, r_n) + MinIdx(k) min=! (S(k, r_n) == V(k)) ? r_n : N +} +""" + +debug = False +set_logtostderr(debug) +set_debug_tc_mapper(debug) +set_debug_cuda(debug) + +mindis = tc.define(lang, name="mindis") +S, V, MinIdx = mindis(luts_t, codes_t) +print("minvals: {}\nminidxs: {}".format(V, MinIdx)) + +lang = """ +def reduce_codes(float(K, M, 256) L, uint8(N, M) Codes) -> (S) { + S(k, n) +=! L(k, r_m, int32(Codes(n, r_m))) +} +def min_2d(float(K, D, NBYD) S) -> (V2) { + V2(k, d) min=! S(k, d, r_nbyd) +} +def min_1d(float(K, D) V2) -> (V) { + V(k) min=! V2(k, r_d) +} +def argmin_2d(float(K, D, NBYD) S, float(K) V) -> (MinIdx2) { + MinIdx2(k, d) min=! (S(k, d, r_nbyd) == V(k)) ? d * NBYD + r_nbyd : N +} +def argmin_1d(float(K, N) S, int32(K, D) MinIdx2) -> (MinIdx) { + MinIdx(k) min=! (MinIdx2(k, r_d) < N) ? r_d : N +} +""" + +reduce_codes = tc.define(lang, name="reduce_codes") +min_2d = tc.define(lang, name="min_2d") +min_1d = tc.define(lang, name="min_1d") +argmin_2d = tc.define(lang, name="argmin_2d") +argmin_1d = tc.define(lang, name="argmin_1d") + +S = reduce_codes(luts_t, codes_t) +V2 = min_2d(S.view((K, D, N / D))) +V = min_1d(V2) +MinIdx2 = argmin_2d(S.view((K, D, N / D)), V) +MinIdx = argmin_1d(S, MinIdx2) +print("minval: {} minidx: {}".format(V, MinIdx)) From a70e90059fd2ce3df7cd056a7c5c0b082f9ddd98 Mon Sep 17 00:00:00 2001 From: nicolasvasilache Date: Wed, 13 Jun 2018 11:48:30 -0600 Subject: [PATCH 3/3] [WIP][DO NOT MERGE] Experiment with vector types This commit adds an experimental vector type support to the lexer, semantic analyzer, and cuda implementation. This also adds language and end-to-end functional tests for all types. This is limited by the fact that ATen doesn't allow such types. Therefore this commit adds some striding genuflexions to bypass ATen issues. --- tc/aten/aten_compiler.cc | 19 ++- tc/core/libraries.h | 70 ++++++++++ tc/core/tc2halide.cc | 74 ++++++++-- tc/lang/lexer.h | 171 +++++++++++++++-------- tc/lang/sema.h | 217 +++++++++++++++++++++++------- test/cuda/test_compile_and_run.cc | 65 ++++++--- test/test_tc2halide.cc | 60 +++++++-- 7 files changed, 535 insertions(+), 141 deletions(-) diff --git a/tc/aten/aten_compiler.cc b/tc/aten/aten_compiler.cc index b34fbd316..7b9f7b699 100644 --- a/tc/aten/aten_compiler.cc +++ b/tc/aten/aten_compiler.cc @@ -59,9 +59,22 @@ std::vector prepareOutputs( auto atenBackend = inputs[0].type().backend(); for (size_t i = 0; i < outTensorInfo.size(); ++i) { TensorInfo info(outTensorInfo[i]); - auto stype = at::toScalarType(info.dtype); - outputs.push_back( - at::getType(atenBackend, stype).tensor(at::IntList(info.shape))); + if (info.dtype.lanes == 1) { + auto stype = at::toScalarType(info.dtype); + outputs.push_back( + at::getType(atenBackend, stype).tensor(at::IntList(info.shape))); + } else { + // "Cast" to a larger strided tensor with 1 lane + auto lanes = info.dtype.lanes; + TC_CHECK(lanes == 2 || lanes == 4); + info.dtype.lanes = 1; + info.shape[info.shape.size() - 1] *= lanes; + auto stype = at::toScalarType(info.dtype); + auto T = at::getType(atenBackend, stype).tensor(at::IntList(info.shape)); + + info.shape[info.shape.size() - 1] /= lanes; + outputs.push_back(T.set_(*T.storage(), 0, info.shape)); + } } return outputs; } diff --git a/tc/core/libraries.h b/tc/core/libraries.h index c2fdb1de8..196823b6c 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -47,6 +47,76 @@ typedef unsigned long uint64; // typedef half float16; typedef float float32; typedef double float64; + +template +struct array2 { + T x, y; + array2(T t) : x(t), y(t) {} + array2(T x, T y) : x(x), y(y) {} + array2 operator+(const array2& a) const { + return array2{ + static_cast(x + a.x), + static_cast(y + a.y) + }; + } + array2& operator+=(const array2& a) { + x += a.x; + y += a.y; + return *this; + } +}; +template +array2 x2(Type t) { + return array2(t); +} +typedef array2 int8x2; +typedef array2 int16x2; +typedef array2 int32x2; +typedef array2 int64x2; +typedef array2 uint8x2; +typedef array2 uint16x2; +typedef array2 uint32x2; +typedef array2 uint64x2; +// typedef array2 float16x2; +typedef array2 float32x2; +typedef array2 float64x2; + +template +struct array4 { + T x, y, z, w; + array4(T t) : x(t), y(t), z(t), w(t) {} + array4(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {} + array4 operator+(const array4& a) const { + return array4{ + static_cast(x + a.x), + static_cast(y + a.y), + static_cast(z + a.z), + static_cast(w + a.w) + }; + } + array4& operator+=(const array4& a) { + x += a.x; + y += a.y; + z += a.z; + w += a.w; + return *this; + } +}; +template +array4 x4(Type t) { + return array4(t); +} +typedef array4 int8x4; +typedef array4 int16x4; +typedef array4 int32x4; +typedef array4 int64x4; +typedef array4 uint8x4; +typedef array4 uint16x4; +typedef array4 uint32x4; +typedef array4 uint64x4; +// typedef array4 float16x4; +typedef array4 float32x4; +typedef array4 float64x4; )C"; constexpr auto defines = R"C( diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 379add821..82df69d4a 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -37,14 +37,6 @@ Type translateScalarType(int tcType) { switch (tcType) { case lang::TK_BOOL: return UInt(1); - case lang::TK_UINT8: - return UInt(8); - case lang::TK_UINT16: - return UInt(16); - case lang::TK_UINT32: - return UInt(32); - case lang::TK_UINT64: - return UInt(64); case lang::TK_INT8: return Int(8); case lang::TK_INT16: @@ -53,6 +45,14 @@ Type translateScalarType(int tcType) { return Int(32); case lang::TK_INT64: return Int(64); + case lang::TK_UINT8: + return UInt(8); + case lang::TK_UINT16: + return UInt(16); + case lang::TK_UINT32: + return UInt(32); + case lang::TK_UINT64: + return UInt(64); case lang::TK_FLOAT16: return Float(16); case lang::TK_FLOAT32: @@ -64,6 +64,64 @@ Type translateScalarType(int tcType) { case lang::TK_DOUBLE: return Float(64); + case lang::TK_VECTOR2_BOOL: + return UInt(1, 2); + case lang::TK_VECTOR2_INT8: + return Int(8, 2); + case lang::TK_VECTOR2_INT16: + return Int(16, 2); + case lang::TK_VECTOR2_INT32: + return Int(32, 2); + case lang::TK_VECTOR2_INT64: + return Int(64, 2); + case lang::TK_VECTOR2_UINT8: + return UInt(8, 2); + case lang::TK_VECTOR2_UINT16: + return UInt(16, 2); + case lang::TK_VECTOR2_UINT32: + return UInt(32, 2); + case lang::TK_VECTOR2_UINT64: + return UInt(64, 2); + case lang::TK_VECTOR2_FLOAT16: + return Float(16, 2); + case lang::TK_VECTOR2_FLOAT32: + return Float(32, 2); + case lang::TK_VECTOR2_FLOAT64: + return Float(64, 2); + case lang::TK_VECTOR2_FLOAT: + return Float(32, 2); + case lang::TK_VECTOR2_DOUBLE: + return Float(64, 2); + + case lang::TK_VECTOR4_BOOL: + return UInt(1, 4); + case lang::TK_VECTOR4_INT8: + return Int(8, 4); + case lang::TK_VECTOR4_INT16: + return Int(16, 4); + case lang::TK_VECTOR4_INT32: + return Int(32, 4); + case lang::TK_VECTOR4_INT64: + return Int(64, 4); + case lang::TK_VECTOR4_UINT8: + return UInt(8, 4); + case lang::TK_VECTOR4_UINT16: + return UInt(16, 4); + case lang::TK_VECTOR4_UINT32: + return UInt(32, 4); + case lang::TK_VECTOR4_UINT64: + return UInt(64, 4); + case lang::TK_VECTOR4_FLOAT16: + return Float(16, 4); + case lang::TK_VECTOR4_FLOAT32: + return Float(32, 4); + case lang::TK_VECTOR4_FLOAT64: + return Float(64, 4); + case lang::TK_VECTOR4_FLOAT: + return Float(32, 4); + case lang::TK_VECTOR4_DOUBLE: + return Float(64, 4); + default: LOG(FATAL) << "Unhandled TC scalar type: " << tcType << '\n'; return Type(); diff --git a/tc/lang/lexer.h b/tc/lang/lexer.h index 2db8bdc79..1fe35e439 100644 --- a/tc/lang/lexer.h +++ b/tc/lang/lexer.h @@ -34,61 +34,92 @@ namespace lang { // Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the // lexer. -#define TC_FORALL_TOKEN_KINDS(_) \ - _(TK_EOF, "eof", "") \ - _(TK_NUMBER, "number", "") \ - _(TK_BOOL_VALUE, "bool_value", "") \ - _(TK_MIN, "min", "min") \ - _(TK_MAX, "max", "max") \ - _(TK_WHERE, "where", "where") \ - _(TK_DEF, "def", "def") \ - _(TK_ARROW, "arrow", "->") \ - _(TK_EQUIVALENT, "equivalent", "<=>") \ - _(TK_IDENT, "ident", "") \ - _(TK_STRING, "string", "") \ - _(TK_CONST, "const", "") \ - _(TK_LIST, "list", "") \ - _(TK_OPTION, "option", "") \ - _(TK_APPLY, "apply", "") \ - _(TK_COMPREHENSION, "comprehension", "") \ - _(TK_TENSOR_TYPE, "tensor_type", "") \ - _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ - _(TK_PARAM, "param", "") \ - _(TK_INFERRED, "inferred", "") \ - _(TK_ACCESS, "access", "") \ - _(TK_BUILT_IN, "built-in", "") \ - _(TK_PLUS_EQ, "plus_eq", "+=") \ - _(TK_TIMES_EQ, "times_eq", "*=") \ - _(TK_MIN_EQ, "min_eq", "min=") \ - _(TK_MAX_EQ, "max_eq", "max=") \ - _(TK_PLUS_EQ_B, "plus_eq_b", "+=!") \ - _(TK_TIMES_EQ_B, "times_eq_b", "*=!") \ - _(TK_MIN_EQ_B, "min_eq_b", "min=!") \ - _(TK_MAX_EQ_B, "max_eq_b", "max=!") \ - \ - _(TK_BOOL, "bool", "bool") \ - _(TK_UINT8, "uint8", "uint8") \ - _(TK_UINT16, "uint16", "uint16") \ - _(TK_UINT32, "uint32", "uint32") \ - _(TK_UINT64, "uint64", "uint64") \ - _(TK_INT8, "int8", "int8") \ - _(TK_INT16, "int16", "int16") \ - _(TK_INT32, "int32", "int32") \ - _(TK_INT64, "int64", "int64") \ - _(TK_FLOAT16, "float16", "float16") \ - _(TK_FLOAT32, "float32", "float32") \ - _(TK_FLOAT64, "float64", "float64") \ - _(TK_FLOAT, "float", "float") \ - _(TK_DOUBLE, "double", "double") \ - _(TK_CAST, "cast", "") \ - _(TK_IN, "in", "in") \ - _(TK_GE, "ge", ">=") \ - _(TK_LE, "le", "<=") \ - _(TK_EQ, "eq", "==") \ - _(TK_NE, "neq", "!=") \ - _(TK_AND, "and", "&&") \ - _(TK_OR, "or", "||") \ - _(TK_LET, "let", "") \ +#define TC_FORALL_TOKEN_KINDS(_) \ + _(TK_EOF, "eof", "") \ + _(TK_NUMBER, "number", "") \ + _(TK_BOOL_VALUE, "bool_value", "") \ + _(TK_MIN, "min", "min") \ + _(TK_MAX, "max", "max") \ + _(TK_WHERE, "where", "where") \ + _(TK_DEF, "def", "def") \ + _(TK_ARROW, "arrow", "->") \ + _(TK_EQUIVALENT, "equivalent", "<=>") \ + _(TK_IDENT, "ident", "") \ + _(TK_STRING, "string", "") \ + _(TK_CONST, "const", "") \ + _(TK_LIST, "list", "") \ + _(TK_OPTION, "option", "") \ + _(TK_APPLY, "apply", "") \ + _(TK_COMPREHENSION, "comprehension", "") \ + _(TK_TENSOR_TYPE, "tensor_type", "") \ + _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ + _(TK_PARAM, "param", "") \ + _(TK_INFERRED, "inferred", "") \ + _(TK_ACCESS, "access", "") \ + _(TK_BUILT_IN, "built-in", "") \ + _(TK_PLUS_EQ, "plus_eq", "+=") \ + _(TK_TIMES_EQ, "times_eq", "*=") \ + _(TK_MIN_EQ, "min_eq", "min=") \ + _(TK_MAX_EQ, "max_eq", "max=") \ + _(TK_PLUS_EQ_B, "plus_eq_b", "+=!") \ + _(TK_TIMES_EQ_B, "times_eq_b", "*=!") \ + _(TK_MIN_EQ_B, "min_eq_b", "min=!") \ + _(TK_MAX_EQ_B, "max_eq_b", "max=!") \ + \ + _(TK_BOOL, "bool", "bool") \ + _(TK_UINT8, "uint8", "uint8") \ + _(TK_UINT16, "uint16", "uint16") \ + _(TK_UINT32, "uint32", "uint32") \ + _(TK_UINT64, "uint64", "uint64") \ + _(TK_INT8, "int8", "int8") \ + _(TK_INT16, "int16", "int16") \ + _(TK_INT32, "int32", "int32") \ + _(TK_INT64, "int64", "int64") \ + _(TK_FLOAT16, "float16", "float16") \ + _(TK_FLOAT32, "float32", "float32") \ + _(TK_FLOAT64, "float64", "float64") \ + _(TK_FLOAT, "float", "float") \ + _(TK_DOUBLE, "double", "double") \ + \ + _(TK_VECTOR2_BOOL, "boolx2", "boolx2") \ + _(TK_VECTOR2_UINT8, "uint8x2", "uint8x2") \ + _(TK_VECTOR2_UINT16, "uint16x2", "uint16x2") \ + _(TK_VECTOR2_UINT32, "uint32x2", "uint32x2") \ + _(TK_VECTOR2_UINT64, "uint64x2", "uint64x2") \ + _(TK_VECTOR2_INT8, "int8x2", "int8x2") \ + _(TK_VECTOR2_INT16, "int16x2", "int16x2") \ + _(TK_VECTOR2_INT32, "int32x2", "int32x2") \ + _(TK_VECTOR2_INT64, "int64x2", "int64x2") \ + _(TK_VECTOR2_FLOAT16, "float16x2", "float16x2") \ + _(TK_VECTOR2_FLOAT32, "float32x2", "float32x2") \ + _(TK_VECTOR2_FLOAT64, "float64x2", "float64x2") \ + _(TK_VECTOR2_FLOAT, "floatx2", "floatx2") \ + _(TK_VECTOR2_DOUBLE, "doublex2", "doublex2") \ + \ + _(TK_VECTOR4_BOOL, "boolx4", "boolx4") \ + _(TK_VECTOR4_UINT8, "uint8x4", "uint8x4") \ + _(TK_VECTOR4_UINT16, "uint16x4", "uint16x4") \ + _(TK_VECTOR4_UINT32, "uint32x4", "uint32x4") \ + _(TK_VECTOR4_UINT64, "uint64x4", "uint64x4") \ + _(TK_VECTOR4_INT8, "int8x4", "int8x4") \ + _(TK_VECTOR4_INT16, "int16x4", "int16x4") \ + _(TK_VECTOR4_INT32, "int32x4", "int32x4") \ + _(TK_VECTOR4_INT64, "int64x4", "int64x4") \ + _(TK_VECTOR4_FLOAT16, "float16x4", "float16x4") \ + _(TK_VECTOR4_FLOAT32, "float32x4", "float32x4") \ + _(TK_VECTOR4_FLOAT64, "float64x4", "float64x4") \ + _(TK_VECTOR4_FLOAT, "floatx4", "floatx4") \ + _(TK_VECTOR4_DOUBLE, "doublex4", "doublex4") \ + \ + _(TK_CAST, "cast", "") \ + _(TK_IN, "in", "in") \ + _(TK_GE, "ge", ">=") \ + _(TK_LE, "le", "<=") \ + _(TK_EQ, "eq", "==") \ + _(TK_NE, "neq", "!=") \ + _(TK_AND, "and", "&&") \ + _(TK_OR, "or", "||") \ + _(TK_LET, "let", "") \ _(TK_EXISTS, "exists", "exists") static const char* valid_single_char_tokens = "+-*/()[]?:,={}>kind()) { #define TYPE_INFO_OPTION(tok, c, b, l) \ case tok: \ code_ = c; \ bits_ = b; \ + lanes_ = l; \ break; TYPE_INFO_OPTION(TK_BOOL, UInt, 1, 1) TYPE_INFO_OPTION(TK_UINT8, UInt, 8, 1) @@ -52,6 +54,36 @@ struct TypeInfo { TYPE_INFO_OPTION(TK_FLOAT, Float, 32, 1) TYPE_INFO_OPTION(TK_DOUBLE, Float, 64, 1) + TYPE_INFO_OPTION(TK_VECTOR2_BOOL, UInt, 1, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT8, UInt, 8, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT16, UInt, 16, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT32, UInt, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT64, UInt, 64, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT8, Int, 8, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT16, Int, 16, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT32, Int, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT64, Int, 64, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT16, Float, 16, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT32, Float, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT64, Float, 64, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT, Float, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_DOUBLE, Float, 64, 2) + + TYPE_INFO_OPTION(TK_VECTOR4_BOOL, UInt, 1, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT8, UInt, 8, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT16, UInt, 16, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT32, UInt, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT64, UInt, 64, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT8, Int, 8, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT16, Int, 16, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT32, Int, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT64, Int, 64, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT16, Float, 16, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT32, Float, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT64, Float, 64, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT, Float, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_DOUBLE, Float, 64, 4) + #undef TYPE_INFO_OPTION default: throw ErrorReport(scalar_type) @@ -59,40 +91,114 @@ struct TypeInfo { } } int toScalarToken() const { - switch (code()) { - case UInt: - switch (bits()) { - case 1: - return TK_BOOL; - case 8: - return TK_UINT8; - case 16: - return TK_UINT16; - case 32: - return TK_UINT32; - case 64: - return TK_UINT64; - } - case Int: - switch (bits()) { - case 8: - return TK_INT8; - case 16: - return TK_INT16; - case 32: - return TK_INT32; - case 64: - return TK_INT64; - } - case Float: - switch (bits()) { - case 16: - return TK_FLOAT16; - case 32: - return TK_FLOAT; - case 64: - return TK_DOUBLE; - } + if (lanes() == 1) { + switch (code()) { + case UInt: + switch (bits()) { + case 1: + return TK_BOOL; + case 8: + return TK_UINT8; + case 16: + return TK_UINT16; + case 32: + return TK_UINT32; + case 64: + return TK_UINT64; + } + case Int: + switch (bits()) { + case 8: + return TK_INT8; + case 16: + return TK_INT16; + case 32: + return TK_INT32; + case 64: + return TK_INT64; + } + case Float: + switch (bits()) { + case 16: + return TK_FLOAT16; + case 32: + return TK_FLOAT; + case 64: + return TK_DOUBLE; + } + } + } else if (lanes() == 2) { + switch (code()) { + case UInt: + switch (bits()) { + case 1: + return TK_VECTOR2_BOOL; + case 8: + return TK_VECTOR2_UINT8; + case 16: + return TK_VECTOR2_UINT16; + case 32: + return TK_VECTOR2_UINT32; + case 64: + return TK_VECTOR2_UINT64; + } + case Int: + switch (bits()) { + case 8: + return TK_VECTOR2_INT8; + case 16: + return TK_VECTOR2_INT16; + case 32: + return TK_VECTOR2_INT32; + case 64: + return TK_VECTOR2_INT64; + } + case Float: + switch (bits()) { + case 16: + return TK_VECTOR2_FLOAT16; + case 32: + return TK_VECTOR2_FLOAT; + case 64: + return TK_VECTOR2_DOUBLE; + } + } + } else if (lanes() == 4) { + switch (code()) { + case UInt: + switch (bits()) { + case 1: + return TK_VECTOR4_BOOL; + case 8: + return TK_VECTOR4_UINT8; + case 16: + return TK_VECTOR4_UINT16; + case 32: + return TK_VECTOR4_UINT32; + case 64: + return TK_VECTOR4_UINT64; + } + case Int: + switch (bits()) { + case 8: + return TK_VECTOR4_INT8; + case 16: + return TK_VECTOR4_INT16; + case 32: + return TK_VECTOR4_INT32; + case 64: + return TK_VECTOR4_INT64; + } + case Float: + switch (bits()) { + case 16: + return TK_VECTOR4_FLOAT16; + case 32: + return TK_VECTOR4_FLOAT; + case 64: + return TK_VECTOR4_DOUBLE; + } + } } throw std::runtime_error("Unknown type info?"); @@ -103,6 +209,9 @@ struct TypeInfo { uint8_t bits() const { return bits_; } + uint8_t lanes() const { + return lanes_; + } bool is_float() const { return code_ == Float; } @@ -113,10 +222,11 @@ struct TypeInfo { private: Code code_; uint8_t bits_; + uint8_t lanes_; }; static inline bool operator==(TypeInfo a, TypeInfo b) { - return a.bits() == b.bits() && a.code() == b.code(); + return a.bits() == b.bits() && a.code() == b.code() && a.lanes() == b.lanes(); } static inline TreeRef match_types(TreeRef a, TreeRef b) { @@ -125,34 +235,49 @@ static inline TreeRef match_types(TreeRef a, TreeRef b) { if (ta == tb) return a; - if (!ta.is_float() && tb.is_float()) { +#define NO_MATCH() \ + throw ErrorReport(b) << "Could not match types: " \ + << kindToString(ta.toScalarToken()) << ", " \ + << kindToString(tb.toScalarToken()); + + if (!ta.is_float() && ta.lanes() == 1 && tb.is_float()) { // int(a) * float(b) -> float(b) // uint(a) * float(b) -> float(b) return b; - } else if (ta.is_float() && !tb.is_float()) { + } else if (ta.is_float() && !tb.is_float() && tb.lanes() == 1) { return a; } else if (ta.is_float() && tb.is_float()) { // float(a) * float(b) -> float(max(a, b)) - if (ta.bits() > tb.bits()) + if (ta.bits() > tb.bits() && tb.lanes() == 1) return a; - else + else if (ta.lanes() == 1) return b; - } else if (ta.is_uint() && tb.is_uint()) { + else { + NO_MATCH(); + } + } else if (ta.is_uint() && tb.is_uint() && tb.lanes() == 1) { // uint(a) * uint(b) -> uint(max(a, b)) - if (ta.bits() > tb.bits()) + if (ta.bits() > tb.bits() && tb.lanes() == 1) return a; - else + else if (ta.lanes() == 1) return b; + else { + NO_MATCH(); + } } else if (!ta.is_float() && !tb.is_float()) { // int(a) * (u)int(b) -> int(max(a, b)) int bits = std::max(ta.bits(), tb.bits()); + if ((bits == ta.bits() && tb.lanes() != 1) || + (bits == tb.bits() && ta.lanes() != 1)) { + NO_MATCH(); + } return Compound::create( TypeInfo(TypeInfo::Int, bits).toScalarToken(), a->range(), {}); } else { - throw ErrorReport(b) << "Could not match types: " - << kindToString(ta.toScalarToken()) << ", " - << kindToString(tb.toScalarToken()); + NO_MATCH(); } + +#undef NO_MATCH } /// Semantic analysis transforms the raw AST into a diff --git a/test/cuda/test_compile_and_run.cc b/test/cuda/test_compile_and_run.cc index 272da3e3f..1113c0b7f 100644 --- a/test/cuda/test_compile_and_run.cc +++ b/test/cuda/test_compile_and_run.cc @@ -279,30 +279,65 @@ TEST_F(CompilationTest, Types) { struct TypeMatch { std::string s; at::ScalarType a; + uint8_t lanes; }; for (auto type : - {// TypeMatch{"bool", at::ScalarType::Bool}, // no aten version - TypeMatch{"uint8", at::ScalarType::Byte}, - // TypeMatch{"uint16", at::ScalarType::Short}, // no aten version - // TypeMatch{"uint32", at::ScalarType::Int}, // no aten version - // TypeMatch{"uint64", at::ScalarType::Long}, // no aten version - TypeMatch{"int8", at::ScalarType::Char}, - TypeMatch{"int16", at::ScalarType::Short}, - TypeMatch{"int32", at::ScalarType::Int}, - TypeMatch{"int64", at::ScalarType::Long}, + {// TypeMatch{"bool", at::ScalarType::Bool, 1}, // no aten version + TypeMatch{"uint8", at::ScalarType::Byte, 1}, + // TypeMatch{"uint16", at::ScalarType::Short, 1}, // no aten version + // TypeMatch{"uint32", at::ScalarType::Int, 1}, // no aten version + // TypeMatch{"uint64", at::ScalarType::Long, 1}, // no aten version + TypeMatch{"int8", at::ScalarType::Char, 1}, + TypeMatch{"int16", at::ScalarType::Short, 1}, + TypeMatch{"int32", at::ScalarType::Int, 1}, + TypeMatch{"int64", at::ScalarType::Long, 1}, // NVRTC include transitive dependencies issue - // TypeMatch{"float16", at::ScalarType::Half}, - TypeMatch{"float32", at::ScalarType::Float}, - TypeMatch{"float64", at::ScalarType::Double}, - TypeMatch{"float", at::ScalarType::Float}, - TypeMatch{"double", at::ScalarType::Double}}) { + // TypeMatch{"float16", at::ScalarType::Half, 1}, + TypeMatch{"float32", at::ScalarType::Float, 1}, + TypeMatch{"float64", at::ScalarType::Double, 1}, + TypeMatch{"float", at::ScalarType::Float, 1}, + TypeMatch{"double", at::ScalarType::Double, 1}, + + // TypeMatch{"boolx2", at::ScalarType::Bool, 2}, // no aten version + TypeMatch{"uint8x2", at::ScalarType::Byte, 2}, + // TypeMatch{"uint16x2", at::ScalarType::Short, 2}, // no aten version + // TypeMatch{"uint32x2", at::ScalarType::Int, 2}, // no aten version + // TypeMatch{"uint64x2", at::ScalarType::Long, 2}, // no aten version + TypeMatch{"int8x2", at::ScalarType::Char, 2}, + TypeMatch{"int16x2", at::ScalarType::Short, 2}, + TypeMatch{"int32x2", at::ScalarType::Int, 2}, + TypeMatch{"int64x2", at::ScalarType::Long, 2}, + // NVRTC include transitive dependencies issue + // TypeMatch{"float16x2", at::ScalarType::Half, 2}, + TypeMatch{"float32x2", at::ScalarType::Float, 2}, + TypeMatch{"float64x2", at::ScalarType::Double, 2}, + TypeMatch{"floatx2", at::ScalarType::Float, 2}, + TypeMatch{"doublex2", at::ScalarType::Double, 2}, + + // TypeMatch{"boolx4", at::ScalarType::Bool, 4}, // no aten version + TypeMatch{"uint8x4", at::ScalarType::Byte, 4}, + // TypeMatch{"uint16x4", at::ScalarType::Short, 4}, // no aten version + // TypeMatch{"uint32x4", at::ScalarType::Int, 4}, // no aten version + // TypeMatch{"uint64x4", at::ScalarType::Long, 4}, // no aten version + TypeMatch{"int8x4", at::ScalarType::Char, 4}, + TypeMatch{"int16x4", at::ScalarType::Short, 4}, + TypeMatch{"int32x4", at::ScalarType::Int, 4}, + TypeMatch{"int64x4", at::ScalarType::Long, 4}, + // NVRTC include transitive dependencies issue + // TypeMatch{"float16x4", at::ScalarType::Half, 4}, + TypeMatch{"float32x4", at::ScalarType::Float, 4}, + TypeMatch{"float64x4", at::ScalarType::Double, 4}, + TypeMatch{"floatx4", at::ScalarType::Float, 4}, + TypeMatch{"doublex4", at::ScalarType::Double, 4}}) { std::string tc = std::string("def test_type(") + std::string(type.s) + std::string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }"); + + auto T = at::CUDA(type.a).ones({100 * type.lanes}); std::vector outputs = Check( tc, "test_type", tc::CudaMappingOptions::makeNaiveMappingOptions(), - {at::CUDA(type.a).ones({100})}); + {T.set_(*T.storage(), {100}, {type.lanes})}); } } diff --git a/test/test_tc2halide.cc b/test/test_tc2halide.cc index b0b81220f..a36b2a85b 100644 --- a/test/test_tc2halide.cc +++ b/test/test_tc2halide.cc @@ -199,20 +199,52 @@ def foo(float(N) A) -> (B) { } TEST_F(TC2Isl, Types) { - for (auto type : {"bool", - "uint8", - "uint16", - "uint32", - "uint64", - "int8", - "int16", - "int32", - "int64", - "float16", - "float32", - "float64", - "float", - "double"}) { + for (auto type : { + "bool", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + "float", + "double", + // + "boolx2", + "uint8x2", + "uint16x2", + "uint32x2", + "uint64x2", + "int8x2", + "int16x2", + "int32x2", + "int64x2", + "float16x2", + "float32x2", + "float64x2", + "floatx2", + "doublex2", + // + "boolx4", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "float16x4", + "float32x4", + "float64x4", + "floatx4", + "doublex4", + }) { string tc = string("def test_type(") + string(type) + string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }"); Check(tc);