|
9 | 9 |
|
10 | 10 | import dataclasses
|
11 | 11 | import json
|
12 |
| -from typing import Any, Dict, List, Optional, Tuple |
| 12 | +from typing import Any, Dict, Optional |
13 | 13 |
|
14 |
| -import numpy as np |
15 | 14 | import torch
|
16 | 15 |
|
17 |
| -from fbgemm_gpu.tbe.utils.common import get_device, round_up |
18 |
| -from fbgemm_gpu.tbe.utils.requests import ( |
19 |
| - generate_batch_sizes_from_stats, |
20 |
| - generate_pooling_factors_from_stats, |
21 |
| - get_table_batched_offsets_from_dense, |
22 |
| - maybe_to_dtype, |
23 |
| - TBERequest, |
24 |
| -) |
| 16 | +from fbgemm_gpu.tbe.utils.common import get_device |
25 | 17 |
|
26 | 18 | from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
|
27 | 19 |
|
@@ -104,175 +96,3 @@ def variable_L(self) -> bool:
|
104 | 96 | def _new_weights(self, size: int) -> Optional[torch.Tensor]:
|
105 | 97 | # Per-sample weights will always be FP32
|
106 | 98 | return None if not self.weighted else torch.randn(size, device=get_device())
|
107 |
| - |
108 |
| - def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]: |
109 |
| - if self.variable_B(): |
110 |
| - assert ( |
111 |
| - self.batch_params.vbe_num_ranks is not None |
112 |
| - ), "vbe_num_ranks must be set for varaible batch size generation" |
113 |
| - return generate_batch_sizes_from_stats( |
114 |
| - self.batch_params.B, |
115 |
| - self.T, |
116 |
| - # pyre-ignore [6] |
117 |
| - self.batch_params.sigma_B, |
118 |
| - self.batch_params.vbe_num_ranks, |
119 |
| - # pyre-ignore [6] |
120 |
| - self.batch_params.vbe_distribution, |
121 |
| - ) |
122 |
| - |
123 |
| - else: |
124 |
| - return ([self.batch_params.B] * self.T, None) |
125 |
| - |
126 |
| - def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor: |
127 |
| - if self.variable_L(): |
128 |
| - # Generate L from stats |
129 |
| - _, L_offsets = generate_pooling_factors_from_stats( |
130 |
| - iters, |
131 |
| - Bs, |
132 |
| - self.pooling_params.L, |
133 |
| - # pyre-ignore [6] |
134 |
| - self.pooling_params.sigma_L, |
135 |
| - # pyre-ignore [6] |
136 |
| - self.pooling_params.length_distribution, |
137 |
| - ) |
138 |
| - |
139 |
| - else: |
140 |
| - Ls = [self.pooling_params.L] * (sum(Bs) * iters) |
141 |
| - L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0) |
142 |
| - |
143 |
| - return L_offsets |
144 |
| - |
145 |
| - def _generate_indices( |
146 |
| - self, |
147 |
| - iters: int, |
148 |
| - Bs: List[int], |
149 |
| - L_offsets: torch.Tensor, |
150 |
| - ) -> torch.Tensor: |
151 |
| - total_B = sum(Bs) |
152 |
| - L_offsets_list = L_offsets.tolist() |
153 |
| - indices_list = [] |
154 |
| - for it in range(iters): |
155 |
| - # L_offsets is defined over the entire set of batches for a single iteration |
156 |
| - start_offset = L_offsets_list[it * total_B] |
157 |
| - end_offset = L_offsets_list[(it + 1) * total_B] |
158 |
| - |
159 |
| - indices_list.append( |
160 |
| - torch.ops.fbgemm.tbe_generate_indices_from_distribution( |
161 |
| - self.indices_params.heavy_hitters, |
162 |
| - self.indices_params.zipf_q, |
163 |
| - self.indices_params.zipf_s, |
164 |
| - # max_index = dimensions of the embedding table |
165 |
| - self.E, |
166 |
| - # num_indices = number of indices to generate |
167 |
| - end_offset - start_offset, |
168 |
| - ) |
169 |
| - ) |
170 |
| - |
171 |
| - return torch.cat(indices_list) |
172 |
| - |
173 |
| - def _build_requests_jagged( |
174 |
| - self, |
175 |
| - iters: int, |
176 |
| - Bs: List[int], |
177 |
| - Bs_feature_rank: Optional[List[List[int]]], |
178 |
| - L_offsets: torch.Tensor, |
179 |
| - all_indices: torch.Tensor, |
180 |
| - ) -> List[TBERequest]: |
181 |
| - total_B = sum(Bs) |
182 |
| - all_indices = all_indices.flatten() |
183 |
| - requests = [] |
184 |
| - for it in range(iters): |
185 |
| - start_offset = L_offsets[it * total_B] |
186 |
| - it_L_offsets = torch.concat( |
187 |
| - [ |
188 |
| - torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device), |
189 |
| - L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset, |
190 |
| - ] |
191 |
| - ) |
192 |
| - requests.append( |
193 |
| - TBERequest( |
194 |
| - maybe_to_dtype( |
195 |
| - all_indices[start_offset : L_offsets[(it + 1) * total_B]], |
196 |
| - self.indices_params.index_dtype, |
197 |
| - ), |
198 |
| - maybe_to_dtype( |
199 |
| - it_L_offsets.to(get_device()), self.indices_params.offset_dtype |
200 |
| - ), |
201 |
| - self._new_weights(int(it_L_offsets[-1].item())), |
202 |
| - Bs_feature_rank if self.variable_B() else None, |
203 |
| - ) |
204 |
| - ) |
205 |
| - return requests |
206 |
| - |
207 |
| - def _build_requests_dense( |
208 |
| - self, iters: int, all_indices: torch.Tensor |
209 |
| - ) -> List[TBERequest]: |
210 |
| - # NOTE: We're using existing code from requests.py to build the |
211 |
| - # requests, and since the existing code requires 2D view of all_indices, |
212 |
| - # the existing all_indices must be reshaped |
213 |
| - all_indices = all_indices.reshape(iters, -1) |
214 |
| - |
215 |
| - requests = [] |
216 |
| - for it in range(iters): |
217 |
| - indices, offsets = get_table_batched_offsets_from_dense( |
218 |
| - all_indices[it].view( |
219 |
| - self.T, self.batch_params.B, self.pooling_params.L |
220 |
| - ), |
221 |
| - use_cpu=self.use_cpu, |
222 |
| - ) |
223 |
| - requests.append( |
224 |
| - TBERequest( |
225 |
| - maybe_to_dtype(indices, self.indices_params.index_dtype), |
226 |
| - maybe_to_dtype(offsets, self.indices_params.offset_dtype), |
227 |
| - self._new_weights( |
228 |
| - self.T * self.batch_params.B * self.pooling_params.L |
229 |
| - ), |
230 |
| - ) |
231 |
| - ) |
232 |
| - return requests |
233 |
| - |
234 |
| - def generate_requests( |
235 |
| - self, |
236 |
| - iters: int = 1, |
237 |
| - ) -> List[TBERequest]: |
238 |
| - # Generate batch sizes |
239 |
| - Bs, Bs_feature_rank = self._generate_batch_sizes() |
240 |
| - |
241 |
| - # Generate pooling info |
242 |
| - L_offsets = self._generate_pooling_info(iters, Bs) |
243 |
| - |
244 |
| - # Generate indices |
245 |
| - all_indices = self._generate_indices(iters, Bs, L_offsets) |
246 |
| - |
247 |
| - # Build TBE requests |
248 |
| - if self.variable_B() or self.variable_L(): |
249 |
| - return self._build_requests_jagged( |
250 |
| - iters, Bs, Bs_feature_rank, L_offsets, all_indices |
251 |
| - ) |
252 |
| - else: |
253 |
| - return self._build_requests_dense(iters, all_indices) |
254 |
| - |
255 |
| - def generate_embedding_dims(self) -> Tuple[int, List[int]]: |
256 |
| - if self.mixed_dim: |
257 |
| - Ds = [ |
258 |
| - round_up( |
259 |
| - np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4 |
260 |
| - ) |
261 |
| - for _ in range(self.T) |
262 |
| - ] |
263 |
| - return (int(np.average(Ds)), Ds) |
264 |
| - else: |
265 |
| - return (self.D, [self.D] * self.T) |
266 |
| - |
267 |
| - def generate_feature_requires_grad(self, size: int) -> torch.Tensor: |
268 |
| - assert size <= self.T, "size of feature_requires_grad must be less than T" |
269 |
| - weighted_requires_grad_tables = np.random.choice( |
270 |
| - self.T, replace=False, size=(size,) |
271 |
| - ).tolist() |
272 |
| - return ( |
273 |
| - torch.tensor( |
274 |
| - [1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)] |
275 |
| - ) |
276 |
| - .to(get_device()) |
277 |
| - .int() |
278 |
| - ) |
0 commit comments