Skip to content

Commit c580216

Browse files
chouxifacebook-github-bot
authored andcommitted
Add ability in SSDTBE to fetch weights from L1 and SP from outside of the module (#4450)
Summary: X-link: pytorch/torchrec#3166 Pull Request resolved: #4450 X-link: facebookresearch/FBGEMM#1513 Given row ids, fetch the updated weighs from L1 and SP. `register_backward_hook_before_eviction` This function will be called in delta tracker - register the hook to be executed in backward before the sp eviction. - So the hooks are reordered to let the hook registered through this function to execute first. `fetch_from_l1_sp_w_row_ids` will be called inside the backward hook registered using above function - to get updated weights for the given row ids already added offset to SSDTBE tables. - fetch from `lxu_cache_weights (L1)` and rest from `inserted_rows (SP) ` `table_names` is added to let the delta tracker aware of the table to TBE map. - the delta tracker stores the updated ids in forward pass's lookup module, it stores ids as a 'fqn' -> 'ids' map. - in backward pass, per TBE module we need to know which ids we want to fetch for the updated weights, thus the table names is needed to infer fqn. Reviewed By: q10, duduyi2013 Differential Revision: D72188513 fbshipit-source-id: e35b86695877088829d359b5601561fcc51d2dcd
1 parent b60e109 commit c580216

File tree

2 files changed

+506
-2
lines changed

2 files changed

+506
-2
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import threading
1818
import time
1919
from functools import cached_property
20-
from math import floor, log2
20+
from math import ceil, floor, log2
2121
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2222
import torch # usort:skip
2323

@@ -98,6 +98,7 @@ class SSDTableBatchedEmbeddingBags(nn.Module):
9898
weights_offsets: Tensor
9999
_local_instance_index: int = -1
100100
res_params: RESParams
101+
table_names: List[str]
101102

102103
def __init__(
103104
self,
@@ -169,6 +170,7 @@ def __init__(
169170
enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
170171
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
171172
flushing_block_size: int = 2_000_000_000, # 2GB
173+
table_names: Optional[List[str]] = None,
172174
) -> None:
173175
super(SSDTableBatchedEmbeddingBags, self).__init__()
174176

@@ -200,6 +202,7 @@ def __init__(
200202
self.pooling_mode = pooling_mode
201203
self.bounds_check_mode_int: int = bounds_check_mode.value
202204
self.embedding_specs = embedding_specs
205+
self.table_names = table_names if table_names is not None else []
203206
(rows, dims) = zip(*embedding_specs)
204207
T_ = len(self.embedding_specs)
205208
assert T_ > 0
@@ -3315,3 +3318,141 @@ def _recording_to_timer(
33153318
return timer.recording(**kwargs)
33163319
# No-Op context manager
33173320
return contextlib.nullcontext()
3321+
3322+
def fetch_from_l1_sp_w_row_ids(
3323+
self, row_ids: torch.Tensor, only_get_optimizer_states: bool = False
3324+
) -> Tuple[torch.Tensor, torch.Tensor]:
3325+
"""
3326+
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
3327+
@return: updated_weights/optimizer_states, mask of which rows are filled
3328+
"""
3329+
with torch.no_grad():
3330+
weights_dtype = self.weights_precision.as_dtype()
3331+
step = self.step
3332+
if not self.enable_optimizer_offloading and only_get_optimizer_states:
3333+
raise RuntimeError(
3334+
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
3335+
)
3336+
if only_get_optimizer_states:
3337+
start_pos = pad4(self.max_D)
3338+
row_dim = self.optimizer.state_size_dim(weights_dtype)
3339+
result_dtype = self.optimizer.dtype()
3340+
result_dim = int(
3341+
ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize))
3342+
)
3343+
else:
3344+
start_pos = 0
3345+
# get the whole row
3346+
row_dim = self.cache_row_dim
3347+
result_dim = row_dim
3348+
result_dtype = weights_dtype
3349+
3350+
with record_function(f"## fetch_from_l1_{step}_{self.tbe_unique_id} ##"):
3351+
lxu_cache_locations: torch.Tensor = torch.ops.fbgemm.lxu_cache_lookup(
3352+
row_ids,
3353+
self.lxu_cache_state,
3354+
self.total_hash_size,
3355+
)
3356+
updated_weights = torch.empty(
3357+
row_ids.numel(),
3358+
result_dim,
3359+
device=self.current_device,
3360+
dtype=result_dtype,
3361+
)
3362+
3363+
# D2D copy cache
3364+
cache_location_mask = lxu_cache_locations >= 0
3365+
updated_weights[cache_location_mask] = self.lxu_cache_weights[
3366+
lxu_cache_locations[cache_location_mask],
3367+
start_pos : start_pos + row_dim,
3368+
].view(result_dtype)
3369+
3370+
with record_function(f"## fetch_from_sp_{step}_{self.tbe_unique_id} ##"):
3371+
if len(self.ssd_scratch_pad_eviction_data) > 0:
3372+
sp = self.ssd_scratch_pad_eviction_data[0][0]
3373+
sp_idx = self.ssd_scratch_pad_eviction_data[0][1].to(
3374+
self.current_device
3375+
)
3376+
actions_count_gpu = self.ssd_scratch_pad_eviction_data[0][2][0]
3377+
if actions_count_gpu.item() == 0:
3378+
# no action to take
3379+
return (updated_weights, cache_location_mask)
3380+
3381+
sp_idx = sp_idx[:actions_count_gpu]
3382+
3383+
# -1 in lxu_cache_locations means the row is not in L1 cache and in SP
3384+
# fill the row_ids in L1 with -2, >0 values means in SP
3385+
# @eg. updated_row_ids_in_sp= [1, 100, 1, 2, -2, 3, 4, 5, 10]
3386+
updated_row_ids_in_sp = row_ids.masked_fill(
3387+
lxu_cache_locations != -1, -2
3388+
)
3389+
# sort the sp_idx for binary search
3390+
# should already be sorted
3391+
# sp_idx_inverse_indices is the indices before sorting which is same to the location in SP.
3392+
# @eg. sp_idx = [4, 2, 1, 3, 10]
3393+
# @eg sorted_sp_idx = [ 1, 2, 3, 4, 10] and sp_idx_inverse_indices = [2, 1, 3, 0, 4]
3394+
sorted_sp_idx, sp_idx_inverse_indices = torch.sort(sp_idx)
3395+
# search rows id in sp against the SP indexes to find location of the rows in SP
3396+
# @eg: updated_ids_in_sp_idx = [0, 5, 0, 1, 0, 2, 3, 4, 4]
3397+
# @eg: 5 is OOB
3398+
updated_ids_in_sp_idx = torch.searchsorted(
3399+
sorted_sp_idx, updated_row_ids_in_sp
3400+
)
3401+
# does not found in SP will Out of Bound
3402+
oob_sp_idx = updated_ids_in_sp_idx >= sp_idx.numel()
3403+
# make the oob items in bound
3404+
# @eg updated_ids_in_sp_idx=[0, 0, 0, 1, 0, 2, 3, 4, 4]
3405+
updated_ids_in_sp_idx[oob_sp_idx] = 0
3406+
3407+
# -1s locations will be filtered out in masked_index_select
3408+
sp_locations_in_updated_weights = torch.full_like(
3409+
updated_row_ids_in_sp, -1
3410+
)
3411+
# torch.searchsorted is not exact match,
3412+
# we only take exact matched rows, where the id is found in SP.
3413+
# @eg 5 in updated_row_ids_in_sp is not in sp_idx, but has 4 in updated_ids_in_sp_idx
3414+
# @eg sorted_sp_idx[updated_ids_in_sp_idx]=[ 1, 1, 1, 2, 1, 3, 4, 10, 10]
3415+
# @eg exact_match_mask=[ True, False, True, True, False, True, True, False, True]
3416+
exact_match_mask = (
3417+
sorted_sp_idx[updated_ids_in_sp_idx] == updated_row_ids_in_sp
3418+
)
3419+
# Get the location of the row ids found in SP.
3420+
# @eg: sp_locations_found=[2, 2, 1, 3, 0, 4]
3421+
sp_locations_found = sp_idx_inverse_indices[
3422+
updated_ids_in_sp_idx[exact_match_mask]
3423+
]
3424+
# @eg: sp_locations_in_updated_weights=[ 2, -1, 2, 1, -1, 3, 0, -1, 4]
3425+
sp_locations_in_updated_weights[exact_match_mask] = (
3426+
sp_locations_found
3427+
)
3428+
3429+
# D2D copy SP
3430+
updated_weights[exact_match_mask] = sp[
3431+
sp_locations_found, start_pos : start_pos + row_dim
3432+
].view(result_dtype)
3433+
# cache_location_mask is the mask of rows in L1
3434+
# exact_match_mask is the mask of rows in SP
3435+
cache_location_mask = torch.logical_or(
3436+
cache_location_mask, exact_match_mask
3437+
)
3438+
3439+
return (updated_weights, cache_location_mask)
3440+
3441+
def register_backward_hook_before_eviction(
3442+
self, backward_hook: Callable[[torch.Tensor], None]
3443+
) -> None:
3444+
"""
3445+
Register a backward hook to the TBE module.
3446+
And make sure this is called before the sp eviction hook.
3447+
"""
3448+
# make sure this hook is the first one to be executed
3449+
hooks = []
3450+
backward_hooks = self.placeholder_autograd_tensor._backward_hooks
3451+
if backward_hooks is not None:
3452+
for _handle_id, hook in backward_hooks.items():
3453+
hooks.append(hook)
3454+
backward_hooks.clear()
3455+
3456+
self.placeholder_autograd_tensor.register_hook(backward_hook)
3457+
for hook in hooks:
3458+
self.placeholder_autograd_tensor.register_hook(hook)

0 commit comments

Comments
 (0)