Skip to content

Commit 619b6ab

Browse files
q10facebook-github-bot
authored andcommitted
Update the cache row dim calculation in TBE SSD (#4480)
Summary: Pull Request resolved: #4480 X-link: facebookresearch/FBGEMM#1537 - The current cache row dim calculation in TBE SSD assumes that optimizers have state sizes that are fixed relative to table dimensions. This change updates the cache row dim calculation to account for optimizers whose states' sizes depends on the row length, such as Partial Rowwise Adam. Reviewed By: sryap, emlin, jiawenliu64 Differential Revision: D77321062 fbshipit-source-id: 001002e945c03eb4d28dd35b837797cf87ebe45b
1 parent 87a03b6 commit 619b6ab

File tree

3 files changed

+75
-39
lines changed

3 files changed

+75
-39
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
# pyre-strict
99

1010
import enum
11-
import math
1211
from typing import Any, Dict # noqa: F401
1312

1413
import torch
@@ -41,22 +40,33 @@ class EmbOptimType(enum.Enum):
4140
def __str__(self) -> str:
4241
return self.value
4342

44-
def state_size(self) -> int:
43+
def _extract_dtype(
44+
self, optimizer_state_dtypes: Dict[str, "SparseType"], name: str
45+
) -> torch.dtype:
46+
if optimizer_state_dtypes is None or name not in optimizer_state_dtypes:
47+
return torch.float32
48+
return optimizer_state_dtypes[name].as_dtype()
49+
50+
def state_size_nbytes(
51+
self, D: int, optimizer_state_dtypes: Dict[str, "SparseType"] = {} # noqa: B006
52+
) -> int:
4553
"""
4654
Returns the size of the data (in bytes) required to hold the optimizer
47-
state (per table row), or 0 if none needed
55+
state (per table row)
4856
"""
49-
return {
50-
# Only holds the momentum float value per row
51-
EmbOptimType.EXACT_ROWWISE_ADAGRAD: torch.float32.itemsize,
52-
}.get(self, 0)
57+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
58+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
59+
# Store one value for momentum per row
60+
return momentum1_dtype.itemsize
5361

54-
def state_size_dim(self, dtype: torch.dtype) -> int:
55-
"""
56-
Returns the size of the data (in units of elements of dtype) rquired to
57-
hold optimizer information (per table row)
58-
"""
59-
return int(math.ceil(self.state_size() / dtype.itemsize))
62+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
63+
momentum1_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum1")
64+
momentum2_dtype = self._extract_dtype(optimizer_state_dtypes, "momentum2")
65+
# Store one value for momentum2 plus D values for momentum1 per row
66+
return momentum2_dtype.itemsize + (D * momentum1_dtype.itemsize)
67+
68+
else:
69+
return 0
6070

6171
def dtype(self) -> torch.dtype:
6272
"""

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
import itertools
1414
import logging
15+
import math
1516
import os
1617
import tempfile
1718
import threading
@@ -172,6 +173,7 @@ def __init__(
172173
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
173174
flushing_block_size: int = 2_000_000_000, # 2GB
174175
table_names: Optional[List[str]] = None,
176+
optimizer_state_dtypes: Dict[str, SparseType] = {}, # noqa: B006
175177
) -> None:
176178
super(SSDTableBatchedEmbeddingBags, self).__init__()
177179

@@ -185,6 +187,7 @@ def __init__(
185187
assert weights_precision in (SparseType.FP32, SparseType.FP16)
186188
self.weights_precision = weights_precision
187189
self.output_dtype: int = output_dtype.as_int()
190+
self.optimizer_state_dtypes: Dict[str, SparseType] = optimizer_state_dtypes
188191

189192
# Zero collision TBE configurations
190193
self.kv_zch_params = kv_zch_params
@@ -987,13 +990,24 @@ def cache_row_dim(self) -> int:
987990
"""
988991
if self.enable_optimizer_offloading:
989992
return self.max_D + pad4(
990-
# Compute the number of elements of cache_dtype needed to store the
991-
# optimizer state
992-
self.optimizer.state_size_dim(self.weights_precision.as_dtype())
993+
# Compute the number of elements of cache_dtype needed to store
994+
# the optimizer state
995+
self.optimizer_state_dim
993996
)
994997
else:
995998
return self.max_D
996999

1000+
@cached_property
1001+
def optimizer_state_dim(self) -> int:
1002+
return int(
1003+
math.ceil(
1004+
self.optimizer.state_size_nbytes(
1005+
self.max_D, self.optimizer_state_dtypes
1006+
)
1007+
/ self.weights_precision.as_dtype().itemsize
1008+
)
1009+
)
1010+
9971011
@property
9981012
# pyre-ignore
9991013
def ssd_db(self):
@@ -2285,9 +2299,8 @@ def split_optimizer_states(
22852299
table_offset = 0
22862300

22872301
dtype = self.weights_precision.as_dtype()
2288-
optimizer_dim = self.optimizer.state_size_dim(dtype)
22892302
logging.info(
2290-
f"split_optimizer_states: {optimizer_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
2303+
f"split_optimizer_states: {self.optimizer_state_dim=}, {self.optimizer.dtype()=} {self.enable_load_state_dict_mode=}"
22912304
)
22922305

22932306
for t, (emb_height, emb_dim) in enumerate(self.embedding_specs):
@@ -2345,7 +2358,7 @@ def split_optimizer_states(
23452358
and sorted_id_tensor[t].size(0) > 0
23462359
else emb_height
23472360
),
2348-
optimizer_dim,
2361+
self.optimizer_state_dim,
23492362
],
23502363
dtype=dtype,
23512364
row_offset=row_offset,
@@ -2373,7 +2386,7 @@ def split_optimizer_states(
23732386
# backend will return both weight and optimizer in one tensor, read the whole tensor
23742387
# out could OOM CPU memory.
23752388
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2376-
shape=[emb_height, optimizer_dim],
2389+
shape=[emb_height, self.optimizer_state_dim],
23772390
dtype=dtype,
23782391
row_offset=row_offset,
23792392
snapshot_handle=snapshot_handle,
@@ -2652,11 +2665,7 @@ def split_embedding_weights(
26522665
(
26532666
metaheader_dim # metaheader is already padded
26542667
+ pad4(emb_dim)
2655-
+ pad4(
2656-
self.optimizer.state_size_dim(
2657-
self.weights_precision.as_dtype()
2658-
)
2659-
)
2668+
+ pad4(self.optimizer_state_dim)
26602669
)
26612670
if self.backend_return_whole_row
26622671
else emb_dim
@@ -2802,8 +2811,7 @@ def streaming_write_weight_and_id_per_table(
28022811
# TODO: make chunk_size configurable or dynamic
28032812
chunk_size = 10000
28042813
row = weight_state.size(0)
2805-
optimizer_dim = self.optimizer.state_size_dim(dtype)
2806-
opt_state_2d = opt_state.view(dtype).view(-1, optimizer_dim)
2814+
opt_state_2d = opt_state.view(dtype).view(-1, self.optimizer_state_dim)
28072815
for i in range(0, row, chunk_size):
28082816
length = min(chunk_size, row - i)
28092817
chunk_buffer = torch.empty(
@@ -2813,9 +2821,9 @@ def streaming_write_weight_and_id_per_table(
28132821
device="cpu",
28142822
)
28152823
chunk_buffer[:, : weight_state.size(1)] = weight_state[i : i + length, :]
2816-
chunk_buffer[:, D_rounded : D_rounded + optimizer_dim] = opt_state_2d[
2817-
i : i + length, :
2818-
]
2824+
chunk_buffer[:, D_rounded : D_rounded + self.optimizer_state_dim] = (
2825+
opt_state_2d[i : i + length, :]
2826+
)
28192827
kvt.set_weights_and_ids(chunk_buffer, id_tensor[i : i + length, :].view(-1))
28202828

28212829
@torch.jit.ignore
@@ -3454,20 +3462,35 @@ def fetch_from_l1_sp_w_row_ids(
34543462
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
34553463
@return: updated_weights/optimizer_states, mask of which rows are filled
34563464
"""
3465+
if not self.enable_optimizer_offloading and only_get_optimizer_states:
3466+
raise RuntimeError(
3467+
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
3468+
)
3469+
3470+
# NOTE: Remove this once there is support for fetching multiple
3471+
# optimizer states in fetch_from_l1_sp_w_row_ids
3472+
if self.optimizer != OptimType.EXACT_ROWWISE_ADAGRAD:
3473+
raise RuntimeError(
3474+
"Only rowwise adagrad is supported in fetch_from_l1_sp_w_row_ids at the moment"
3475+
)
3476+
34573477
with torch.no_grad():
34583478
weights_dtype = self.weights_precision.as_dtype()
34593479
step = self.step
3460-
if not self.enable_optimizer_offloading and only_get_optimizer_states:
3461-
raise RuntimeError(
3462-
"Optimizer states are not offloaded, while only_get_optimizer_states is True"
3463-
)
3480+
34643481
if only_get_optimizer_states:
34653482
start_pos = pad4(self.max_D)
3466-
row_dim = self.optimizer.state_size_dim(weights_dtype)
3467-
result_dtype = self.optimizer.dtype()
3483+
# NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working
3484+
# until it is upgraded to support optimizers with multiple states
3485+
# and dtypes
3486+
row_dim = int(
3487+
math.ceil(torch.float32.itemsize / weights_dtype.itemsize)
3488+
)
3489+
result_dtype = torch.float32
34683490
result_dim = int(
34693491
ceil(row_dim / (result_dtype.itemsize / weights_dtype.itemsize))
34703492
)
3493+
34713494
else:
34723495
start_pos = 0
34733496
# get the whole row

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2678,7 +2678,7 @@ def test_raw_embedding_streaming_prefetch_pipeline(
26782678

26792679
@given(**default_st)
26802680
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
2681-
def test_ssd_fetch_from_l1_sp_w_row_ids(
2681+
def test_ssd_fetch_from_l1_sp_w_row_ids_weight(
26822682
self,
26832683
T: int,
26842684
D: int,
@@ -2928,7 +2928,10 @@ def test_ssd_fetch_from_l1_sp_w_row_ids_opt_only(
29282928
indices.numel(),
29292929
1,
29302930
device=emb.current_device,
2931-
dtype=emb.optimizer.dtype(),
2931+
# NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids unit test
2932+
# working until it is upgraded to support optimizers with multiple
2933+
# states and dtypes
2934+
dtype=torch.float32,
29322935
)
29332936
linearized_indices = []
29342937
for f, idxes in enumerate(indices_list):
@@ -3034,7 +3037,7 @@ def copy_opt_states_hook(
30343037

30353038
torch.testing.assert_close(
30363039
split_optimizer_states[t][indices].float(),
3037-
opt_states_per_tb.cpu(),
3040+
opt_states_per_tb.cpu().float(),
30383041
atol=tolerance,
30393042
rtol=tolerance,
30403043
)

0 commit comments

Comments
 (0)