diff --git a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py index 9d2a5a871b..8cdfafe599 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py +++ b/fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py @@ -8,7 +8,6 @@ # pyre-strict import enum -import math from typing import Any, Dict # noqa: F401 import torch @@ -41,22 +40,33 @@ class EmbOptimType(enum.Enum): def __str__(self) -> str: return self.value - def state_size(self) -> int: + def _extract_dtype( + self, optimizer_state_dtypes: Dict[str, "SparseType"], name: str + ) -> torch.dtype: + if optimizer_state_dtypes is None or name not in optimizer_state_dtypes: + return torch.float32 + return optimizer_state_dtypes[name].as_dtype() + + def state_size_nbytes( + self, D: int, optimizer_state_dtypes: Dict[str, "SparseType"] = {} # noqa: B006 + ) -> int: """ Returns the size of the data (in bytes) required to hold the optimizer - state (per table row), or 0 if none needed + state (per table row) """ - return { - # Only holds the momentum float value per row - EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32.itemsize, - }.get(self, 0) + if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1") + # Store one value for momentum per row + return momentum1_dtype.itemsize - def state_size_dim(self, dtype: torch.dtype) -> int: - """ - Returns the size of the data (in units of elements of dtype) rquired to - hold optimizer information (per table row) - """ - return int(math.ceil(self.state_size() / dtype.itemsize)) + elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM: + momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1") + momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2") + # Store one value for momentum2 plus D values for momentum1 per row + return momentum2_dtype.itemsize + (D * momentum1_dtype.itemsize) + + else: + return 0 def dtype(self) -> torch.dtype: """ diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py index 7ef539e036..179b35c011 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py @@ -12,6 +12,7 @@ import functools import itertools import logging +import math import os import tempfile import threading @@ -172,6 +173,7 @@ def __init__( res_params: Optional[RESParams] = None, # raw embedding streaming sharding info flushing_block_size: int = 2_000_000_000, # 2GB table_names: Optional[List[str]] = None, + optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006 ) -> None: super(SSDTableBatchedEmbeddingBags, self).__init__() @@ -185,6 +187,7 @@ def __init__( assert weights_precision in (SparseType.FP32, SparseType.FP16) self.weights_precision = weights_precision self.output_dtype: int = output_dtype.as_int() + self.optimizer_state_dtypes: Dict[str, SparseType] = optimizer_state_dtypes # Zero collision TBE configurations self.kv_zch_params = kv_zch_params @@ -987,13 +990,24 @@ def cache_row_dim(self) -> int: """ if self.enable_optimizer_offloading: return self.max_D + pad4( - # Compute the number of elements of cache_dtype needed to store the - # optimizer state - self.optimizer.state_size_dim(self.weights_precision.as_dtype()) + # Compute the number of elements of cache_dtype needed to store + # the optimizer state + self.optimizer_state_dim ) else: return self.max_D + @cached_property + def optimizer_state_dim(self) -> int: + return int( + math.ceil( + self.optimizer.state_size_nbytes( + self.max_D, self.optimizer_state_dtypes + ) + / self.weights_precision.as_dtype().itemsize + ) + ) + @property # pyre-ignore def ssd_db(self): @@ -2285,9 +2299,8 @@ def split_optimizer_states( table_offset = 0 dtype = self.weights_precision.as_dtype() - optimizer_dim = self.optimizer.state_size_dim(dtype) logging.info( - f"split_optimizer_states: {optimizer_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}" + f"split_optimizer_states: {self.optimizer_state_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}" ) for t, (emb_height, emb_dim) in enumerate(self.embedding_specs): @@ -2345,7 +2358,7 @@ def split_optimizer_states( and sorted_id_tensor[t].size(0) > 0 else emb_height ), - optimizer_dim, + self.optimizer_state_dim, ], dtype=dtype, row_offset=row_offset, @@ -2373,7 +2386,7 @@ def split_optimizer_states( # backend will return both weight and optimizer in one tensor, read the whole tensor # out could OOM CPU memory. tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper( - shape=[emb_height, optimizer_dim], + shape=[emb_height, self.optimizer_state_dim], dtype=dtype, row_offset=row_offset, snapshot_handle=snapshot_handle, @@ -2652,11 +2665,7 @@ def split_embedding_weights( ( metaheader_dim # metaheader is already padded + pad4(emb_dim) - + pad4( - self.optimizer.state_size_dim( - self.weights_precision.as_dtype() - ) - ) + + pad4(self.optimizer_state_dim) ) if self.backend_return_whole_row else emb_dim @@ -2802,8 +2811,7 @@ def streaming_write_weight_and_id_per_table( # TODO: make chunk_size configurable or dynamic chunk_size = 10000 row = weight_state.size(0) - optimizer_dim = self.optimizer.state_size_dim(dtype) - opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim) + opt_state_2d = opt_state.view(dtype).view(-1, self.optimizer_state_dim) for i in range(0, row, chunk_size): length = min(chunk_size, row - i) chunk_buffer = torch.empty( @@ -2813,9 +2821,9 @@ def streaming_write_weight_and_id_per_table( device="cpu", ) chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :] - chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[ - i : i + length, : - ] + chunk_buffer[:, D_rounded : D_rounded + self.optimizer_state_dim] = ( + opt_state_2d[i : i + length, :] + ) kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1)) @torch.jit.ignore @@ -3454,20 +3462,35 @@ def fetch_from_l1_sp_w_row_ids( Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids. @return: updated_weights/optimizer_states, mask of which rows are filled """ + if not self.enable_optimizer_offloading and only_get_optimizer_states: + raise RuntimeError( + "Optimizer states are not offloaded, while only_get_optimizer_states is True" + ) + + # NOTE: Remove this once there is support for fetching multiple + # optimizer states in fetch_from_l1_sp_w_row_ids + if self.optimizer != OptimType.EXACT_ROWWISE_ADAGRAD: + raise RuntimeError( + "Only rowwise adagrad is supported in fetch_from_l1_sp_w_row_ids at the moment" + ) + with torch.no_grad(): weights_dtype = self.weights_precision.as_dtype() step = self.step - if not self.enable_optimizer_offloading and only_get_optimizer_states: - raise RuntimeError( - "Optimizer states are not offloaded, while only_get_optimizer_states is True" - ) + if only_get_optimizer_states: start_pos = pad4(self.max_D) - row_dim = self.optimizer.state_size_dim(weights_dtype) - result_dtype = self.optimizer.dtype() + # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working + # until it is upgraded to support optimizers with multiple states + # and dtypes + row_dim = int( + math.ceil(torch.float32.itemsize / weights_dtype.itemsize) + ) + result_dtype = torch.float32 result_dim = int( ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize)) ) + else: start_pos = 0 # get the whole row diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py index 881900d321..b5e8dc4570 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py @@ -2678,7 +2678,7 @@ def test_raw_embedding_streaming_prefetch_pipeline( @given(**default_st) @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) - def test_ssd_fetch_from_l1_sp_w_row_ids( + def test_ssd_fetch_from_l1_sp_w_row_ids_weight( self, T: int, D: int, @@ -2928,7 +2928,10 @@ def test_ssd_fetch_from_l1_sp_w_row_ids_opt_only( indices.numel(), 1, device=emb.current_device, - dtype=emb.optimizer.dtype(), + # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids unit test + # working until it is upgraded to support optimizers with multiple + # states and dtypes + dtype=torch.float32, ) linearized_indices = [] for f, idxes in enumerate(indices_list): @@ -3034,7 +3037,7 @@ def copy_opt_states_hook( torch.testing.assert_close( split_optimizer_states[t][indices].float(), - opt_states_per_tb.cpu(), + opt_states_per_tb.cpu().float(), atol=tolerance, rtol=tolerance, )