|
17 | 17 | import threading
|
18 | 18 | import time
|
19 | 19 | from functools import cached_property
|
20 |
| -from math import floor, log2 |
| 20 | +from math import ceil, floor, log2 |
21 | 21 | from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
22 | 22 | import torch # usort:skip
|
23 | 23 |
|
@@ -98,6 +98,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
|
98 | 98 | weights_offsets: Tensor
|
99 | 99 | _local_instance_index: int = -1
|
100 | 100 | res_params: RESParams
|
| 101 | + table_names: List[str] |
101 | 102 |
|
102 | 103 | def __init__(
|
103 | 104 | self,
|
@@ -169,6 +170,7 @@ def __init__(
|
169 | 170 | enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
|
170 | 171 | res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
|
171 | 172 | flushing_block_size: int = 2_000_000_000, # 2GB
|
| 173 | + table_names: Optional[List[str]] = None, |
172 | 174 | ) -> None:
|
173 | 175 | super(SSDTableBatchedEmbeddingBags, self).__init__()
|
174 | 176 |
|
@@ -200,6 +202,7 @@ def __init__(
|
200 | 202 | self.pooling_mode = pooling_mode
|
201 | 203 | self.bounds_check_mode_int: int = bounds_check_mode.value
|
202 | 204 | self.embedding_specs = embedding_specs
|
| 205 | + self.table_names = table_names if table_names is not None else [] |
203 | 206 | (rows, dims) = zip(*embedding_specs)
|
204 | 207 | T_ = len(self.embedding_specs)
|
205 | 208 | assert T_ > 0
|
@@ -3315,3 +3318,141 @@ def _recording_to_timer(
|
3315 | 3318 | return timer.recording(**kwargs)
|
3316 | 3319 | # No-Op context manager
|
3317 | 3320 | return contextlib.nullcontext()
|
| 3321 | + |
| 3322 | + def fetch_from_l1_sp_w_row_ids( |
| 3323 | + self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False |
| 3324 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 3325 | + """ |
| 3326 | + Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids. |
| 3327 | + @return: updated_weights/optimizer_states, mask of which rows are filled |
| 3328 | + """ |
| 3329 | + with torch.no_grad(): |
| 3330 | + weights_dtype = self.weights_precision.as_dtype() |
| 3331 | + step = self.step |
| 3332 | + if not self.enable_optimizer_offloading and only_get_optimizer_states: |
| 3333 | + raise RuntimeError( |
| 3334 | + "Optimizer states are not offloaded, while only_get_optimizer_states is True" |
| 3335 | + ) |
| 3336 | + if only_get_optimizer_states: |
| 3337 | + start_pos = pad4(self.max_D) |
| 3338 | + row_dim = self.optimizer.state_size_dim(weights_dtype) |
| 3339 | + result_dtype = self.optimizer.dtype() |
| 3340 | + result_dim = int( |
| 3341 | + ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize)) |
| 3342 | + ) |
| 3343 | + else: |
| 3344 | + start_pos = 0 |
| 3345 | + # get the whole row |
| 3346 | + row_dim = self.cache_row_dim |
| 3347 | + result_dim = row_dim |
| 3348 | + result_dtype = weights_dtype |
| 3349 | + |
| 3350 | + with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"): |
| 3351 | + lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup( |
| 3352 | + row_ids, |
| 3353 | + self.lxu_cache_state, |
| 3354 | + self.total_hash_size, |
| 3355 | + ) |
| 3356 | + updated_weights = torch.empty( |
| 3357 | + row_ids.numel(), |
| 3358 | + result_dim, |
| 3359 | + device=self.current_device, |
| 3360 | + dtype=result_dtype, |
| 3361 | + ) |
| 3362 | + |
| 3363 | + # D2D copy cache |
| 3364 | + cache_location_mask = lxu_cache_locations >= 0 |
| 3365 | + updated_weights[cache_location_mask] = self.lxu_cache_weights[ |
| 3366 | + lxu_cache_locations[cache_location_mask], |
| 3367 | + start_pos : start_pos + row_dim, |
| 3368 | + ].view(result_dtype) |
| 3369 | + |
| 3370 | + with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"): |
| 3371 | + if len(self.ssd_scratch_pad_eviction_data) > 0: |
| 3372 | + sp = self.ssd_scratch_pad_eviction_data[0][0] |
| 3373 | + sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to( |
| 3374 | + self.current_device |
| 3375 | + ) |
| 3376 | + actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0] |
| 3377 | + if actions_count_gpu.item() == 0: |
| 3378 | + # no action to take |
| 3379 | + return (updated_weights, cache_location_mask) |
| 3380 | + |
| 3381 | + sp_idx = sp_idx[:actions_count_gpu] |
| 3382 | + |
| 3383 | + # -1 in lxu_cache_locations means the row is not in L1 cache and in SP |
| 3384 | + # fill the row_ids in L1 with -2, >0 values means in SP |
| 3385 | + # @eg. updated_row_ids_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10] |
| 3386 | + updated_row_ids_in_sp = row_ids.masked_fill( |
| 3387 | + lxu_cache_locations != -1, -2 |
| 3388 | + ) |
| 3389 | + # sort the sp_idx for binary search |
| 3390 | + # should already be sorted |
| 3391 | + # sp_idx_inverse_indices is the indices before sorting which is same to the location in SP. |
| 3392 | + # @eg. sp_idx = [4, 2, 1, 3, 10] |
| 3393 | + # @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4] |
| 3394 | + sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx) |
| 3395 | + # search rows id in sp against the SP indexes to find location of the rows in SP |
| 3396 | + # @eg: updated_ids_in_sp_idx = [0, 5, 0, 1, 0, 2, 3, 4, 4] |
| 3397 | + # @eg: 5 is OOB |
| 3398 | + updated_ids_in_sp_idx = torch.searchsorted( |
| 3399 | + sorted_sp_idx, updated_row_ids_in_sp |
| 3400 | + ) |
| 3401 | + # does not found in SP will Out of Bound |
| 3402 | + oob_sp_idx = updated_ids_in_sp_idx >= sp_idx.numel() |
| 3403 | + # make the oob items in bound |
| 3404 | + # @eg updated_ids_in_sp_idx=[0, 0, 0, 1, 0, 2, 3, 4, 4] |
| 3405 | + updated_ids_in_sp_idx[oob_sp_idx] = 0 |
| 3406 | + |
| 3407 | + # -1s locations will be filtered out in masked_index_select |
| 3408 | + sp_locations_in_updated_weights = torch.full_like( |
| 3409 | + updated_row_ids_in_sp, -1 |
| 3410 | + ) |
| 3411 | + # torch.searchsorted is not exact match, |
| 3412 | + # we only take exact matched rows, where the id is found in SP. |
| 3413 | + # @eg 5 in updated_row_ids_in_sp is not in sp_idx, but has 4 in updated_ids_in_sp_idx |
| 3414 | + # @eg sorted_sp_idx[updated_ids_in_sp_idx]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10] |
| 3415 | + # @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True] |
| 3416 | + exact_match_mask = ( |
| 3417 | + sorted_sp_idx[updated_ids_in_sp_idx] == updated_row_ids_in_sp |
| 3418 | + ) |
| 3419 | + # Get the location of the row ids found in SP. |
| 3420 | + # @eg: sp_locations_found=[2, 2, 1, 3, 0, 4] |
| 3421 | + sp_locations_found = sp_idx_inverse_indices[ |
| 3422 | + updated_ids_in_sp_idx[exact_match_mask] |
| 3423 | + ] |
| 3424 | + # @eg: sp_locations_in_updated_weights=[ 2, -1, 2, 1, -1, 3, 0, -1, 4] |
| 3425 | + sp_locations_in_updated_weights[exact_match_mask] = ( |
| 3426 | + sp_locations_found |
| 3427 | + ) |
| 3428 | + |
| 3429 | + # D2D copy SP |
| 3430 | + updated_weights[exact_match_mask] = sp[ |
| 3431 | + sp_locations_found, start_pos : start_pos + row_dim |
| 3432 | + ].view(result_dtype) |
| 3433 | + # cache_location_mask is the mask of rows in L1 |
| 3434 | + # exact_match_mask is the mask of rows in SP |
| 3435 | + cache_location_mask = torch.logical_or( |
| 3436 | + cache_location_mask, exact_match_mask |
| 3437 | + ) |
| 3438 | + |
| 3439 | + return (updated_weights, cache_location_mask) |
| 3440 | + |
| 3441 | + def register_backward_hook_before_eviction( |
| 3442 | + self, backward_hook: Callable[[torch.Tensor], None] |
| 3443 | + ) -> None: |
| 3444 | + """ |
| 3445 | + Register a backward hook to the TBE module. |
| 3446 | + And make sure this is called before the sp eviction hook. |
| 3447 | + """ |
| 3448 | + # make sure this hook is the first one to be executed |
| 3449 | + hooks = [] |
| 3450 | + backward_hooks = self.placeholder_autograd_tensor._backward_hooks |
| 3451 | + if backward_hooks is not None: |
| 3452 | + for _handle_id, hook in backward_hooks.items(): |
| 3453 | + hooks.append(hook) |
| 3454 | + backward_hooks.clear() |
| 3455 | + |
| 3456 | + self.placeholder_autograd_tensor.register_hook(backward_hook) |
| 3457 | + for hook in hooks: |
| 3458 | + self.placeholder_autograd_tensor.register_hook(hook) |
0 commit comments