Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit c20658a

Browse files
Add min_distances.py
This commit adds examples provided by @mdouze where the argmin over a reduced sum is required. These examples are now functional thanks to the previous commit but extra work is needed to make some of the variants perform reasonably: 1. for the fused kernel to parallelize properly across blocks we need grid synchronization. This may be a nice concrete use case @math-fehr 2. for the 1-stage fissioned implementation we need device-wide synchronization otherwise we will always be limited by running on a single SM 3. the 2-stage fissioned implementations can give us performance today after tuning. Without tuning the results on the larger size (1e7, 32, 16) are shown [here](https://gist.github.com/nicolasvasilache/8a0addfb6831a831b2dca45c612f9c2d). `mindis_16_32_10000000` is the totally fused kernel and performs evry poorly. The following 5 kernels correspond to the final use case of interest.
1 parent aa2fe9e commit c20658a

File tree

2 files changed

+194
-0
lines changed

2 files changed

+194
-0
lines changed

.jenkins/build.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ WITH_CAFFE2=ON CUDA_TOOLKIT_ROOT_DIR=/usr/local/cuda CLANG_PREFIX=$(${CONDA_PREF
6969
python setup.py install
7070
./test_python/run_test.sh
7171

72+
for f in $(find ./python/examples -name "*.py"); do
73+
python $f
74+
done
75+
7276
FILTER_OUT="benchmark_MLP_model benchmark_kronecker" ./test.sh
7377
# 2LUT can OOM on smaller Maxwells on our CI machines
7478
./build/tc/benchmarks/benchmark_MLP_model --gtest_filter=-*2LUT*

python/examples/min_distance.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
##############################################################################
15+
import tensor_comprehensions as tc
16+
from tensor_comprehensions.tc import set_logtostderr
17+
from tensor_comprehensions.tc import set_debug_tc_mapper
18+
from tensor_comprehensions.tc import set_debug_cuda
19+
20+
import numpy as np
21+
import torch
22+
23+
#
24+
## Example submitted by @mdouze, originally related to uint8 type support
25+
#
26+
27+
debug = False
28+
set_logtostderr(debug)
29+
set_debug_tc_mapper(debug)
30+
set_debug_cuda(debug)
31+
32+
N = 1000
33+
M = 32
34+
35+
codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32')
36+
codes = codes.view('uint8')
37+
luts = np.random.randn(M, 256).astype('float32')
38+
39+
codes_t = torch.from_numpy(codes).cuda()
40+
luts_t = torch.from_numpy(luts).cuda()
41+
42+
lang = """
43+
# mindis as a single kernel will require grid synchronization to run efficiently
44+
def mindis(float(M, 256) L, uint8(N, M) Codes) -> (S, v, min_idx) {
45+
S(n) +=! L(r_m, int32(Codes(n, r_m)))
46+
v min=! S(r_n)
47+
min_idx min=! (S(r_n) == v) ? r_n : N
48+
}
49+
50+
# Even when splitting in 3 kernels, global device reduction will be needed to
51+
# run efficiently
52+
# don't try to run it with large sizes for now
53+
def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) {
54+
S(n) +=! L(r_m, int32(Codes(n, r_m)))
55+
}
56+
def min_2d(float(N) S) -> (v) {
57+
v min=! S(r_n)
58+
}
59+
def argmin_2d(float(N) S, float v) -> (min_idx) {
60+
min_idx min=! (S(r_n) == v) ? r_n : N
61+
}
62+
"""
63+
64+
mindis = tc.define(lang, name="mindis")
65+
S, v, min_idx = mindis(luts_t, codes_t)
66+
print("minval: {} minidx: {}".format(v, min_idx))
67+
68+
reduce_codes = tc.define(lang, name="reduce_codes")
69+
min_2d = tc.define(lang, name="min_2d")
70+
argmin_2d = tc.define(lang, name="argmin_2d")
71+
72+
S = reduce_codes(luts_t, codes_t)
73+
v = min_2d(S)
74+
min_idx = argmin_2d(S, v)
75+
76+
print("minval: {} minidx: {}".format(v, min_idx))
77+
78+
################################################################################
79+
# Each reduction is probably easier to optimize with a 2-staged TC where we
80+
# artifically increase parallelism and finish the reduction in a second kernel.
81+
# Properly choosing D such that N = D * (N / D) should result in a good version
82+
# with 5 kernels total.
83+
################################################################################
84+
N = 10 ** 5 # bump to 10**7 when ready for primetime
85+
D = 1000
86+
assert N % D == 0, "D={} must divide N={}".format(D, N)
87+
M = 32
88+
89+
lang = """
90+
def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) {
91+
S(n) +=! L(r_m, int32(Codes(n, r_m)))
92+
}
93+
def min_2d(float(D, NBYD) S) -> (V) {
94+
V(d) min=! S(d, r_nbyd)
95+
}
96+
def min_1d(float(D) V) -> (v) {
97+
v min=! V(r_d)
98+
}
99+
def argmin_2d(float(D, NBYD) S, float v) -> (MinIdx) {
100+
MinIdx(d) min=! (S(d, r_nbyd) == v) ? d * NBYD + r_nbyd : N
101+
}
102+
def argmin_1d(float(N) S, int32(D) MinIdx) -> (min_idx) {
103+
min_idx min=! (MinIdx(r_d) < N) ? r_d : N
104+
}
105+
"""
106+
107+
codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32')
108+
codes = codes.view('uint8')
109+
luts = np.random.randn(M, 256).astype('float32')
110+
111+
codes_t = torch.from_numpy(codes).cuda()
112+
luts_t = torch.from_numpy(luts).cuda()
113+
114+
reduce_codes = tc.define(lang, name="reduce_codes")
115+
min_2d = tc.define(lang, name="min_2d")
116+
min_1d = tc.define(lang, name="min_1d")
117+
argmin_2d = tc.define(lang, name="argmin_2d")
118+
argmin_1d = tc.define(lang, name="argmin_1d")
119+
120+
S = reduce_codes(luts_t, codes_t)
121+
V = min_2d(S.view((D, N / D)))
122+
v = min_1d(V)
123+
MinIdx = argmin_2d(S.view((D, N / D)), v)
124+
min_idx = argmin_1d(S, MinIdx)
125+
print("minval: {} minidx: {}".format(v, min_idx))
126+
127+
################################################################################
128+
# Longer form version has an extra k dimension we could use for parallelism
129+
# Unfortunately is it a small dimension (16) so it won't saturate Pascal/Volta.
130+
# So we may want to split in 5 to run efficiently.
131+
################################################################################
132+
N = 10 ** 7 # bump to 10**7 when ready for primetime
133+
D = 1000
134+
assert N % D == 0, "D={} must divide N={}".format(D, N)
135+
M = 32
136+
K = 16
137+
codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32')
138+
codes = codes.view('uint8')
139+
luts = np.random.randn(K, M, 256).astype('float32')
140+
141+
codes_t = torch.from_numpy(codes).cuda()
142+
luts_t = torch.from_numpy(luts).cuda()
143+
144+
lang = """
145+
def mindis(float(K, M, 256) L, uint8(N, M) Codes) -> (S, V, MinIdx) {
146+
S(k, n) +=! L(k, r_m, int32(Codes(n, r_m)))
147+
V(k) min=! S(k, r_n)
148+
MinIdx(k) min=! (S(k, r_n) == V(k)) ? r_n : N
149+
}
150+
"""
151+
152+
debug = False
153+
set_logtostderr(debug)
154+
set_debug_tc_mapper(debug)
155+
set_debug_cuda(debug)
156+
157+
mindis = tc.define(lang, name="mindis")
158+
S, V, MinIdx = mindis(luts_t, codes_t)
159+
print("minvals: {}\nminidxs: {}".format(V, MinIdx))
160+
161+
lang = """
162+
def reduce_codes(float(K, M, 256) L, uint8(N, M) Codes) -> (S) {
163+
S(k, n) +=! L(k, r_m, int32(Codes(n, r_m)))
164+
}
165+
def min_2d(float(K, D, NBYD) S) -> (V2) {
166+
V2(k, d) min=! S(k, d, r_nbyd)
167+
}
168+
def min_1d(float(K, D) V2) -> (V) {
169+
V(k) min=! V2(k, r_d)
170+
}
171+
def argmin_2d(float(K, D, NBYD) S, float(K) V) -> (MinIdx2) {
172+
MinIdx2(k, d) min=! (S(k, d, r_nbyd) == V(k)) ? d * NBYD + r_nbyd : N
173+
}
174+
def argmin_1d(float(K, N) S, int32(K, D) MinIdx2) -> (MinIdx) {
175+
MinIdx(k) min=! (MinIdx2(k, r_d) < N) ? r_d : N
176+
}
177+
"""
178+
179+
reduce_codes = tc.define(lang, name="reduce_codes")
180+
min_2d = tc.define(lang, name="min_2d")
181+
min_1d = tc.define(lang, name="min_1d")
182+
argmin_2d = tc.define(lang, name="argmin_2d")
183+
argmin_1d = tc.define(lang, name="argmin_1d")
184+
185+
S = reduce_codes(luts_t, codes_t)
186+
V2 = min_2d(S.view((K, D, N / D)))
187+
V = min_1d(V2)
188+
MinIdx2 = argmin_2d(S.view((K, D, N / D)), V)
189+
MinIdx = argmin_1d(S, MinIdx2)
190+
print("minval: {} minidx: {}".format(V, MinIdx))

0 commit comments

Comments
 (0)