|
| 1 | +# Copyright (c) 2017-present, Facebook, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +############################################################################## |
| 15 | +import tensor_comprehensions as tc |
| 16 | +from tensor_comprehensions.tc import set_logtostderr |
| 17 | +from tensor_comprehensions.tc import set_debug_tc_mapper |
| 18 | +from tensor_comprehensions.tc import set_debug_cuda |
| 19 | + |
| 20 | +import numpy as np |
| 21 | +import torch |
| 22 | + |
| 23 | +# |
| 24 | +## Example submitted by @mdouze, originally related to uint8 type support |
| 25 | +# |
| 26 | + |
| 27 | +debug = False |
| 28 | +set_logtostderr(debug) |
| 29 | +set_debug_tc_mapper(debug) |
| 30 | +set_debug_cuda(debug) |
| 31 | + |
| 32 | +N = 1000 |
| 33 | +M = 32 |
| 34 | + |
| 35 | +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') |
| 36 | +codes = codes.view('uint8') |
| 37 | +luts = np.random.randn(M, 256).astype('float32') |
| 38 | + |
| 39 | +codes_t = torch.from_numpy(codes).cuda() |
| 40 | +luts_t = torch.from_numpy(luts).cuda() |
| 41 | + |
| 42 | +lang = """ |
| 43 | +# mindis as a single kernel will require grid synchronization to run efficiently |
| 44 | +def mindis(float(M, 256) L, uint8(N, M) Codes) -> (S, v, min_idx) { |
| 45 | + S(n) +=! L(r_m, int32(Codes(n, r_m))) |
| 46 | + v min=! S(r_n) |
| 47 | + min_idx min=! (S(r_n) == v) ? r_n : N |
| 48 | +} |
| 49 | +
|
| 50 | +# Even when splitting in 3 kernels, global device reduction will be needed to |
| 51 | +# run efficiently |
| 52 | +# don't try to run it with large sizes for now |
| 53 | +def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) { |
| 54 | + S(n) +=! L(r_m, int32(Codes(n, r_m))) |
| 55 | +} |
| 56 | +def min_2d(float(N) S) -> (v) { |
| 57 | + v min=! S(r_n) |
| 58 | +} |
| 59 | +def argmin_2d(float(N) S, float v) -> (min_idx) { |
| 60 | + min_idx min=! (S(r_n) == v) ? r_n : N |
| 61 | +} |
| 62 | +""" |
| 63 | + |
| 64 | +mindis = tc.define(lang, name="mindis") |
| 65 | +S, v, min_idx = mindis(luts_t, codes_t) |
| 66 | +print("minval: {} minidx: {}".format(v, min_idx)) |
| 67 | + |
| 68 | +reduce_codes = tc.define(lang, name="reduce_codes") |
| 69 | +min_2d = tc.define(lang, name="min_2d") |
| 70 | +argmin_2d = tc.define(lang, name="argmin_2d") |
| 71 | + |
| 72 | +S = reduce_codes(luts_t, codes_t) |
| 73 | +v = min_2d(S) |
| 74 | +min_idx = argmin_2d(S, v) |
| 75 | + |
| 76 | +print("minval: {} minidx: {}".format(v, min_idx)) |
| 77 | + |
| 78 | +################################################################################ |
| 79 | +# Each reduction is probably easier to optimize with a 2-staged TC where we |
| 80 | +# artifically increase parallelism and finish the reduction in a second kernel. |
| 81 | +# Properly choosing D such that N = D * (N / D) should result in a good version |
| 82 | +# with 5 kernels total. |
| 83 | +################################################################################ |
| 84 | +N = 10 ** 5 # bump to 10**7 when ready for primetime |
| 85 | +D = 1000 |
| 86 | +assert N % D == 0, "D={} must divide N={}".format(D, N) |
| 87 | +M = 32 |
| 88 | + |
| 89 | +lang = """ |
| 90 | +def reduce_codes(float(M, 256) L, uint8(N, M) Codes) -> (S) { |
| 91 | + S(n) +=! L(r_m, int32(Codes(n, r_m))) |
| 92 | +} |
| 93 | +def min_2d(float(D, NBYD) S) -> (V) { |
| 94 | + V(d) min=! S(d, r_nbyd) |
| 95 | +} |
| 96 | +def min_1d(float(D) V) -> (v) { |
| 97 | + v min=! V(r_d) |
| 98 | +} |
| 99 | +def argmin_2d(float(D, NBYD) S, float v) -> (MinIdx) { |
| 100 | + MinIdx(d) min=! (S(d, r_nbyd) == v) ? d * NBYD + r_nbyd : N |
| 101 | +} |
| 102 | +def argmin_1d(float(N) S, int32(D) MinIdx) -> (min_idx) { |
| 103 | + min_idx min=! (MinIdx(r_d) < N) ? r_d : N |
| 104 | +} |
| 105 | +""" |
| 106 | + |
| 107 | +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') |
| 108 | +codes = codes.view('uint8') |
| 109 | +luts = np.random.randn(M, 256).astype('float32') |
| 110 | + |
| 111 | +codes_t = torch.from_numpy(codes).cuda() |
| 112 | +luts_t = torch.from_numpy(luts).cuda() |
| 113 | + |
| 114 | +reduce_codes = tc.define(lang, name="reduce_codes") |
| 115 | +min_2d = tc.define(lang, name="min_2d") |
| 116 | +min_1d = tc.define(lang, name="min_1d") |
| 117 | +argmin_2d = tc.define(lang, name="argmin_2d") |
| 118 | +argmin_1d = tc.define(lang, name="argmin_1d") |
| 119 | + |
| 120 | +S = reduce_codes(luts_t, codes_t) |
| 121 | +V = min_2d(S.view((D, N / D))) |
| 122 | +v = min_1d(V) |
| 123 | +MinIdx = argmin_2d(S.view((D, N / D)), v) |
| 124 | +min_idx = argmin_1d(S, MinIdx) |
| 125 | +print("minval: {} minidx: {}".format(v, min_idx)) |
| 126 | + |
| 127 | +################################################################################ |
| 128 | +# Longer form version has an extra k dimension we could use for parallelism |
| 129 | +# Unfortunately is it a small dimension (16) so it won't saturate Pascal/Volta. |
| 130 | +# So we may want to split in 5 to run efficiently. |
| 131 | +################################################################################ |
| 132 | +N = 10 ** 7 # bump to 10**7 when ready for primetime |
| 133 | +D = 1000 |
| 134 | +assert N % D == 0, "D={} must divide N={}".format(D, N) |
| 135 | +M = 32 |
| 136 | +K = 16 |
| 137 | +codes = np.random.randint(1<<32, size=(N, M // 4)).astype('uint32') |
| 138 | +codes = codes.view('uint8') |
| 139 | +luts = np.random.randn(K, M, 256).astype('float32') |
| 140 | + |
| 141 | +codes_t = torch.from_numpy(codes).cuda() |
| 142 | +luts_t = torch.from_numpy(luts).cuda() |
| 143 | + |
| 144 | +lang = """ |
| 145 | +def mindis(float(K, M, 256) L, uint8(N, M) Codes) -> (S, V, MinIdx) { |
| 146 | + S(k, n) +=! L(k, r_m, int32(Codes(n, r_m))) |
| 147 | + V(k) min=! S(k, r_n) |
| 148 | + MinIdx(k) min=! (S(k, r_n) == V(k)) ? r_n : N |
| 149 | +} |
| 150 | +""" |
| 151 | + |
| 152 | +debug = False |
| 153 | +set_logtostderr(debug) |
| 154 | +set_debug_tc_mapper(debug) |
| 155 | +set_debug_cuda(debug) |
| 156 | + |
| 157 | +mindis = tc.define(lang, name="mindis") |
| 158 | +S, V, MinIdx = mindis(luts_t, codes_t) |
| 159 | +print("minvals: {}\nminidxs: {}".format(V, MinIdx)) |
| 160 | + |
| 161 | +lang = """ |
| 162 | +def reduce_codes(float(K, M, 256) L, uint8(N, M) Codes) -> (S) { |
| 163 | + S(k, n) +=! L(k, r_m, int32(Codes(n, r_m))) |
| 164 | +} |
| 165 | +def min_2d(float(K, D, NBYD) S) -> (V2) { |
| 166 | + V2(k, d) min=! S(k, d, r_nbyd) |
| 167 | +} |
| 168 | +def min_1d(float(K, D) V2) -> (V) { |
| 169 | + V(k) min=! V2(k, r_d) |
| 170 | +} |
| 171 | +def argmin_2d(float(K, D, NBYD) S, float(K) V) -> (MinIdx2) { |
| 172 | + MinIdx2(k, d) min=! (S(k, d, r_nbyd) == V(k)) ? d * NBYD + r_nbyd : N |
| 173 | +} |
| 174 | +def argmin_1d(float(K, N) S, int32(K, D) MinIdx2) -> (MinIdx) { |
| 175 | + MinIdx(k) min=! (MinIdx2(k, r_d) < N) ? r_d : N |
| 176 | +} |
| 177 | +""" |
| 178 | + |
| 179 | +reduce_codes = tc.define(lang, name="reduce_codes") |
| 180 | +min_2d = tc.define(lang, name="min_2d") |
| 181 | +min_1d = tc.define(lang, name="min_1d") |
| 182 | +argmin_2d = tc.define(lang, name="argmin_2d") |
| 183 | +argmin_1d = tc.define(lang, name="argmin_1d") |
| 184 | + |
| 185 | +S = reduce_codes(luts_t, codes_t) |
| 186 | +V2 = min_2d(S.view((K, D, N / D))) |
| 187 | +V = min_1d(V2) |
| 188 | +MinIdx2 = argmin_2d(S.view((K, D, N / D)), V) |
| 189 | +MinIdx = argmin_1d(S, MinIdx2) |
| 190 | +print("minval: {} minidx: {}".format(V, MinIdx)) |
0 commit comments