Skip to content

Update the cache row dim calculation in TBE SSD #4480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# pyre-strict

import enum
import math
from typing import Any, Dict # noqa: F401

import torch
Expand Down Expand Up @@ -41,22 +40,33 @@ class EmbOptimType(enum.Enum):
def __str__(self) -> str:
return self.value

def state_size(self) -> int:
def _extract_dtype(
self, optimizer_state_dtypes: Dict[str, "SparseType"], name: str
) -> torch.dtype:
if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
return torch.float32
return optimizer_state_dtypes[name].as_dtype()

def state_size_nbytes(
self, D: int, optimizer_state_dtypes: Dict[str, "SparseType"] = {} # noqa: B006
) -> int:
"""
Returns the size of the data (in bytes) required to hold the optimizer
state (per table row), or 0 if none needed
state (per table row)
"""
return {
# Only holds the momentum float value per row
EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32.itemsize,
}.get(self, 0)
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
# Store one value for momentum per row
return momentum1_dtype.itemsize

def state_size_dim(self, dtype: torch.dtype) -> int:
"""
Returns the size of the data (in units of elements of dtype) rquired to
hold optimizer information (per table row)
"""
return int(math.ceil(self.state_size() / dtype.itemsize))
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
# Store one value for momentum2 plus D values for momentum1 per row
return momentum2_dtype.itemsize + (D * momentum1_dtype.itemsize)

else:
return 0

def dtype(self) -> torch.dtype:
"""
Expand Down
69 changes: 46 additions & 23 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import functools
import itertools
import logging
import math
import os
import tempfile
import threading
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
flushing_block_size: int = 2_000_000_000, # 2GB
table_names: Optional[List[str]] = None,
optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006
) -> None:
super(SSDTableBatchedEmbeddingBags, self).__init__()

Expand All @@ -185,6 +187,7 @@ def __init__(
assert weights_precision in (SparseType.FP32, SparseType.FP16)
self.weights_precision = weights_precision
self.output_dtype: int = output_dtype.as_int()
self.optimizer_state_dtypes: Dict[str, SparseType] = optimizer_state_dtypes

# Zero collision TBE configurations
self.kv_zch_params = kv_zch_params
Expand Down Expand Up @@ -987,13 +990,24 @@ def cache_row_dim(self) -> int:
"""
if self.enable_optimizer_offloading:
return self.max_D + pad4(
# Compute the number of elements of cache_dtype needed to store the
# optimizer state
self.optimizer.state_size_dim(self.weights_precision.as_dtype())
# Compute the number of elements of cache_dtype needed to store
# the optimizer state
self.optimizer_state_dim
)
else:
return self.max_D

@cached_property
def optimizer_state_dim(self) -> int:
return int(
math.ceil(
self.optimizer.state_size_nbytes(
self.max_D, self.optimizer_state_dtypes
)
/ self.weights_precision.as_dtype().itemsize
)
)

@property
# pyre-ignore
def ssd_db(self):
Expand Down Expand Up @@ -2285,9 +2299,8 @@ def split_optimizer_states(
table_offset = 0

dtype = self.weights_precision.as_dtype()
optimizer_dim = self.optimizer.state_size_dim(dtype)
logging.info(
f"split_optimizer_states: {optimizer_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
f"split_optimizer_states: {self.optimizer_state_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
)

for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
Expand Down Expand Up @@ -2345,7 +2358,7 @@ def split_optimizer_states(
and sorted_id_tensor[t].size(0) > 0
else emb_height
),
optimizer_dim,
self.optimizer_state_dim,
],
dtype=dtype,
row_offset=row_offset,
Expand Down Expand Up @@ -2373,7 +2386,7 @@ def split_optimizer_states(
# backend will return both weight and optimizer in one tensor, read the whole tensor
# out could OOM CPU memory.
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
shape=[emb_height, optimizer_dim],
shape=[emb_height, self.optimizer_state_dim],
dtype=dtype,
row_offset=row_offset,
snapshot_handle=snapshot_handle,
Expand Down Expand Up @@ -2652,11 +2665,7 @@ def split_embedding_weights(
(
metaheader_dim # metaheader is already padded
+ pad4(emb_dim)
+ pad4(
self.optimizer.state_size_dim(
self.weights_precision.as_dtype()
)
)
+ pad4(self.optimizer_state_dim)
)
if self.backend_return_whole_row
else emb_dim
Expand Down Expand Up @@ -2802,8 +2811,7 @@ def streaming_write_weight_and_id_per_table(
# TODO: make chunk_size configurable or dynamic
chunk_size = 10000
row = weight_state.size(0)
optimizer_dim = self.optimizer.state_size_dim(dtype)
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
opt_state_2d = opt_state.view(dtype).view(-1, self.optimizer_state_dim)
for i in range(0, row, chunk_size):
length = min(chunk_size, row - i)
chunk_buffer = torch.empty(
Expand All @@ -2813,9 +2821,9 @@ def streaming_write_weight_and_id_per_table(
device="cpu",
)
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
i : i + length, :
]
chunk_buffer[:, D_rounded : D_rounded + self.optimizer_state_dim] = (
opt_state_2d[i : i + length, :]
)
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))

@torch.jit.ignore
Expand Down Expand Up @@ -3454,20 +3462,35 @@ def fetch_from_l1_sp_w_row_ids(
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
@return: updated_weights/optimizer_states, mask of which rows are filled
"""
if not self.enable_optimizer_offloading and only_get_optimizer_states:
raise RuntimeError(
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
)

# NOTE: Remove this once there is support for fetching multiple
# optimizer states in fetch_from_l1_sp_w_row_ids
if self.optimizer != OptimType.EXACT_ROWWISE_ADAGRAD:
raise RuntimeError(
"Only rowwise adagrad is supported in fetch_from_l1_sp_w_row_ids at the moment"
)

with torch.no_grad():
weights_dtype = self.weights_precision.as_dtype()
step = self.step
if not self.enable_optimizer_offloading and only_get_optimizer_states:
raise RuntimeError(
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
)

if only_get_optimizer_states:
start_pos = pad4(self.max_D)
row_dim = self.optimizer.state_size_dim(weights_dtype)
result_dtype = self.optimizer.dtype()
# NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working
# until it is upgraded to support optimizers with multiple states
# and dtypes
row_dim = int(
math.ceil(torch.float32.itemsize / weights_dtype.itemsize)
)
result_dtype = torch.float32
result_dim = int(
ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize))
)

else:
start_pos = 0
# get the whole row
Expand Down
9 changes: 6 additions & 3 deletions fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2678,7 +2678,7 @@ def test_raw_embedding_streaming_prefetch_pipeline(

@given(**default_st)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_ssd_fetch_from_l1_sp_w_row_ids(
def test_ssd_fetch_from_l1_sp_w_row_ids_weight(
self,
T: int,
D: int,
Expand Down Expand Up @@ -2928,7 +2928,10 @@ def test_ssd_fetch_from_l1_sp_w_row_ids_opt_only(
indices.numel(),
1,
device=emb.current_device,
dtype=emb.optimizer.dtype(),
# NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids unit test
# working until it is upgraded to support optimizers with multiple
# states and dtypes
dtype=torch.float32,
)
linearized_indices = []
for f, idxes in enumerate(indices_list):
Expand Down Expand Up @@ -3034,7 +3037,7 @@ def copy_opt_states_hook(

torch.testing.assert_close(
split_optimizer_states[t][indices].float(),
opt_states_per_tb.cpu(),
opt_states_per_tb.cpu().float(),
atol=tolerance,
rtol=tolerance,
)
Loading