|
| 1 | +# Copyright (c) 2017-present, Facebook, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +############################################################################## |
| 15 | +import tensor_comprehensions as tc |
| 16 | + |
| 17 | +import argparse |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | + |
| 21 | +def GetArgumentParser(): |
| 22 | + parser = argparse.ArgumentParser( |
| 23 | + description='Lengths Cosine Coherence benchmark.' |
| 24 | + ) |
| 25 | + parser.add_argument( |
| 26 | + '--num_segs', type=int, default=4, help='The number of segments.' |
| 27 | + ) |
| 28 | + parser.add_argument( |
| 29 | + '--seg_length', type=int, default=100, help='The length of each segment.' |
| 30 | + ) |
| 31 | + parser.add_argument( |
| 32 | + '--num_of_channels', type=int, default=128, help='The dimension of embeddings.' |
| 33 | + ) |
| 34 | + parser.add_argument( |
| 35 | + '--pos_dist', type=int, default=0, help='The positive window size.' |
| 36 | + ) |
| 37 | + parser.add_argument( |
| 38 | + '--neg_dist', type=int, default=0, help='The negative window size.' |
| 39 | + ) |
| 40 | + |
| 41 | + parser.add_argument( |
| 42 | + '--tuner_threads', type=int, default=16, help='Number of CPU tuning threads.' |
| 43 | + ) |
| 44 | + parser.add_argument( |
| 45 | + '--tuner_generations', type=int, default=25, help='Number of tuning generations.' |
| 46 | + ) |
| 47 | + parser.add_argument( |
| 48 | + '--tuner_pop_size', type=int, default=100, help='Number candidates per tuning generations.' |
| 49 | + ) |
| 50 | + parser.add_argument( |
| 51 | + '--tuner_number_elites', type=int, default=5, help='Number of best tuning candidates that survive each generation.' |
| 52 | + ) |
| 53 | + parser.add_argument( |
| 54 | + '--tuner_devices', type=str, default='0', help='Comma separated list of tuning devices.' |
| 55 | + ) |
| 56 | + parser.add_argument( |
| 57 | + '--tuner_cache_file', |
| 58 | + type=str, |
| 59 | + default='/tmp/cache_tum', |
| 60 | + help='File to store tuned mapping options', |
| 61 | + ) |
| 62 | + return parser |
| 63 | + |
| 64 | + |
| 65 | +parser = GetArgumentParser() |
| 66 | +args, extra_args = parser.parse_known_args() |
| 67 | + |
| 68 | +############################################################################### |
| 69 | +# Reference python impl |
| 70 | +############################################################################### |
| 71 | +def reference(D, L): |
| 72 | + R = np.zeros(shape=(L.size,), dtype=D.dtype) |
| 73 | + Normed_DATA = D * 0 |
| 74 | + Norm_of_Vector = np.zeros(shape=(D.shape[0],), dtype=D.dtype) |
| 75 | + POS_C = np.zeros(shape=(L.size,), dtype=np.long) |
| 76 | + NEG_C = np.zeros(shape=(L.size,), dtype=np.long) |
| 77 | + line = 0 |
| 78 | + kEps = 1e-12 |
| 79 | + |
| 80 | + def dot(a, b): |
| 81 | + return np.sum(a * b) |
| 82 | + |
| 83 | + for i in range(D.shape[0]): |
| 84 | + Norm_of_Vector[i] = dot(D[i], D[i]) |
| 85 | + Normed_DATA[i] = D[i] / np.sqrt(max(Norm_of_Vector[i], kEps)) |
| 86 | + |
| 87 | + for g in range(L.size): |
| 88 | + if L[g] <= 1: |
| 89 | + line += L[g] |
| 90 | + continue |
| 91 | + pos_res = 0 |
| 92 | + neg_res = 0 |
| 93 | + for i in range(L[g] - 1): |
| 94 | + for j in range(i + 1, L[g]): |
| 95 | + sqrt_norm = np.sqrt( |
| 96 | + max(Norm_of_Vector[line + i], kEps) |
| 97 | + * max(Norm_of_Vector[line + j], kEps) |
| 98 | + ) |
| 99 | + if args.pos_dist == 0 or j - i <= args.pos_dist: |
| 100 | + pos_res += dot(D[line + i], D[line + j]) / sqrt_norm |
| 101 | + POS_C[g] += 1 |
| 102 | + if args.neg_dist > 0 and j - i >= args.neg_dist: |
| 103 | + neg_res += dot(D[line + i], D[line + j]) / sqrt_norm |
| 104 | + NEG_C[g] += 1 |
| 105 | + pos_res = 0 if POS_C[g] < 1 else pos_res / POS_C[g] |
| 106 | + neg_res = 0 if NEG_C[g] < 1 else neg_res / NEG_C[g] |
| 107 | + R[g] = pos_res - neg_res |
| 108 | + line += L[g] |
| 109 | + return [R, Normed_DATA, Norm_of_Vector, POS_C, NEG_C] |
| 110 | + |
| 111 | +############################################################################### |
| 112 | +# TC equivalent converting control-flow to data dependencies |
| 113 | +############################################################################### |
| 114 | +LENGTHS_COSINE_COHERENCE = ''' |
| 115 | +# TODO: this is just a scan but currently implemented as K reductions |
| 116 | +def make_idx(int32(K) Segments) -> (Idx) { |
| 117 | + Idx(k) +=! (r_k < k) ? Segments(r_k) : 0 where k in 0:K+1 |
| 118 | +} |
| 119 | +def make_alpha(int32(KP1) Idx, int32(MAX_L) SegmentsMetaData) -> (Alpha) { |
| 120 | + # Triangular compute |
| 121 | + Alpha(k, max_l_1, max_l_2) = (max_l_1 >= max_l_2) ? 0.0 : |
| 122 | + # This computes an approximation using the maximal segment length |
| 123 | + ((<pos_dist> == 0 || fabs(float(max_l_1 - max_l_2)) <= float(<pos_dist>)) ? 1.0 : |
| 124 | + (<neg_dist> == 0 && fabs(float(max_l_1 - max_l_2)) >= float(<neg_dist>)) ? -1.0 : 0.0) |
| 125 | + * |
| 126 | + # Filter against the true value of Idx |
| 127 | + ((Idx(k) + max_l_1 < Idx(k + 1) && Idx(k) + max_l_2 < Idx(k + 1)) |
| 128 | + ? 1.0 : 0.0) |
| 129 | + where k in 0:KP1-1, max_l_1 in 0:MAX_L, max_l_2 in 0:MAX_L |
| 130 | +} |
| 131 | +def make_counts(float(K, MAX_L, MAX_L) Alpha, int32(MAX_L) SegmentsMetaData) |
| 132 | + -> (PosCount, NegCount, tmpPos, tmpNeg) |
| 133 | +{ |
| 134 | + # Triangular compute |
| 135 | + # tmp is necessary for reasonable performance with the current mapper |
| 136 | + # because we don't yet support 2-D reductions |
| 137 | + # Note that in practice, tmp also gives strictly more parallelism and |
| 138 | + # allows exploiting 2 levels or thread parallelism (doall and reduction) or |
| 139 | + # (in the future when block syncrhonization is supported) 1 level of block |
| 140 | + # parallelism without cross-block reductions. |
| 141 | + tmpPos(k, max_l_1) +=! (max_l_1 >= r_max_l_2) ? 0.0 : |
| 142 | + (Alpha(k, max_l_1, r_max_l_2) > 0.0) ? 1.0 : 0.0 |
| 143 | + # TODO: annotation should not be needed |
| 144 | + where k in 0:K, max_l_1 in 0:MAX_L, r_max_l_2 in 0:MAX_L |
| 145 | + PosCount(k) +=! tmpPos(k, max_l_1) |
| 146 | + # TODO: annotation should not be needed |
| 147 | + where k in 0:K, max_l_1 in 0:MAX_L |
| 148 | +
|
| 149 | + # Triangular compute |
| 150 | + # tmp is necessary because we don't yet support 2-D reductions |
| 151 | + # But in practice, tmp also gives strictly more parallelism and allows |
| 152 | + # exploiting blocks without cross-block reductions |
| 153 | + tmpNeg(k, max_l_1) +=! (max_l_1 >= r_max_l_2) ? 0.0 : |
| 154 | + (Alpha(k, max_l_1, r_max_l_2) < 0.0) ? 1.0 : 0.0 |
| 155 | + # TODO: annotation should not be needed |
| 156 | + where k in 0:K, max_l_1 in 0:MAX_L, r_max_l_2 in 0:MAX_L |
| 157 | + NegCount(k) +=! tmpNeg(k, max_l_1) |
| 158 | + # TODO: annotation should not be needed |
| 159 | + where k in 0:K, max_l_1 in 0:MAX_L |
| 160 | +} |
| 161 | +def make_beta(float(K, MAX_L, MAX_L) Alpha, |
| 162 | + float(K) PosCount, |
| 163 | + float(K) NegCount, |
| 164 | + int32(MAX_L) SegmentsMetaData) -> (Beta) |
| 165 | +{ |
| 166 | + Beta(k, max_l_1, max_l_2) = (max_l_1 >= max_l_2) ? 0.0 : |
| 167 | + (Alpha(k, max_l_1, max_l_2) == 1.0) ? 1.0 / float(PosCount(k)) : |
| 168 | + (Alpha(k, max_l_1, max_l_2) == -1.0) ? -1.0 / float(NegCount(k)) : |
| 169 | + 0.0 |
| 170 | + # TODO: annotation should not be needed |
| 171 | + where k in 0:K, max_l_1 in 0:MAX_L, max_l_2 in 0:MAX_L |
| 172 | +} |
| 173 | +
|
| 174 | +def normalize(float(N, C) Input) -> (NormData, Square) { |
| 175 | + Square(n) +=! Input(n, r_c) * Input(n, r_c) |
| 176 | + NormData(n, c) = Input(n, c) / sqrt(Square(n) + 1e-12) |
| 177 | +} |
| 178 | +def dots(float(N, C) NormData, int32(KP1) Idx, float(K, MAX_L, MAX_L) Beta) -> (Dots) { |
| 179 | + # Triangular compute |
| 180 | + Dots(k, max_l_1, max_l_2) +=! (max_l_1 >= max_l_2) ? 0.0 : |
| 181 | + # Avoid out of bounds Idx computations |
| 182 | + ((Idx(k) + max_l_1 >= Idx(k + 1) || Idx(k) + max_l_2 >= Idx(k + 1)) ? |
| 183 | + 0.0 : |
| 184 | + NormData(Idx(k) + max_l_1, r_c) * NormData(Idx(k) + max_l_2, r_c)) |
| 185 | + # TODO: annotation should not be needed |
| 186 | + where k in 0:K, max_l_1 in 0:MAX_L, max_l_2 in 0:MAX_L |
| 187 | +} |
| 188 | +def result(float(K, MAX_L, MAX_L) Beta, float(K, MAX_L, MAX_L) Dots) -> (O, tmpO) { |
| 189 | + # Triangular compute |
| 190 | + # tmp is necessary because we don't yet support 2-D reductions |
| 191 | + # But in practice, tmp also gives strictly more parallelism and allows |
| 192 | + # exploiting blocks without cross-block reductions |
| 193 | + tmpO(k, max_l_1) +=! (max_l_1 >= r_max_l_2) ? 0.0 : |
| 194 | + Dots(k, max_l_1, r_max_l_2) * Beta(k, max_l_1, r_max_l_2) |
| 195 | + O(k) +=! tmpO(k, max_l_1) |
| 196 | +} |
| 197 | +''' |
| 198 | + |
| 199 | +############################################################################### |
| 200 | +# Implicit compilation and tuning behavior |
| 201 | +############################################################################### |
| 202 | +tuner_config = ( |
| 203 | + tc.TunerConfig() |
| 204 | + .threads(args.tuner_threads) |
| 205 | + .generations(args.tuner_generations) |
| 206 | + .pop_size(args.tuner_pop_size) |
| 207 | + .number_elites(args.tuner_number_elites) |
| 208 | + .devices(args.tuner_devices)) |
| 209 | + |
| 210 | +# This function is used for reinforcing tuning |
| 211 | +# 1. make_idx is small and does not get tuned or saved, just using naive |
| 212 | +# options on it is fine; |
| 213 | +# 2. if we find an option in the cache, use it either as is or as starting |
| 214 | +# point for reinforcement, depending on whether the entry_point is in the |
| 215 | +# reinforcement list; |
| 216 | +# 3. dots will benefit from being reinforced a few times (reaching 90us on P100) |
| 217 | +reinforce_list = [''] |
| 218 | +def generate_options(tc_str: str, |
| 219 | + entry_point: str, |
| 220 | + *inputs: torch.Tensor) -> tc.MappingOptions: |
| 221 | + global reinforce |
| 222 | + |
| 223 | + # TODO: comment the line below which serves the purpose of not blowing up |
| 224 | + # CI time |
| 225 | + return tc.make_naive_options_factory()(tc_str, entry_point, *inputs) |
| 226 | + |
| 227 | + if entry_point == 'make_idx': |
| 228 | + return tc.make_naive_options_factory()(tc_str, entry_point, *inputs) |
| 229 | + |
| 230 | + loaded = tc.make_load_from_cache_options_factory(args.tuner_cache_file)( |
| 231 | + tc_str, entry_point, *inputs) |
| 232 | + |
| 233 | + if loaded is None or entry_point in reinforce_list or '*' in reinforce_list: |
| 234 | + start = loaded if loaded is not None else 'naive' |
| 235 | + return tc.make_autotuned_options_factory( |
| 236 | + starting_options=start, |
| 237 | + tuner_config=tuner_config, |
| 238 | + cache_filename=args.tuner_cache_file, |
| 239 | + store_to_cache=True,)(tc_str, entry_point, *inputs) |
| 240 | + |
| 241 | + assert loaded is not None, 'None found' |
| 242 | + |
| 243 | + return loaded |
| 244 | + |
| 245 | + |
| 246 | +############################################################################### |
| 247 | +# Define the TC for LENGTHS_COSINE_COHERENCE, use |
| 248 | +############################################################################### |
| 249 | +TC = tc.define( |
| 250 | + (LENGTHS_COSINE_COHERENCE |
| 251 | + .replace('<pos_dist>', str(args.pos_dist)) |
| 252 | + .replace('<neg_dist>', str(args.neg_dist))), |
| 253 | + generate_options, |
| 254 | +) |
| 255 | + |
| 256 | +############################################################################### |
| 257 | +# Run with implicit compilation and tuning |
| 258 | +############################################################################### |
| 259 | +# Input(N x C) random floats is partitioned into K buckets each of length L(K) |
| 260 | +# We then sum within each bucket (with a positive-pair / negative-pair twist) |
| 261 | +# This first impl uses the max bucket length and makes the computation dense |
| 262 | +InputData = torch.randn( |
| 263 | + args.num_segs * args.seg_length, args.num_of_channels, device='cuda') |
| 264 | +# Assume all segments of same length for now |
| 265 | +Segments = torch.ones(args.num_segs, dtype=torch.int, device='cuda').fill_(args.seg_length) |
| 266 | + |
| 267 | +Idx = TC.make_idx(Segments) |
| 268 | +SegmentsMetaData = torch.ones((torch.max(Segments)[0],), dtype=torch.int, device='cuda') |
| 269 | +Alpha = TC.make_alpha(Idx, SegmentsMetaData) |
| 270 | +PosCount, NegCount, _1, _2 = TC.make_counts(Alpha, SegmentsMetaData) |
| 271 | +Beta = TC.make_beta(Alpha, PosCount, NegCount, SegmentsMetaData) |
| 272 | +NormData, Square = TC.normalize(InputData) |
| 273 | +Dots = TC.dots(NormData, Idx, Beta) |
| 274 | +Output, _ = TC.result(Beta, Dots) |
| 275 | + |
| 276 | +R, Normed_DATA, Norm_of_Vector, POS_C, NEG_C = ( |
| 277 | + reference(InputData.cpu().numpy(), Segments.cpu().numpy())) |
| 278 | + |
| 279 | +############################################################################### |
| 280 | +# Check |
| 281 | +############################################################################### |
| 282 | +tc.assert_almost_equal( |
| 283 | + PosCount.cpu(), |
| 284 | + torch.from_numpy(POS_C).float(), |
| 285 | + torch.from_numpy(POS_C).float(), |
| 286 | + precision=0) |
| 287 | +tc.assert_almost_equal( |
| 288 | + NegCount.cpu(), |
| 289 | + torch.from_numpy(NEG_C).float(), |
| 290 | + torch.from_numpy(NEG_C).float(), |
| 291 | + precision=0) |
| 292 | +tc.assert_almost_equal( |
| 293 | + Output.cpu(), |
| 294 | + torch.from_numpy(R), |
| 295 | + Dots.cpu(), |
| 296 | + Beta.cpu(), |
| 297 | + operations=SegmentsMetaData.size(0) * (SegmentsMetaData.size(0) + 1) // 2, |
| 298 | +) |
| 299 | + |
| 300 | +print('SUCCESS, maxdiff={}'.format((Output.cpu() - torch.from_numpy(R)).abs().max())) |
0 commit comments