Skip to content

Commit 16cf235

Browse files
q10facebook-github-bot
authored andcommitted
Simplify the SplitState application for optimizers TBE SSD (#4492)
Summary: Pull Request resolved: #4492 Currently, the split state application for optimizer states in TBE SSD presumes Exact Rowwise Adagrad, which has only one optimizer state. This change extends the split application to support optimizers with more than one state, such as Partial Rowwise Adam. The code draws inspiration from `construct_split_state()` in non-SSD TBE, but is much more simplified and declarative and thus ergonomic to use Reviewed By: sryap, emlin, ionuthristodorescu Differential Revision: D76709101 fbshipit-source-id: e258713a7c0fd5aac8de0e07fbe6e0c4d7e76947
1 parent 66eb6cf commit 16cf235

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

fbgemm_gpu/fbgemm_gpu/split_embedding_configs.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,16 @@
88
# pyre-strict
99

1010
import enum
11-
from typing import Any, Dict # noqa: F401
11+
import itertools
12+
from typing import Any, Dict, List, Tuple # noqa: F401
1213

1314
import torch
1415

16+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
17+
EmbeddingLocation,
18+
SplitState,
19+
)
20+
1521

1622
@enum.unique
1723
class EmbOptimType(enum.Enum):
@@ -68,6 +74,49 @@ def state_size_nbytes(
6874
else:
6975
return 0
7076

77+
def ssd_state_splits(
78+
self,
79+
embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims)
80+
optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006
81+
enable_optimizer_offloading: bool = False,
82+
) -> List[Tuple[SplitState, str, torch.dtype]]:
83+
"""
84+
Returns the split planning for the optimizer states
85+
"""
86+
(rows, _) = zip(*embedding_specs)
87+
T_ = len(embedding_specs)
88+
89+
# This is the cumulative row counts for rowwise states
90+
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
91+
# This is the cumulative element counts for elementwise states
92+
table_size_cumsum: List[int] = [0] + list(
93+
itertools.accumulate([r * d for r, d in embedding_specs])
94+
)
95+
96+
if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD:
97+
params = {"momentum1": row_count_cumsum}
98+
elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM:
99+
params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum}
100+
else:
101+
params = {}
102+
103+
return [
104+
(
105+
SplitState(
106+
dev_size=(
107+
cumsum_table[-1] if not enable_optimizer_offloading else 0
108+
),
109+
host_size=0,
110+
uvm_size=0,
111+
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
112+
offsets=cumsum_table[:-1],
113+
),
114+
name,
115+
self._extract_dtype(optimizer_state_dtypes, name),
116+
)
117+
for (name, cumsum_table) in params.items()
118+
]
119+
71120
def dtype(self) -> torch.dtype:
72121
"""
73122
Returns the dtype of the optimizer state

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -852,21 +852,14 @@ def __init__(
852852
dtype=table_embedding_dtype,
853853
)
854854

855-
momentum1_offsets = [0] + list(itertools.accumulate(rows))
856-
self._apply_split(
857-
SplitState(
858-
dev_size=(
859-
self.total_hash_size if not self.enable_optimizer_offloading else 0
860-
),
861-
host_size=0,
862-
uvm_size=0,
863-
placements=[EmbeddingLocation.DEVICE for _ in range(T_)],
864-
offsets=momentum1_offsets[:-1],
865-
),
866-
"momentum1",
855+
# Create the optimizer state tensors
856+
for template in self.optimizer.ssd_state_splits(
857+
self.embedding_specs,
858+
self.optimizer_state_dtypes,
859+
self.enable_optimizer_offloading,
860+
):
867861
# pyre-fixme[6]: For 3rd argument expected `Type[dtype]` but got `dtype`.
868-
dtype=torch.float32,
869-
)
862+
self._apply_split(*template)
870863

871864
# For storing current iteration data
872865
self.current_iter_data: Optional[IterData] = None

0 commit comments

Comments
 (0)