From be097b6f1994dbf1f157330d2ea9f5350bda6099 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 14 Jul 2025 14:18:51 -0700 Subject: [PATCH] Update the cache row dim calculation in TBE SSD (#4480) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1537 - The current cache row dim calculation in TBE SSD assumes that optimizers have state sizes that are fixed relative to table dimensions. This change updates the cache row dim calculation to account for optimizers whose states' sizes depends on the row length, such as Partial Rowwise Adam. Reviewed By: emlin, jiawenliu64 Differential Revision: D77321062 --- .../fbgemm_gpu/split_embedding_configs.py | 36 ++++++---- fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py | 69 ++++++++++++------- .../tbe/ssd/ssd_split_tbe_training_test.py | 9 ++- 3 files changed, 75 insertions(+), 39 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 9d2a5a871b..8cdfafe599 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -8,7 +8,6 @@ # pyre-strict import enum -import math from typing import Any, Dict # noqa: F401 import torch @@ -41,22 +40,33 @@ class EmbOptimType(enum.Enum): def __str__(self) -> str: return self.value - def state_size(self) -> int: + def _extract_dtype( + self, optimizer_state_dtypes: Dict[str, "SparseType"], name: str + ) -> torch.dtype: + if optimizer_state_dtypes is None or name not in optimizer_state_dtypes: + return torch.float32 + return optimizer_state_dtypes[name].as_dtype() + + def state_size_nbytes( + self, D: int, optimizer_state_dtypes: Dict[str, "SparseType"] = {} # noqa: B006 + ) -> int: """ Returns the size of the data (in bytes) required to hold the optimizer - state (per table row), or 0 if none needed + state (per table row) """ - return { - # Only holds the momentum float value per row - EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32.itemsize, - }.get(self, 0) + if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1") + # Store one value for momentum per row + return momentum1_dtype.itemsize - def state_size_dim(self, dtype: torch.dtype) -> int: - """ - Returns the size of the data (in units of elements of dtype) rquired to - hold optimizer information (per table row) - """ - return int(math.ceil(self.state_size() / dtype.itemsize)) + elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM: + momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1") + momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2") + # Store one value for momentum2 plus D values for momentum1 per row + return momentum2_dtype.itemsize + (D * momentum1_dtype.itemsize) + + else: + return 0 def dtype(self) -> torch.dtype: """ diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 7ef539e036..179b35c011 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -12,6 +12,7 @@ import functools import itertools import logging +import math import os import tempfile import threading @@ -172,6 +173,7 @@ def __init__( res_params: Optional[RESParams] = None, # raw embedding streaming sharding info flushing_block_size: int = 2_000_000_000, # 2GB table_names: Optional[List[str]] = None, + optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006 ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -185,6 +187,7 @@ def __init__( assert weights_precision in (SparseType.FP32, SparseType.FP16) self.weights_precision = weights_precision self.output_dtype: int = output_dtype.as_int() + self.optimizer_state_dtypes: Dict[str, SparseType] = optimizer_state_dtypes # Zero collision TBE configurations self.kv_zch_params = kv_zch_params @@ -987,13 +990,24 @@ def cache_row_dim(self) -> int: """ if self.enable_optimizer_offloading: return self.max_D + pad4( - # Compute the number of elements of cache_dtype needed to store the - # optimizer state - self.optimizer.state_size_dim(self.weights_precision.as_dtype()) + # Compute the number of elements of cache_dtype needed to store + # the optimizer state + self.optimizer_state_dim ) else: return self.max_D + @cached_property + def optimizer_state_dim(self) -> int: + return int( + math.ceil( + self.optimizer.state_size_nbytes( + self.max_D, self.optimizer_state_dtypes + ) + / self.weights_precision.as_dtype().itemsize + ) + ) + @property # pyre-ignore def ssd_db(self): @@ -2285,9 +2299,8 @@ def split_optimizer_states( table_offset = 0 dtype = self.weights_precision.as_dtype() - optimizer_dim = self.optimizer.state_size_dim(dtype) logging.info( - f"split_optimizer_states: {optimizer_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}" + f"split_optimizer_states: {self.optimizer_state_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}" ) for t, (emb_height, emb_dim) in enumerate(self.embedding_specs): @@ -2345,7 +2358,7 @@ def split_optimizer_states( and sorted_id_tensor[t].size(0) > 0 else emb_height ), - optimizer_dim, + self.optimizer_state_dim, ], dtype=dtype, row_offset=row_offset, @@ -2373,7 +2386,7 @@ def split_optimizer_states( # backend will return both weight and optimizer in one tensor, read the whole tensor # out could OOM CPU memory. tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - shape=[emb_height, optimizer_dim], + shape=[emb_height, self.optimizer_state_dim], dtype=dtype, row_offset=row_offset, snapshot_handle=snapshot_handle, @@ -2652,11 +2665,7 @@ def split_embedding_weights( ( metaheader_dim # metaheader is already padded + pad4(emb_dim) - + pad4( - self.optimizer.state_size_dim( - self.weights_precision.as_dtype() - ) - ) + + pad4(self.optimizer_state_dim) ) if self.backend_return_whole_row else emb_dim @@ -2802,8 +2811,7 @@ def streaming_write_weight_and_id_per_table( # TODO: make chunk_size configurable or dynamic chunk_size = 10000 row = weight_state.size(0) - optimizer_dim = self.optimizer.state_size_dim(dtype) - opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim) + opt_state_2d = opt_state.view(dtype).view(-1, self.optimizer_state_dim) for i in range(0, row, chunk_size): length = min(chunk_size, row - i) chunk_buffer = torch.empty( @@ -2813,9 +2821,9 @@ def streaming_write_weight_and_id_per_table( device="cpu", ) chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :] - chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[ - i : i + length, : - ] + chunk_buffer[:, D_rounded : D_rounded + self.optimizer_state_dim] = ( + opt_state_2d[i : i + length, :] + ) kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1)) @torch.jit.ignore @@ -3454,20 +3462,35 @@ def fetch_from_l1_sp_w_row_ids( Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids. @return: updated_weights/optimizer_states, mask of which rows are filled """ + if not self.enable_optimizer_offloading and only_get_optimizer_states: + raise RuntimeError( + "Optimizer states are not offloaded, while only_get_optimizer_states is True" + ) + + # NOTE: Remove this once there is support for fetching multiple + # optimizer states in fetch_from_l1_sp_w_row_ids + if self.optimizer != OptimType.EXACT_ROWWISE_ADAGRAD: + raise RuntimeError( + "Only rowwise adagrad is supported in fetch_from_l1_sp_w_row_ids at the moment" + ) + with torch.no_grad(): weights_dtype = self.weights_precision.as_dtype() step = self.step - if not self.enable_optimizer_offloading and only_get_optimizer_states: - raise RuntimeError( - "Optimizer states are not offloaded, while only_get_optimizer_states is True" - ) + if only_get_optimizer_states: start_pos = pad4(self.max_D) - row_dim = self.optimizer.state_size_dim(weights_dtype) - result_dtype = self.optimizer.dtype() + # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working + # until it is upgraded to support optimizers with multiple states + # and dtypes + row_dim = int( + math.ceil(torch.float32.itemsize / weights_dtype.itemsize) + ) + result_dtype = torch.float32 result_dim = int( ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize)) ) + else: start_pos = 0 # get the whole row diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 881900d321..b5e8dc4570 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -2678,7 +2678,7 @@ def test_raw_embedding_streaming_prefetch_pipeline( @given(**default_st) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd_fetch_from_l1_sp_w_row_ids( + def test_ssd_fetch_from_l1_sp_w_row_ids_weight( self, T: int, D: int, @@ -2928,7 +2928,10 @@ def test_ssd_fetch_from_l1_sp_w_row_ids_opt_only( indices.numel(), 1, device=emb.current_device, - dtype=emb.optimizer.dtype(), + # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids unit test + # working until it is upgraded to support optimizers with multiple + # states and dtypes + dtype=torch.float32, ) linearized_indices = [] for f, idxes in enumerate(indices_list): @@ -3034,7 +3037,7 @@ def copy_opt_states_hook( torch.testing.assert_close( split_optimizer_states[t][indices].float(), - opt_states_per_tb.cpu(), + opt_states_per_tb.cpu().float(), atol=tolerance, rtol=tolerance, )