diff --git a/.jenkins/build.sh b/.jenkins/build.sh index f9c816a28..ae673d6cf 100755 --- a/.jenkins/build.sh +++ b/.jenkins/build.sh @@ -69,6 +69,10 @@ WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREF python setup.py install ./test_python/run_test.sh +for f in $(find ./python/examples -name "*.py"); do + python $f +done + FILTER_OUT="benchmark_MLP_model benchmark_kronecker" ./test.sh # 2LUT can OOM on smaller Maxwells on our CI machines ./build/tc/benchmarks/benchmark_MLP_model --gtest_filter=-*2LUT* diff --git a/python/examples/min_distance.py b/python/examples/min_distance.py new file mode 100644 index 000000000..d2e725c30 --- /dev/null +++ b/python/examples/min_distance.py @@ -0,0 +1,190 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## +import tensor_comprehensions as tc +from tensor_comprehensions.tc import set_logtostderr +from tensor_comprehensions.tc import set_debug_tc_mapper +from tensor_comprehensions.tc import set_debug_cuda + +import numpy as np +import torch + +# +## Example submitted by @mdouze, originally related to uint8 type support +# + +debug = False +set_logtostderr(debug) +set_debug_tc_mapper(debug) +set_debug_cuda(debug) + +N = 1000 +M = 32 + +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') +codes = codes.view('uint8') +luts = np.random.randn(M, 256).astype('float32') + +codes_t = torch.from_numpy(codes).cuda() +luts_t = torch.from_numpy(luts).cuda() + +lang = """ +# mindis as a single kernel will require grid synchronization to run efficiently +def mindis(float(M, 256) L, uint8(N, M) Codes) -> (S, v, min_idx) { + S(n) +=! L(r_m, int32(Codes(n, r_m))) + v min=! S(r_n) + min_idx min=! (S(r_n) == v) ? r_n : N +} + +# Even when splitting in 3 kernels, global device reduction will be needed to +# run efficiently +# don't try to run it with large sizes for now +def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) { + S(n) +=! L(r_m, int32(Codes(n, r_m))) +} +def min_2d(float(N) S) -> (v) { + v min=! S(r_n) +} +def argmin_2d(float(N) S, float v) -> (min_idx) { + min_idx min=! (S(r_n) == v) ? r_n : N +} +""" + +mindis = tc.define(lang, name="mindis") +S, v, min_idx = mindis(luts_t, codes_t) +print("minval: {} minidx: {}".format(v, min_idx)) + +reduce_codes = tc.define(lang, name="reduce_codes") +min_2d = tc.define(lang, name="min_2d") +argmin_2d = tc.define(lang, name="argmin_2d") + +S = reduce_codes(luts_t, codes_t) +v = min_2d(S) +min_idx = argmin_2d(S, v) + +print("minval: {} minidx: {}".format(v, min_idx)) + +################################################################################ +# Each reduction is probably easier to optimize with a 2-staged TC where we +# artifically increase parallelism and finish the reduction in a second kernel. +# Properly choosing D such that N = D * (N / D) should result in a good version +# with 5 kernels total. +################################################################################ +N = 10 ** 5 # bump to 10**7 when ready for primetime +D = 1000 +assert N % D == 0, "D={} must divide N={}".format(D, N) +M = 32 + +lang = """ +def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) { + S(n) +=! L(r_m, int32(Codes(n, r_m))) +} +def min_2d(float(D, NBYD) S) -> (V) { + V(d) min=! S(d, r_nbyd) +} +def min_1d(float(D) V) -> (v) { + v min=! V(r_d) +} +def argmin_2d(float(D, NBYD) S, float v) -> (MinIdx) { + MinIdx(d) min=! (S(d, r_nbyd) == v) ? d * NBYD + r_nbyd : N +} +def argmin_1d(float(N) S, int32(D) MinIdx) -> (min_idx) { + min_idx min=! (MinIdx(r_d) < N) ? r_d : N +} +""" + +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') +codes = codes.view('uint8') +luts = np.random.randn(M, 256).astype('float32') + +codes_t = torch.from_numpy(codes).cuda() +luts_t = torch.from_numpy(luts).cuda() + +reduce_codes = tc.define(lang, name="reduce_codes") +min_2d = tc.define(lang, name="min_2d") +min_1d = tc.define(lang, name="min_1d") +argmin_2d = tc.define(lang, name="argmin_2d") +argmin_1d = tc.define(lang, name="argmin_1d") + +S = reduce_codes(luts_t, codes_t) +V = min_2d(S.view((D, N / D))) +v = min_1d(V) +MinIdx = argmin_2d(S.view((D, N / D)), v) +min_idx = argmin_1d(S, MinIdx) +print("minval: {} minidx: {}".format(v, min_idx)) + +################################################################################ +# Longer form version has an extra k dimension we could use for parallelism +# Unfortunately is it a small dimension (16) so it won't saturate Pascal/Volta. +# So we may want to split in 5 to run efficiently. +################################################################################ +N = 10 ** 7 # bump to 10**7 when ready for primetime +D = 1000 +assert N % D == 0, "D={} must divide N={}".format(D, N) +M = 32 +K = 16 +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') +codes = codes.view('uint8') +luts = np.random.randn(K, M, 256).astype('float32') + +codes_t = torch.from_numpy(codes).cuda() +luts_t = torch.from_numpy(luts).cuda() + +lang = """ +def mindis(float(K, M, 256) L, uint8(N, M) Codes) -> (S, V, MinIdx) { + S(k, n) +=! L(k, r_m, int32(Codes(n, r_m))) + V(k) min=! S(k, r_n) + MinIdx(k) min=! (S(k, r_n) == V(k)) ? r_n : N +} +""" + +debug = False +set_logtostderr(debug) +set_debug_tc_mapper(debug) +set_debug_cuda(debug) + +mindis = tc.define(lang, name="mindis") +S, V, MinIdx = mindis(luts_t, codes_t) +print("minvals: {}\nminidxs: {}".format(V, MinIdx)) + +lang = """ +def reduce_codes(float(K, M, 256) L, uint8(N, M) Codes) -> (S) { + S(k, n) +=! L(k, r_m, int32(Codes(n, r_m))) +} +def min_2d(float(K, D, NBYD) S) -> (V2) { + V2(k, d) min=! S(k, d, r_nbyd) +} +def min_1d(float(K, D) V2) -> (V) { + V(k) min=! V2(k, r_d) +} +def argmin_2d(float(K, D, NBYD) S, float(K) V) -> (MinIdx2) { + MinIdx2(k, d) min=! (S(k, d, r_nbyd) == V(k)) ? d * NBYD + r_nbyd : N +} +def argmin_1d(float(K, N) S, int32(K, D) MinIdx2) -> (MinIdx) { + MinIdx(k) min=! (MinIdx2(k, r_d) < N) ? r_d : N +} +""" + +reduce_codes = tc.define(lang, name="reduce_codes") +min_2d = tc.define(lang, name="min_2d") +min_1d = tc.define(lang, name="min_1d") +argmin_2d = tc.define(lang, name="argmin_2d") +argmin_1d = tc.define(lang, name="argmin_1d") + +S = reduce_codes(luts_t, codes_t) +V2 = min_2d(S.view((K, D, N / D))) +V = min_1d(V2) +MinIdx2 = argmin_2d(S.view((K, D, N / D)), V) +MinIdx = argmin_1d(S, MinIdx2) +print("minval: {} minidx: {}".format(V, MinIdx)) diff --git a/tc/aten/aten_compiler.cc b/tc/aten/aten_compiler.cc index b34fbd316..7b9f7b699 100644 --- a/tc/aten/aten_compiler.cc +++ b/tc/aten/aten_compiler.cc @@ -59,9 +59,22 @@ std::vector prepareOutputs( auto atenBackend = inputs[0].type().backend(); for (size_t i = 0; i < outTensorInfo.size(); ++i) { TensorInfo info(outTensorInfo[i]); - auto stype = at::toScalarType(info.dtype); - outputs.push_back( - at::getType(atenBackend, stype).tensor(at::IntList(info.shape))); + if (info.dtype.lanes == 1) { + auto stype = at::toScalarType(info.dtype); + outputs.push_back( + at::getType(atenBackend, stype).tensor(at::IntList(info.shape))); + } else { + // "Cast" to a larger strided tensor with 1 lane + auto lanes = info.dtype.lanes; + TC_CHECK(lanes == 2 || lanes == 4); + info.dtype.lanes = 1; + info.shape[info.shape.size() - 1] *= lanes; + auto stype = at::toScalarType(info.dtype); + auto T = at::getType(atenBackend, stype).tensor(at::IntList(info.shape)); + + info.shape[info.shape.size() - 1] /= lanes; + outputs.push_back(T.set_(*T.storage(), 0, info.shape)); + } } return outputs; } diff --git a/tc/core/libraries.h b/tc/core/libraries.h index c86eb565d..196823b6c 100644 --- a/tc/core/libraries.h +++ b/tc/core/libraries.h @@ -31,11 +31,92 @@ namespace code { namespace c { constexpr auto types = R"C( +// Can't include system dependencies with NVRTC +// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies +// #include + // Halide type handling +typedef char int8; +typedef short int16; typedef int int32; typedef long int64; +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +// typedef half float16; typedef float float32; typedef double float64; + +template +struct array2 { + T x, y; + array2(T t) : x(t), y(t) {} + array2(T x, T y) : x(x), y(y) {} + array2 operator+(const array2& a) const { + return array2{ + static_cast(x + a.x), + static_cast(y + a.y) + }; + } + array2& operator+=(const array2& a) { + x += a.x; + y += a.y; + return *this; + } +}; +template +array2 x2(Type t) { + return array2(t); +} +typedef array2 int8x2; +typedef array2 int16x2; +typedef array2 int32x2; +typedef array2 int64x2; +typedef array2 uint8x2; +typedef array2 uint16x2; +typedef array2 uint32x2; +typedef array2 uint64x2; +// typedef array2 float16x2; +typedef array2 float32x2; +typedef array2 float64x2; + +template +struct array4 { + T x, y, z, w; + array4(T t) : x(t), y(t), z(t), w(t) {} + array4(T x, T y, T z, T w) : x(x), y(y), z(z), w(w) {} + array4 operator+(const array4& a) const { + return array4{ + static_cast(x + a.x), + static_cast(y + a.y), + static_cast(z + a.z), + static_cast(w + a.w) + }; + } + array4& operator+=(const array4& a) { + x += a.x; + y += a.y; + z += a.z; + w += a.w; + return *this; + } +}; +template +array4 x4(Type t) { + return array4(t); +} +typedef array4 int8x4; +typedef array4 int16x4; +typedef array4 int32x4; +typedef array4 int64x4; +typedef array4 uint8x4; +typedef array4 uint16x4; +typedef array4 uint32x4; +typedef array4 uint64x4; +// typedef array4 float16x4; +typedef array4 float32x4; +typedef array4 float64x4; )C"; constexpr auto defines = R"C( diff --git a/tc/core/tc2halide.cc b/tc/core/tc2halide.cc index 01da89edf..82df69d4a 100644 --- a/tc/core/tc2halide.cc +++ b/tc/core/tc2halide.cc @@ -37,14 +37,6 @@ Type translateScalarType(int tcType) { switch (tcType) { case lang::TK_BOOL: return UInt(1); - case lang::TK_UINT8: - return UInt(8); - case lang::TK_UINT16: - return UInt(16); - case lang::TK_UINT32: - return UInt(32); - case lang::TK_UINT64: - return UInt(64); case lang::TK_INT8: return Int(8); case lang::TK_INT16: @@ -53,10 +45,83 @@ Type translateScalarType(int tcType) { return Int(32); case lang::TK_INT64: return Int(64); + case lang::TK_UINT8: + return UInt(8); + case lang::TK_UINT16: + return UInt(16); + case lang::TK_UINT32: + return UInt(32); + case lang::TK_UINT64: + return UInt(64); + case lang::TK_FLOAT16: + return Float(16); + case lang::TK_FLOAT32: + return Float(32); + case lang::TK_FLOAT64: + return Float(64); case lang::TK_FLOAT: return Float(32); case lang::TK_DOUBLE: return Float(64); + + case lang::TK_VECTOR2_BOOL: + return UInt(1, 2); + case lang::TK_VECTOR2_INT8: + return Int(8, 2); + case lang::TK_VECTOR2_INT16: + return Int(16, 2); + case lang::TK_VECTOR2_INT32: + return Int(32, 2); + case lang::TK_VECTOR2_INT64: + return Int(64, 2); + case lang::TK_VECTOR2_UINT8: + return UInt(8, 2); + case lang::TK_VECTOR2_UINT16: + return UInt(16, 2); + case lang::TK_VECTOR2_UINT32: + return UInt(32, 2); + case lang::TK_VECTOR2_UINT64: + return UInt(64, 2); + case lang::TK_VECTOR2_FLOAT16: + return Float(16, 2); + case lang::TK_VECTOR2_FLOAT32: + return Float(32, 2); + case lang::TK_VECTOR2_FLOAT64: + return Float(64, 2); + case lang::TK_VECTOR2_FLOAT: + return Float(32, 2); + case lang::TK_VECTOR2_DOUBLE: + return Float(64, 2); + + case lang::TK_VECTOR4_BOOL: + return UInt(1, 4); + case lang::TK_VECTOR4_INT8: + return Int(8, 4); + case lang::TK_VECTOR4_INT16: + return Int(16, 4); + case lang::TK_VECTOR4_INT32: + return Int(32, 4); + case lang::TK_VECTOR4_INT64: + return Int(64, 4); + case lang::TK_VECTOR4_UINT8: + return UInt(8, 4); + case lang::TK_VECTOR4_UINT16: + return UInt(16, 4); + case lang::TK_VECTOR4_UINT32: + return UInt(32, 4); + case lang::TK_VECTOR4_UINT64: + return UInt(64, 4); + case lang::TK_VECTOR4_FLOAT16: + return Float(16, 4); + case lang::TK_VECTOR4_FLOAT32: + return Float(32, 4); + case lang::TK_VECTOR4_FLOAT64: + return Float(64, 4); + case lang::TK_VECTOR4_FLOAT: + return Float(32, 4); + case lang::TK_VECTOR4_DOUBLE: + return Float(64, 4); + default: LOG(FATAL) << "Unhandled TC scalar type: " << tcType << '\n'; return Type(); diff --git a/tc/lang/lexer.h b/tc/lang/lexer.h index 9e40a3092..1fe35e439 100644 --- a/tc/lang/lexer.h +++ b/tc/lang/lexer.h @@ -34,57 +34,92 @@ namespace lang { // Some kinds TK_APPLY, TK_LIST are only used in the AST and are not seen in the // lexer. -#define TC_FORALL_TOKEN_KINDS(_) \ - _(TK_EOF, "eof", "") \ - _(TK_NUMBER, "number", "") \ - _(TK_BOOL_VALUE, "bool_value", "") \ - _(TK_MIN, "min", "min") \ - _(TK_MAX, "max", "max") \ - _(TK_WHERE, "where", "where") \ - _(TK_FLOAT, "float", "float") \ - _(TK_DOUBLE, "double", "double") \ - _(TK_DEF, "def", "def") \ - _(TK_ARROW, "arrow", "->") \ - _(TK_EQUIVALENT, "equivalent", "<=>") \ - _(TK_IDENT, "ident", "") \ - _(TK_STRING, "string", "") \ - _(TK_CONST, "const", "") \ - _(TK_LIST, "list", "") \ - _(TK_OPTION, "option", "") \ - _(TK_APPLY, "apply", "") \ - _(TK_COMPREHENSION, "comprehension", "") \ - _(TK_TENSOR_TYPE, "tensor_type", "") \ - _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ - _(TK_PARAM, "param", "") \ - _(TK_INFERRED, "inferred", "") \ - _(TK_ACCESS, "access", "") \ - _(TK_BUILT_IN, "built-in", "") \ - _(TK_PLUS_EQ, "plus_eq", "+=") \ - _(TK_TIMES_EQ, "times_eq", "*=") \ - _(TK_MIN_EQ, "min_eq", "min=") \ - _(TK_MAX_EQ, "max_eq", "max=") \ - _(TK_PLUS_EQ_B, "plus_eq_b", "+=!") \ - _(TK_TIMES_EQ_B, "times_eq_b", "*=!") \ - _(TK_MIN_EQ_B, "min_eq_b", "min=!") \ - _(TK_MAX_EQ_B, "max_eq_b", "max=!") \ - _(TK_INT8, "int8", "int8") \ - _(TK_INT16, "int16", "int16") \ - _(TK_INT32, "int32", "int32") \ - _(TK_INT64, "int64", "int64") \ - _(TK_UINT8, "uint8", "uint8") \ - _(TK_UINT16, "uint16", "uint16") \ - _(TK_UINT32, "uint32", "uint32") \ - _(TK_UINT64, "uint64", "uint64") \ - _(TK_BOOL, "bool", "bool") \ - _(TK_CAST, "cast", "") \ - _(TK_IN, "in", "in") \ - _(TK_GE, "ge", ">=") \ - _(TK_LE, "le", "<=") \ - _(TK_EQ, "eq", "==") \ - _(TK_NE, "neq", "!=") \ - _(TK_AND, "and", "&&") \ - _(TK_OR, "or", "||") \ - _(TK_LET, "let", "") \ +#define TC_FORALL_TOKEN_KINDS(_) \ + _(TK_EOF, "eof", "") \ + _(TK_NUMBER, "number", "") \ + _(TK_BOOL_VALUE, "bool_value", "") \ + _(TK_MIN, "min", "min") \ + _(TK_MAX, "max", "max") \ + _(TK_WHERE, "where", "where") \ + _(TK_DEF, "def", "def") \ + _(TK_ARROW, "arrow", "->") \ + _(TK_EQUIVALENT, "equivalent", "<=>") \ + _(TK_IDENT, "ident", "") \ + _(TK_STRING, "string", "") \ + _(TK_CONST, "const", "") \ + _(TK_LIST, "list", "") \ + _(TK_OPTION, "option", "") \ + _(TK_APPLY, "apply", "") \ + _(TK_COMPREHENSION, "comprehension", "") \ + _(TK_TENSOR_TYPE, "tensor_type", "") \ + _(TK_RANGE_CONSTRAINT, "range_constraint", "") \ + _(TK_PARAM, "param", "") \ + _(TK_INFERRED, "inferred", "") \ + _(TK_ACCESS, "access", "") \ + _(TK_BUILT_IN, "built-in", "") \ + _(TK_PLUS_EQ, "plus_eq", "+=") \ + _(TK_TIMES_EQ, "times_eq", "*=") \ + _(TK_MIN_EQ, "min_eq", "min=") \ + _(TK_MAX_EQ, "max_eq", "max=") \ + _(TK_PLUS_EQ_B, "plus_eq_b", "+=!") \ + _(TK_TIMES_EQ_B, "times_eq_b", "*=!") \ + _(TK_MIN_EQ_B, "min_eq_b", "min=!") \ + _(TK_MAX_EQ_B, "max_eq_b", "max=!") \ + \ + _(TK_BOOL, "bool", "bool") \ + _(TK_UINT8, "uint8", "uint8") \ + _(TK_UINT16, "uint16", "uint16") \ + _(TK_UINT32, "uint32", "uint32") \ + _(TK_UINT64, "uint64", "uint64") \ + _(TK_INT8, "int8", "int8") \ + _(TK_INT16, "int16", "int16") \ + _(TK_INT32, "int32", "int32") \ + _(TK_INT64, "int64", "int64") \ + _(TK_FLOAT16, "float16", "float16") \ + _(TK_FLOAT32, "float32", "float32") \ + _(TK_FLOAT64, "float64", "float64") \ + _(TK_FLOAT, "float", "float") \ + _(TK_DOUBLE, "double", "double") \ + \ + _(TK_VECTOR2_BOOL, "boolx2", "boolx2") \ + _(TK_VECTOR2_UINT8, "uint8x2", "uint8x2") \ + _(TK_VECTOR2_UINT16, "uint16x2", "uint16x2") \ + _(TK_VECTOR2_UINT32, "uint32x2", "uint32x2") \ + _(TK_VECTOR2_UINT64, "uint64x2", "uint64x2") \ + _(TK_VECTOR2_INT8, "int8x2", "int8x2") \ + _(TK_VECTOR2_INT16, "int16x2", "int16x2") \ + _(TK_VECTOR2_INT32, "int32x2", "int32x2") \ + _(TK_VECTOR2_INT64, "int64x2", "int64x2") \ + _(TK_VECTOR2_FLOAT16, "float16x2", "float16x2") \ + _(TK_VECTOR2_FLOAT32, "float32x2", "float32x2") \ + _(TK_VECTOR2_FLOAT64, "float64x2", "float64x2") \ + _(TK_VECTOR2_FLOAT, "floatx2", "floatx2") \ + _(TK_VECTOR2_DOUBLE, "doublex2", "doublex2") \ + \ + _(TK_VECTOR4_BOOL, "boolx4", "boolx4") \ + _(TK_VECTOR4_UINT8, "uint8x4", "uint8x4") \ + _(TK_VECTOR4_UINT16, "uint16x4", "uint16x4") \ + _(TK_VECTOR4_UINT32, "uint32x4", "uint32x4") \ + _(TK_VECTOR4_UINT64, "uint64x4", "uint64x4") \ + _(TK_VECTOR4_INT8, "int8x4", "int8x4") \ + _(TK_VECTOR4_INT16, "int16x4", "int16x4") \ + _(TK_VECTOR4_INT32, "int32x4", "int32x4") \ + _(TK_VECTOR4_INT64, "int64x4", "int64x4") \ + _(TK_VECTOR4_FLOAT16, "float16x4", "float16x4") \ + _(TK_VECTOR4_FLOAT32, "float32x4", "float32x4") \ + _(TK_VECTOR4_FLOAT64, "float64x4", "float64x4") \ + _(TK_VECTOR4_FLOAT, "floatx4", "floatx4") \ + _(TK_VECTOR4_DOUBLE, "doublex4", "doublex4") \ + \ + _(TK_CAST, "cast", "") \ + _(TK_IN, "in", "in") \ + _(TK_GE, "ge", ">=") \ + _(TK_LE, "le", "<=") \ + _(TK_EQ, "eq", "==") \ + _(TK_NE, "neq", "!=") \ + _(TK_AND, "and", "&&") \ + _(TK_OR, "or", "||") \ + _(TK_LET, "let", "") \ _(TK_EXISTS, "exists", "exists") static const char* valid_single_char_tokens = "+-*/()[]?:,={}>kind()) { -#define TYPE_INFO_OPTION(tok, c, b) \ - case tok: \ - code_ = c; \ - bits_ = b; \ +#define TYPE_INFO_OPTION(tok, c, b, l) \ + case tok: \ + code_ = c; \ + bits_ = b; \ + lanes_ = l; \ break; - TYPE_INFO_OPTION(TK_BOOL, UInt, 1) - TYPE_INFO_OPTION(TK_UINT8, UInt, 8) - TYPE_INFO_OPTION(TK_UINT16, UInt, 16) - TYPE_INFO_OPTION(TK_UINT32, UInt, 32) - TYPE_INFO_OPTION(TK_UINT64, UInt, 64) - TYPE_INFO_OPTION(TK_INT8, Int, 8) - TYPE_INFO_OPTION(TK_INT16, Int, 16) - TYPE_INFO_OPTION(TK_INT32, Int, 32) - TYPE_INFO_OPTION(TK_INT64, Int, 64) - TYPE_INFO_OPTION(TK_FLOAT, Float, 32) - TYPE_INFO_OPTION(TK_DOUBLE, Float, 64) + TYPE_INFO_OPTION(TK_BOOL, UInt, 1, 1) + TYPE_INFO_OPTION(TK_UINT8, UInt, 8, 1) + TYPE_INFO_OPTION(TK_UINT16, UInt, 16, 1) + TYPE_INFO_OPTION(TK_UINT32, UInt, 32, 1) + TYPE_INFO_OPTION(TK_UINT64, UInt, 64, 1) + TYPE_INFO_OPTION(TK_INT8, Int, 8, 1) + TYPE_INFO_OPTION(TK_INT16, Int, 16, 1) + TYPE_INFO_OPTION(TK_INT32, Int, 32, 1) + TYPE_INFO_OPTION(TK_INT64, Int, 64, 1) + TYPE_INFO_OPTION(TK_FLOAT16, Float, 16, 1) + TYPE_INFO_OPTION(TK_FLOAT32, Float, 32, 1) + TYPE_INFO_OPTION(TK_FLOAT64, Float, 64, 1) + TYPE_INFO_OPTION(TK_FLOAT, Float, 32, 1) + TYPE_INFO_OPTION(TK_DOUBLE, Float, 64, 1) + + TYPE_INFO_OPTION(TK_VECTOR2_BOOL, UInt, 1, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT8, UInt, 8, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT16, UInt, 16, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT32, UInt, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_UINT64, UInt, 64, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT8, Int, 8, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT16, Int, 16, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT32, Int, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_INT64, Int, 64, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT16, Float, 16, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT32, Float, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT64, Float, 64, 2) + TYPE_INFO_OPTION(TK_VECTOR2_FLOAT, Float, 32, 2) + TYPE_INFO_OPTION(TK_VECTOR2_DOUBLE, Float, 64, 2) + + TYPE_INFO_OPTION(TK_VECTOR4_BOOL, UInt, 1, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT8, UInt, 8, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT16, UInt, 16, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT32, UInt, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_UINT64, UInt, 64, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT8, Int, 8, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT16, Int, 16, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT32, Int, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_INT64, Int, 64, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT16, Float, 16, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT32, Float, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT64, Float, 64, 4) + TYPE_INFO_OPTION(TK_VECTOR4_FLOAT, Float, 32, 4) + TYPE_INFO_OPTION(TK_VECTOR4_DOUBLE, Float, 64, 4) + #undef TYPE_INFO_OPTION default: throw ErrorReport(scalar_type) @@ -55,39 +91,116 @@ struct TypeInfo { } } int toScalarToken() const { - switch (code()) { - case UInt: - switch (bits()) { - case 1: - return TK_BOOL; - case 8: - return TK_UINT8; - case 16: - return TK_UINT16; - case 32: - return TK_UINT32; - case 64: - return TK_UINT64; - } - case Int: - switch (bits()) { - case 8: - return TK_INT8; - case 16: - return TK_INT16; - case 32: - return TK_INT32; - case 64: - return TK_INT64; - } - case Float: - switch (bits()) { - case 32: - return TK_FLOAT; - case 64: - return TK_DOUBLE; - } + if (lanes() == 1) { + switch (code()) { + case UInt: + switch (bits()) { + case 1: + return TK_BOOL; + case 8: + return TK_UINT8; + case 16: + return TK_UINT16; + case 32: + return TK_UINT32; + case 64: + return TK_UINT64; + } + case Int: + switch (bits()) { + case 8: + return TK_INT8; + case 16: + return TK_INT16; + case 32: + return TK_INT32; + case 64: + return TK_INT64; + } + case Float: + switch (bits()) { + case 16: + return TK_FLOAT16; + case 32: + return TK_FLOAT; + case 64: + return TK_DOUBLE; + } + } + } else if (lanes() == 2) { + switch (code()) { + case UInt: + switch (bits()) { + case 1: + return TK_VECTOR2_BOOL; + case 8: + return TK_VECTOR2_UINT8; + case 16: + return TK_VECTOR2_UINT16; + case 32: + return TK_VECTOR2_UINT32; + case 64: + return TK_VECTOR2_UINT64; + } + case Int: + switch (bits()) { + case 8: + return TK_VECTOR2_INT8; + case 16: + return TK_VECTOR2_INT16; + case 32: + return TK_VECTOR2_INT32; + case 64: + return TK_VECTOR2_INT64; + } + case Float: + switch (bits()) { + case 16: + return TK_VECTOR2_FLOAT16; + case 32: + return TK_VECTOR2_FLOAT; + case 64: + return TK_VECTOR2_DOUBLE; + } + } + } else if (lanes() == 4) { + switch (code()) { + case UInt: + switch (bits()) { + case 1: + return TK_VECTOR4_BOOL; + case 8: + return TK_VECTOR4_UINT8; + case 16: + return TK_VECTOR4_UINT16; + case 32: + return TK_VECTOR4_UINT32; + case 64: + return TK_VECTOR4_UINT64; + } + case Int: + switch (bits()) { + case 8: + return TK_VECTOR4_INT8; + case 16: + return TK_VECTOR4_INT16; + case 32: + return TK_VECTOR4_INT32; + case 64: + return TK_VECTOR4_INT64; + } + case Float: + switch (bits()) { + case 16: + return TK_VECTOR4_FLOAT16; + case 32: + return TK_VECTOR4_FLOAT; + case 64: + return TK_VECTOR4_DOUBLE; + } + } } + throw std::runtime_error("Unknown type info?"); } Code code() const { @@ -96,6 +209,9 @@ struct TypeInfo { uint8_t bits() const { return bits_; } + uint8_t lanes() const { + return lanes_; + } bool is_float() const { return code_ == Float; } @@ -106,10 +222,11 @@ struct TypeInfo { private: Code code_; uint8_t bits_; + uint8_t lanes_; }; static inline bool operator==(TypeInfo a, TypeInfo b) { - return a.bits() == b.bits() && a.code() == b.code(); + return a.bits() == b.bits() && a.code() == b.code() && a.lanes() == b.lanes(); } static inline TreeRef match_types(TreeRef a, TreeRef b) { @@ -118,34 +235,49 @@ static inline TreeRef match_types(TreeRef a, TreeRef b) { if (ta == tb) return a; - if (!ta.is_float() && tb.is_float()) { +#define NO_MATCH() \ + throw ErrorReport(b) << "Could not match types: " \ + << kindToString(ta.toScalarToken()) << ", " \ + << kindToString(tb.toScalarToken()); + + if (!ta.is_float() && ta.lanes() == 1 && tb.is_float()) { // int(a) * float(b) -> float(b) // uint(a) * float(b) -> float(b) return b; - } else if (ta.is_float() && !tb.is_float()) { + } else if (ta.is_float() && !tb.is_float() && tb.lanes() == 1) { return a; } else if (ta.is_float() && tb.is_float()) { // float(a) * float(b) -> float(max(a, b)) - if (ta.bits() > tb.bits()) + if (ta.bits() > tb.bits() && tb.lanes() == 1) return a; - else + else if (ta.lanes() == 1) return b; - } else if (ta.is_uint() && tb.is_uint()) { + else { + NO_MATCH(); + } + } else if (ta.is_uint() && tb.is_uint() && tb.lanes() == 1) { // uint(a) * uint(b) -> uint(max(a, b)) - if (ta.bits() > tb.bits()) + if (ta.bits() > tb.bits() && tb.lanes() == 1) return a; - else + else if (ta.lanes() == 1) return b; + else { + NO_MATCH(); + } } else if (!ta.is_float() && !tb.is_float()) { // int(a) * (u)int(b) -> int(max(a, b)) int bits = std::max(ta.bits(), tb.bits()); + if ((bits == ta.bits() && tb.lanes() != 1) || + (bits == tb.bits() && ta.lanes() != 1)) { + NO_MATCH(); + } return Compound::create( TypeInfo(TypeInfo::Int, bits).toScalarToken(), a->range(), {}); } else { - throw ErrorReport(b) << "Could not match types: " - << kindToString(ta.toScalarToken()) << ", " - << kindToString(tb.toScalarToken()); + NO_MATCH(); } + +#undef NO_MATCH } /// Semantic analysis transforms the raw AST into a diff --git a/test/cuda/test_compile_and_run.cc b/test/cuda/test_compile_and_run.cc index 9dc2bde61..1113c0b7f 100644 --- a/test/cuda/test_compile_and_run.cc +++ b/test/cuda/test_compile_and_run.cc @@ -275,6 +275,72 @@ def cast(float(M,N) A, int32 four) -> (int32(M,N) output) { TC_CHECK_EQ(r, 0); } +TEST_F(CompilationTest, Types) { + struct TypeMatch { + std::string s; + at::ScalarType a; + uint8_t lanes; + }; + for (auto type : + {// TypeMatch{"bool", at::ScalarType::Bool, 1}, // no aten version + TypeMatch{"uint8", at::ScalarType::Byte, 1}, + // TypeMatch{"uint16", at::ScalarType::Short, 1}, // no aten version + // TypeMatch{"uint32", at::ScalarType::Int, 1}, // no aten version + // TypeMatch{"uint64", at::ScalarType::Long, 1}, // no aten version + TypeMatch{"int8", at::ScalarType::Char, 1}, + TypeMatch{"int16", at::ScalarType::Short, 1}, + TypeMatch{"int32", at::ScalarType::Int, 1}, + TypeMatch{"int64", at::ScalarType::Long, 1}, + // NVRTC include transitive dependencies issue + // TypeMatch{"float16", at::ScalarType::Half, 1}, + TypeMatch{"float32", at::ScalarType::Float, 1}, + TypeMatch{"float64", at::ScalarType::Double, 1}, + TypeMatch{"float", at::ScalarType::Float, 1}, + TypeMatch{"double", at::ScalarType::Double, 1}, + + // TypeMatch{"boolx2", at::ScalarType::Bool, 2}, // no aten version + TypeMatch{"uint8x2", at::ScalarType::Byte, 2}, + // TypeMatch{"uint16x2", at::ScalarType::Short, 2}, // no aten version + // TypeMatch{"uint32x2", at::ScalarType::Int, 2}, // no aten version + // TypeMatch{"uint64x2", at::ScalarType::Long, 2}, // no aten version + TypeMatch{"int8x2", at::ScalarType::Char, 2}, + TypeMatch{"int16x2", at::ScalarType::Short, 2}, + TypeMatch{"int32x2", at::ScalarType::Int, 2}, + TypeMatch{"int64x2", at::ScalarType::Long, 2}, + // NVRTC include transitive dependencies issue + // TypeMatch{"float16x2", at::ScalarType::Half, 2}, + TypeMatch{"float32x2", at::ScalarType::Float, 2}, + TypeMatch{"float64x2", at::ScalarType::Double, 2}, + TypeMatch{"floatx2", at::ScalarType::Float, 2}, + TypeMatch{"doublex2", at::ScalarType::Double, 2}, + + // TypeMatch{"boolx4", at::ScalarType::Bool, 4}, // no aten version + TypeMatch{"uint8x4", at::ScalarType::Byte, 4}, + // TypeMatch{"uint16x4", at::ScalarType::Short, 4}, // no aten version + // TypeMatch{"uint32x4", at::ScalarType::Int, 4}, // no aten version + // TypeMatch{"uint64x4", at::ScalarType::Long, 4}, // no aten version + TypeMatch{"int8x4", at::ScalarType::Char, 4}, + TypeMatch{"int16x4", at::ScalarType::Short, 4}, + TypeMatch{"int32x4", at::ScalarType::Int, 4}, + TypeMatch{"int64x4", at::ScalarType::Long, 4}, + // NVRTC include transitive dependencies issue + // TypeMatch{"float16x4", at::ScalarType::Half, 4}, + TypeMatch{"float32x4", at::ScalarType::Float, 4}, + TypeMatch{"float64x4", at::ScalarType::Double, 4}, + TypeMatch{"floatx4", at::ScalarType::Float, 4}, + TypeMatch{"doublex4", at::ScalarType::Double, 4}}) { + std::string tc = std::string("def test_type(") + std::string(type.s) + + std::string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }"); + + auto T = at::CUDA(type.a).ones({100 * type.lanes}); + std::vector outputs = Check( + tc, + "test_type", + tc::CudaMappingOptions::makeNaiveMappingOptions(), + {T.set_(*T.storage(), {100}, {type.lanes})}); + } +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true); diff --git a/test/test_tc2halide.cc b/test/test_tc2halide.cc index 128afacca..a36b2a85b 100644 --- a/test/test_tc2halide.cc +++ b/test/test_tc2halide.cc @@ -197,6 +197,60 @@ def foo(float(N) A) -> (B) { )TC"; EXPECT_THROW(Check(tc), ::lang::ErrorReport); } + +TEST_F(TC2Isl, Types) { + for (auto type : { + "bool", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + "float", + "double", + // + "boolx2", + "uint8x2", + "uint16x2", + "uint32x2", + "uint64x2", + "int8x2", + "int16x2", + "int32x2", + "int64x2", + "float16x2", + "float32x2", + "float64x2", + "floatx2", + "doublex2", + // + "boolx4", + "uint8x4", + "uint16x4", + "uint32x4", + "uint64x4", + "int8x4", + "int16x4", + "int32x4", + "int64x4", + "float16x4", + "float32x4", + "float64x4", + "floatx4", + "doublex4", + }) { + string tc = string("def test_type(") + string(type) + + string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }"); + Check(tc); + } +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::gflags::ParseCommandLineFlags(&argc, &argv, true);