Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 7e2df7c

Browse files
Merge pull request #512 from nicolasvasilache/pr/types
Types support and min_distance.py function example
2 parents af9b3fb + c20658a commit 7e2df7c

File tree

8 files changed

+298
-12
lines changed

8 files changed

+298
-12
lines changed

.jenkins/build.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREF
6969
python setup.py install
7070
./test_python/run_test.sh
7171

72+
for f in $(find ./python/examples -name "*.py"); do
73+
python $f
74+
done
75+
7276
FILTER_OUT="benchmark_MLP_model benchmark_kronecker" ./test.sh
7377
# 2LUT can OOM on smaller Maxwells on our CI machines
7478
./build/tc/benchmarks/benchmark_MLP_model --gtest_filter=-*2LUT*

python/examples/min_distance.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
##############################################################################
15+
import tensor_comprehensions as tc
16+
from tensor_comprehensions.tc import set_logtostderr
17+
from tensor_comprehensions.tc import set_debug_tc_mapper
18+
from tensor_comprehensions.tc import set_debug_cuda
19+
20+
import numpy as np
21+
import torch
22+
23+
#
24+
## Example submitted by @mdouze, originally related to uint8 type support
25+
#
26+
27+
debug = False
28+
set_logtostderr(debug)
29+
set_debug_tc_mapper(debug)
30+
set_debug_cuda(debug)
31+
32+
N = 1000
33+
M = 32
34+
35+
codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32')
36+
codes = codes.view('uint8')
37+
luts = np.random.randn(M, 256).astype('float32')
38+
39+
codes_t = torch.from_numpy(codes).cuda()
40+
luts_t = torch.from_numpy(luts).cuda()
41+
42+
lang = """
43+
# mindis as a single kernel will require grid synchronization to run efficiently
44+
def mindis(float(M, 256) L, uint8(N, M) Codes) -> (S, v, min_idx) {
45+
S(n) +=! L(r_m, int32(Codes(n, r_m)))
46+
v min=! S(r_n)
47+
min_idx min=! (S(r_n) == v) ? r_n : N
48+
}
49+
50+
# Even when splitting in 3 kernels, global device reduction will be needed to
51+
# run efficiently
52+
# don't try to run it with large sizes for now
53+
def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) {
54+
S(n) +=! L(r_m, int32(Codes(n, r_m)))
55+
}
56+
def min_2d(float(N) S) -> (v) {
57+
v min=! S(r_n)
58+
}
59+
def argmin_2d(float(N) S, float v) -> (min_idx) {
60+
min_idx min=! (S(r_n) == v) ? r_n : N
61+
}
62+
"""
63+
64+
mindis = tc.define(lang, name="mindis")
65+
S, v, min_idx = mindis(luts_t, codes_t)
66+
print("minval: {} minidx: {}".format(v, min_idx))
67+
68+
reduce_codes = tc.define(lang, name="reduce_codes")
69+
min_2d = tc.define(lang, name="min_2d")
70+
argmin_2d = tc.define(lang, name="argmin_2d")
71+
72+
S = reduce_codes(luts_t, codes_t)
73+
v = min_2d(S)
74+
min_idx = argmin_2d(S, v)
75+
76+
print("minval: {} minidx: {}".format(v, min_idx))
77+
78+
################################################################################
79+
# Each reduction is probably easier to optimize with a 2-staged TC where we
80+
# artifically increase parallelism and finish the reduction in a second kernel.
81+
# Properly choosing D such that N = D * (N / D) should result in a good version
82+
# with 5 kernels total.
83+
################################################################################
84+
N = 10 ** 5 # bump to 10**7 when ready for primetime
85+
D = 1000
86+
assert N % D == 0, "D={} must divide N={}".format(D, N)
87+
M = 32
88+
89+
lang = """
90+
def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) {
91+
S(n) +=! L(r_m, int32(Codes(n, r_m)))
92+
}
93+
def min_2d(float(D, NBYD) S) -> (V) {
94+
V(d) min=! S(d, r_nbyd)
95+
}
96+
def min_1d(float(D) V) -> (v) {
97+
v min=! V(r_d)
98+
}
99+
def argmin_2d(float(D, NBYD) S, float v) -> (MinIdx) {
100+
MinIdx(d) min=! (S(d, r_nbyd) == v) ? d * NBYD + r_nbyd : N
101+
}
102+
def argmin_1d(float(N) S, int32(D) MinIdx) -> (min_idx) {
103+
min_idx min=! (MinIdx(r_d) < N) ? r_d : N
104+
}
105+
"""
106+
107+
codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32')
108+
codes = codes.view('uint8')
109+
luts = np.random.randn(M, 256).astype('float32')
110+
111+
codes_t = torch.from_numpy(codes).cuda()
112+
luts_t = torch.from_numpy(luts).cuda()
113+
114+
reduce_codes = tc.define(lang, name="reduce_codes")
115+
min_2d = tc.define(lang, name="min_2d")
116+
min_1d = tc.define(lang, name="min_1d")
117+
argmin_2d = tc.define(lang, name="argmin_2d")
118+
argmin_1d = tc.define(lang, name="argmin_1d")
119+
120+
S = reduce_codes(luts_t, codes_t)
121+
V = min_2d(S.view((D, N / D)))
122+
v = min_1d(V)
123+
MinIdx = argmin_2d(S.view((D, N / D)), v)
124+
min_idx = argmin_1d(S, MinIdx)
125+
print("minval: {} minidx: {}".format(v, min_idx))
126+
127+
################################################################################
128+
# Longer form version has an extra k dimension we could use for parallelism
129+
# Unfortunately is it a small dimension (16) so it won't saturate Pascal/Volta.
130+
# So we may want to split in 5 to run efficiently.
131+
################################################################################
132+
N = 10 ** 7 # bump to 10**7 when ready for primetime
133+
D = 1000
134+
assert N % D == 0, "D={} must divide N={}".format(D, N)
135+
M = 32
136+
K = 16
137+
codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32')
138+
codes = codes.view('uint8')
139+
luts = np.random.randn(K, M, 256).astype('float32')
140+
141+
codes_t = torch.from_numpy(codes).cuda()
142+
luts_t = torch.from_numpy(luts).cuda()
143+
144+
lang = """
145+
def mindis(float(K, M, 256) L, uint8(N, M) Codes) -> (S, V, MinIdx) {
146+
S(k, n) +=! L(k, r_m, int32(Codes(n, r_m)))
147+
V(k) min=! S(k, r_n)
148+
MinIdx(k) min=! (S(k, r_n) == V(k)) ? r_n : N
149+
}
150+
"""
151+
152+
debug = False
153+
set_logtostderr(debug)
154+
set_debug_tc_mapper(debug)
155+
set_debug_cuda(debug)
156+
157+
mindis = tc.define(lang, name="mindis")
158+
S, V, MinIdx = mindis(luts_t, codes_t)
159+
print("minvals: {}\nminidxs: {}".format(V, MinIdx))
160+
161+
lang = """
162+
def reduce_codes(float(K, M, 256) L, uint8(N, M) Codes) -> (S) {
163+
S(k, n) +=! L(k, r_m, int32(Codes(n, r_m)))
164+
}
165+
def min_2d(float(K, D, NBYD) S) -> (V2) {
166+
V2(k, d) min=! S(k, d, r_nbyd)
167+
}
168+
def min_1d(float(K, D) V2) -> (V) {
169+
V(k) min=! V2(k, r_d)
170+
}
171+
def argmin_2d(float(K, D, NBYD) S, float(K) V) -> (MinIdx2) {
172+
MinIdx2(k, d) min=! (S(k, d, r_nbyd) == V(k)) ? d * NBYD + r_nbyd : N
173+
}
174+
def argmin_1d(float(K, N) S, int32(K, D) MinIdx2) -> (MinIdx) {
175+
MinIdx(k) min=! (MinIdx2(k, r_d) < N) ? r_d : N
176+
}
177+
"""
178+
179+
reduce_codes = tc.define(lang, name="reduce_codes")
180+
min_2d = tc.define(lang, name="min_2d")
181+
min_1d = tc.define(lang, name="min_1d")
182+
argmin_2d = tc.define(lang, name="argmin_2d")
183+
argmin_1d = tc.define(lang, name="argmin_1d")
184+
185+
S = reduce_codes(luts_t, codes_t)
186+
V2 = min_2d(S.view((K, D, N / D)))
187+
V = min_1d(V2)
188+
MinIdx2 = argmin_2d(S.view((K, D, N / D)), V)
189+
MinIdx = argmin_1d(S, MinIdx2)
190+
print("minval: {} minidx: {}".format(V, MinIdx))

tc/core/libraries.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,20 @@ namespace code {
3131
namespace c {
3232

3333
constexpr auto types = R"C(
34+
// Can't include system dependencies with NVRTC
35+
// Can't include cuda_fp16.h with NVRTC due to transitive system dependencies
36+
// #include <cuda_fp16.h>
37+
3438
// Halide type handling
39+
typedef char int8;
40+
typedef short int16;
3541
typedef int int32;
3642
typedef long int64;
43+
typedef unsigned char uint8;
44+
typedef unsigned short uint16;
45+
typedef unsigned int uint32;
46+
typedef unsigned long uint64;
47+
// typedef half float16;
3748
typedef float float32;
3849
typedef double float64;
3950
)C";

tc/core/tc2halide.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,17 @@ Type translateScalarType(int tcType) {
5353
return Int(32);
5454
case lang::TK_INT64:
5555
return Int(64);
56+
case lang::TK_FLOAT16:
57+
return Float(16);
58+
case lang::TK_FLOAT32:
59+
return Float(32);
60+
case lang::TK_FLOAT64:
61+
return Float(64);
5662
case lang::TK_FLOAT:
5763
return Float(32);
5864
case lang::TK_DOUBLE:
5965
return Float(64);
66+
6067
default:
6168
LOG(FATAL) << "Unhandled TC scalar type: " << tcType << '\n';
6269
return Type();

tc/lang/lexer.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ namespace lang {
4141
_(TK_MIN, "min", "min") \
4242
_(TK_MAX, "max", "max") \
4343
_(TK_WHERE, "where", "where") \
44-
_(TK_FLOAT, "float", "float") \
45-
_(TK_DOUBLE, "double", "double") \
4644
_(TK_DEF, "def", "def") \
4745
_(TK_ARROW, "arrow", "->") \
4846
_(TK_EQUIVALENT, "equivalent", "<=>") \
@@ -67,15 +65,21 @@ namespace lang {
6765
_(TK_TIMES_EQ_B, "times_eq_b", "*=!") \
6866
_(TK_MIN_EQ_B, "min_eq_b", "min=!") \
6967
_(TK_MAX_EQ_B, "max_eq_b", "max=!") \
70-
_(TK_INT8, "int8", "int8") \
71-
_(TK_INT16, "int16", "int16") \
72-
_(TK_INT32, "int32", "int32") \
73-
_(TK_INT64, "int64", "int64") \
68+
\
69+
_(TK_BOOL, "bool", "bool") \
7470
_(TK_UINT8, "uint8", "uint8") \
7571
_(TK_UINT16, "uint16", "uint16") \
7672
_(TK_UINT32, "uint32", "uint32") \
7773
_(TK_UINT64, "uint64", "uint64") \
78-
_(TK_BOOL, "bool", "bool") \
74+
_(TK_INT8, "int8", "int8") \
75+
_(TK_INT16, "int16", "int16") \
76+
_(TK_INT32, "int32", "int32") \
77+
_(TK_INT64, "int64", "int64") \
78+
_(TK_FLOAT16, "float16", "float16") \
79+
_(TK_FLOAT32, "float32", "float32") \
80+
_(TK_FLOAT64, "float64", "float64") \
81+
_(TK_FLOAT, "float", "float") \
82+
_(TK_DOUBLE, "double", "double") \
7983
_(TK_CAST, "cast", "") \
8084
_(TK_IN, "in", "in") \
8185
_(TK_GE, "ge", ">=") \
@@ -271,15 +275,18 @@ struct SharedParserData {
271275
}
272276
bool isScalarType(int kind) {
273277
switch (kind) {
274-
case TK_INT8:
275-
case TK_INT16:
276-
case TK_INT32:
277-
case TK_INT64:
278+
case TK_BOOL:
278279
case TK_UINT8:
279280
case TK_UINT16:
280281
case TK_UINT32:
281282
case TK_UINT64:
282-
case TK_BOOL:
283+
case TK_INT8:
284+
case TK_INT16:
285+
case TK_INT32:
286+
case TK_INT64:
287+
case TK_FLOAT16:
288+
case TK_FLOAT32:
289+
case TK_FLOAT64:
283290
case TK_FLOAT:
284291
case TK_DOUBLE:
285292
return true;

tc/lang/sema.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,23 @@ struct TypeInfo {
4646
TYPE_INFO_OPTION(TK_INT16, Int, 16)
4747
TYPE_INFO_OPTION(TK_INT32, Int, 32)
4848
TYPE_INFO_OPTION(TK_INT64, Int, 64)
49+
TYPE_INFO_OPTION(TK_FLOAT16, Float, 16)
50+
TYPE_INFO_OPTION(TK_FLOAT32, Float, 32)
51+
TYPE_INFO_OPTION(TK_FLOAT64, Float, 64)
4952
TYPE_INFO_OPTION(TK_FLOAT, Float, 32)
5053
TYPE_INFO_OPTION(TK_DOUBLE, Float, 64)
54+
5155
#undef TYPE_INFO_OPTION
5256
default:
5357
throw ErrorReport(scalar_type)
5458
<< "Unhandled TC scalar type: " << scalar_type;
5559
}
60+
61+
if (code_ == Code::Float && bits_ == 16) {
62+
throw ErrorReport(scalar_type)
63+
<< "Half precision floating point not supported "
64+
<< "until we can make NVRTC include system headers";
65+
}
5666
}
5767
int toScalarToken() const {
5868
switch (code()) {
@@ -82,12 +92,15 @@ struct TypeInfo {
8292
}
8393
case Float:
8494
switch (bits()) {
95+
case 16:
96+
return TK_FLOAT16;
8597
case 32:
8698
return TK_FLOAT;
8799
case 64:
88100
return TK_DOUBLE;
89101
}
90102
}
103+
91104
throw std::runtime_error("Unknown type info?");
92105
}
93106
Code code() const {

test/cuda/test_compile_and_run.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,37 @@ def cast(float(M,N) A, int32 four) -> (int32(M,N) output) {
275275
TC_CHECK_EQ(r, 0);
276276
}
277277

278+
TEST_F(CompilationTest, Types) {
279+
struct TypeMatch {
280+
std::string s;
281+
at::ScalarType a;
282+
};
283+
for (auto type :
284+
{// TypeMatch{"bool", at::ScalarType::Bool}, // no aten version
285+
TypeMatch{"uint8", at::ScalarType::Byte},
286+
// TypeMatch{"uint16", at::ScalarType::Short}, // no aten version
287+
// TypeMatch{"uint32", at::ScalarType::Int}, // no aten version
288+
// TypeMatch{"uint64", at::ScalarType::Long}, // no aten version
289+
TypeMatch{"int8", at::ScalarType::Char},
290+
TypeMatch{"int16", at::ScalarType::Short},
291+
TypeMatch{"int32", at::ScalarType::Int},
292+
TypeMatch{"int64", at::ScalarType::Long},
293+
// NVRTC include transitive dependencies issue
294+
// TypeMatch{"float16", at::ScalarType::Half},
295+
TypeMatch{"float32", at::ScalarType::Float},
296+
TypeMatch{"float64", at::ScalarType::Double},
297+
TypeMatch{"float", at::ScalarType::Float},
298+
TypeMatch{"double", at::ScalarType::Double}}) {
299+
std::string tc = std::string("def test_type(") + std::string(type.s) +
300+
std::string("(N) A) -> (B) { B(k) +=! A(i) where k in 0:1 }");
301+
std::vector<at::Tensor> outputs = Check(
302+
tc,
303+
"test_type",
304+
tc::CudaMappingOptions::makeNaiveMappingOptions(),
305+
{at::CUDA(type.a).ones({100})});
306+
}
307+
}
308+
278309
int main(int argc, char** argv) {
279310
::testing::InitGoogleTest(&argc, argv);
280311
::gflags::ParseCommandLineFlags(&argc, &argv, true);

0 commit comments

Comments
 (0)