Skip to content

Commit 7d8feca

Browse files
q10facebook-github-bot
authored andcommitted
Expand split_optimizer_states() to support multiple optimizer states (#4495)
Summary: Pull Request resolved: #4495 X-link: facebookresearch/FBGEMM#1548 - Expand `split_optimizer_states()` to support multiple optimizer states. This is necessary for unit tests involving new optimizers such as Partial Rowwise Adam to work. There are 4 cases to handle when attempting to fetch the split optimizer states: 1. The no-KV ZCH case 1. The KV ZCH case, but where `self.load_state_dict` is `True` (i.e. fall back to `self._cached_kvzch_data`) 1. The KV ZCH case, where `self.load_state_dict` is `False`, and `self.enable_optimizer_offloading` is false 1. The KV ZCH case, where `self.load_state_dict` is `False`, and `self.enable_optimizer_offloading` is `True` This diff completes the handling of returning optimizer states from SSD TBE for the non-KV ZCH case (case 1). The rest will be implemented in subsequent diffs along the stack Reviewed By: emlin, ionuthristodorescu Differential Revision: D77337646 fbshipit-source-id: d010b347009867cc936dc177802adbf31066526b
1 parent 2168495 commit 7d8feca

File tree

2 files changed

+146
-40
lines changed

2 files changed

+146
-40
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2230,28 +2230,77 @@ def forward(
22302230
@torch.jit.ignore
22312231
def _split_optimizer_states_non_kv_zch(
22322232
self,
2233-
) -> List[torch.Tensor]:
2233+
) -> List[List[torch.Tensor]]:
22342234
"""
2235-
Returns a list of optimizer states, split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
2236-
so only momentum1 state is returned.
2235+
Returns a list of optimizer states (view), split by table.
2236+
2237+
Returns:
2238+
A list of list of states. Shape = (the number of tables, the number
2239+
of states).
2240+
2241+
The following shows the list of states (in the returned order) for
2242+
each optimizer:
2243+
2244+
(1) `EXACT_ROWWISE_ADAGRAD`: `momentum1` (rowwise)
2245+
2246+
(1) `PARTIAL_ROWWISE_ADAM`: `momentum1`, `momentum2` (rowwise)
22372247
"""
2248+
22382249
logging.info("_split_optimizer_states_non_kv_zch")
2239-
(rows, _) = zip(*self.embedding_specs)
22402250

2241-
rows_cumsum = [0] + list(itertools.accumulate(rows))
2251+
# Row count per table
2252+
(rows, dims) = zip(*self.embedding_specs)
2253+
# Cumulative row counts per table for rowwise states
2254+
row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows))
2255+
# Cumulative element counts per table for elementwise states
2256+
elem_count_cumsum: List[int] = [0] + list(
2257+
itertools.accumulate([r * d for r, d in self.embedding_specs])
2258+
)
2259+
2260+
# pyre-ignore[53]
2261+
def _slice(tensor: Tensor, t: int, rowwise: bool) -> Tensor:
2262+
d: int = dims[t]
2263+
e: int = rows[t]
2264+
2265+
if not rowwise:
2266+
# Optimizer state is element-wise - compute the table offset for
2267+
# the table, view the slice as 2D tensor
2268+
return tensor.detach()[
2269+
elem_count_cumsum[t] : elem_count_cumsum[t + 1]
2270+
].view(-1, d)
2271+
else:
2272+
# Optimizer state is row-wise - fetch elements in range and view
2273+
# slice as 1D
2274+
return tensor.detach()[
2275+
row_count_cumsum[t] : row_count_cumsum[t + 1]
2276+
].view(e)
22422277

2243-
return [
2244-
self.momentum1_dev.detach()[rows_cumsum[t] : rows_cumsum[t + 1]].view(row)
2245-
for t, row in enumerate(rows)
2246-
]
2278+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2279+
return [
2280+
[_slice(self.momentum1_dev, t, rowwise=True)]
2281+
for t, _ in enumerate(rows)
2282+
]
2283+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2284+
return [
2285+
[
2286+
_slice(self.momentum1_dev, t, rowwise=False),
2287+
# pyre-ignore[6]
2288+
_slice(self.momentum2_dev, t, rowwise=True),
2289+
]
2290+
for t, _ in enumerate(rows)
2291+
]
2292+
else:
2293+
raise NotImplementedError(
2294+
f"Getting optimizer states is not supported for {self.optimizer}"
2295+
)
22472296

22482297
@torch.jit.export
22492298
def split_optimizer_states(
22502299
self,
22512300
sorted_id_tensor: Optional[List[torch.Tensor]] = None,
22522301
no_snapshot: bool = True,
22532302
should_flush: bool = False,
2254-
) -> List[torch.Tensor]:
2303+
) -> List[List[torch.Tensor]]:
22552304
"""
22562305
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
22572306
so only momentum1 state is returned.
@@ -2277,7 +2326,16 @@ def split_optimizer_states(
22772326
self._cached_kvzch_data is not None
22782327
and self._cached_kvzch_data.cached_optimizer_state_per_table
22792328
), "optimizer state is not initialized for load checkpointing"
2280-
return self._cached_kvzch_data.cached_optimizer_state_per_table
2329+
2330+
# NOTE: This is a temporary hack to have split_optimizer_states return a
2331+
# List[List[Tensor]] instead of List[Tensor] to match the behavior of
2332+
# _split_optimizer_states_non_kv_zch. This should be removed after
2333+
# proper support for multiple optimizers is added for the
2334+
# enable_optimizer_offloading=True case.
2335+
return [
2336+
[opt]
2337+
for opt in self._cached_kvzch_data.cached_optimizer_state_per_table
2338+
]
22812339

22822340
logging.info(
22832341
f"split_optimizer_states for KV ZCH: {no_snapshot=}, {should_flush=}"
@@ -2401,7 +2459,13 @@ def split_optimizer_states(
24012459
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
24022460
f"num ids list: {None if not sorted_id_tensor else [ids.numel() for ids in sorted_id_tensor]}"
24032461
)
2404-
return opt_list
2462+
2463+
# NOTE: This is a temporary hack to have split_optimizer_states return a
2464+
# List[List[Tensor]] instead of List[Tensor] to match the behavior of
2465+
# _split_optimizer_states_non_kv_zch. This should be removed after
2466+
# proper support for multiple optimizers is added for the
2467+
# enable_optimizer_offloading=True case.
2468+
return [[opt] for opt in opt_list]
24052469

24062470
@torch.jit.export
24072471
def get_offloaded_optimizer_states(
@@ -2438,14 +2502,22 @@ def get_optimizer_state(
24382502
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
24392503
so only momentum1 state is returned.
24402504
"""
2441-
return [
2442-
({"momentum1": states})
2443-
for states in self.split_optimizer_states(
2444-
sorted_id_tensor=sorted_id_tensor,
2445-
no_snapshot=no_snapshot,
2446-
should_flush=should_flush,
2505+
states_list = self.split_optimizer_states(
2506+
sorted_id_tensor=sorted_id_tensor,
2507+
no_snapshot=no_snapshot,
2508+
should_flush=should_flush,
2509+
)
2510+
2511+
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2512+
keys = ["momentum1"]
2513+
elif self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
2514+
keys = ["momentum1", "momentum2"]
2515+
else:
2516+
raise NotImplementedError(
2517+
f"Getting optimizer states is not supported for {self.optimizer}"
24472518
)
2448-
]
2519+
2520+
return [dict(zip(keys, states)) for states in states_list]
24492521

24502522
@torch.jit.export
24512523
def debug_split_embedding_weights(self) -> List[torch.Tensor]:
@@ -2460,7 +2532,7 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
24602532
splits = []
24612533
get_event = torch.cuda.Event()
24622534

2463-
for t, (row, dim) in enumerate(self.embedding_specs):
2535+
for t, (row, _) in enumerate(self.embedding_specs):
24642536
weights = torch.empty(
24652537
(row, self.max_D), dtype=self.weights_precision.as_dtype()
24662538
)

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def execute_ssd_forward_(
757757

758758
def split_optimizer_states_(
759759
self, emb: SSDTableBatchedEmbeddingBags
760-
) -> List[torch.Tensor]:
760+
) -> List[List[torch.Tensor]]:
761761
_, bucket_asc_ids_list, _ = emb.split_embedding_weights(
762762
no_snapshot=False, should_flush=True
763763
)
@@ -962,7 +962,7 @@ def test_ssd_backward_adagrad(
962962
# pyre-fixme[16]: Optional type has no attribute `float`.
963963
ref_optimizer_state = emb_ref[f].weight.grad.float().to_dense().pow(2)
964964
torch.testing.assert_close(
965-
split_optimizer_states[t].float(),
965+
split_optimizer_states[t][0].float(),
966966
ref_optimizer_state.mean(dim=1),
967967
atol=tolerance,
968968
rtol=tolerance,
@@ -978,7 +978,7 @@ def test_ssd_backward_adagrad(
978978
emb_r.weight.float(),
979979
value=-lr,
980980
tensor1=emb_r.weight.grad.float().to_dense(),
981-
tensor2=split_optimizer_states[t]
981+
tensor2=split_optimizer_states[t][0]
982982
.float()
983983
.sqrt_()
984984
.add_(eps)
@@ -1113,7 +1113,10 @@ def test_ssd_emb_state_dict(
11131113
emb_r.weight.float(),
11141114
value=-lr,
11151115
tensor1=emb_r.weight.grad.float().to_dense(), # pyre-ignore[16]
1116-
tensor2=split_optimizer_states[table_index]
1116+
# NOTE: The [0] index is a hack since the test is fixed to use
1117+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1118+
# be upgraded in the future to support multiple optimizers
1119+
tensor2=split_optimizer_states[table_index][0]
11171120
.float()
11181121
.sqrt_()
11191122
.add_(eps)
@@ -1188,7 +1191,8 @@ def execute_ssd_cache_pipeline_( # noqa C901
11881191
)
11891192

11901193
optimizer_states_ref = [
1191-
s.clone().float() for s in self.split_optimizer_states_(emb)
1194+
[s.clone().float() for s in states]
1195+
for states in self.split_optimizer_states_(emb)
11921196
]
11931197

11941198
Es = [emb.embedding_specs[t][0] for t in range(T)]
@@ -1334,8 +1338,11 @@ def _prefetch(b_it: int) -> int:
13341338
# Compare optimizer states
13351339
split_optimizer_states = self.split_optimizer_states_(emb)
13361340
for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map):
1337-
optim_state_r = optimizer_states_ref[t]
1338-
optim_state_t = split_optimizer_states[t]
1341+
optim_state_r = optimizer_states_ref[t][0]
1342+
# NOTE: The [0] index is a hack since the test is fixed to use
1343+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1344+
# be upgraded in the future to support multiple optimizers
1345+
optim_state_t = split_optimizer_states[t][0]
13391346
emb_r = emb_ref[f]
13401347

13411348
optim_state_r.add_(
@@ -1753,7 +1760,10 @@ def test_kv_emb_state_dict(
17531760
dim=1
17541761
)
17551762
torch.testing.assert_close(
1756-
split_optimizer_states[t].float(),
1763+
# NOTE: The [0] index is a hack since the test is fixed to use
1764+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1765+
# be upgraded in the future to support multiple optimizers
1766+
split_optimizer_states[t][0].float(),
17571767
ref_opt_mean.cpu(),
17581768
atol=tolerance,
17591769
rtol=tolerance,
@@ -1799,8 +1809,11 @@ def test_kv_emb_state_dict(
17991809
.to_dense()[bucket_asc_ids_list[table_index].view(-1)]
18001810
)
18011811
self.assertLess(table_index, len(emb_state_dict_list))
1802-
assert len(split_optimizer_states[table_index]) == num_ids
1803-
opt = split_optimizer_states[table_index]
1812+
assert len(split_optimizer_states[table_index][0]) == num_ids
1813+
# NOTE: The [0] index is a hack since the test is fixed to use
1814+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1815+
# be upgraded in the future to support multiple optimizers
1816+
opt = split_optimizer_states[table_index][0]
18041817
new_ref_weight = torch.addcdiv(
18051818
emb_r_w.float(),
18061819
value=-lr,
@@ -1985,7 +1998,10 @@ def test_kv_opt_state_w_offloading(
19851998
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `__getitem__`.
19861999
ref_kv_opt = ref_optimizer_state[bucket_asc_ids_list[t]].view(-1)
19872000
torch.testing.assert_close(
1988-
split_optimizer_states[t].float(),
2001+
# NOTE: The [0] index is a hack since the test is fixed to use
2002+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2003+
# be upgraded in the future to support multiple optimizers
2004+
split_optimizer_states[t][0].float(),
19892005
ref_kv_opt,
19902006
atol=tolerance,
19912007
rtol=tolerance,
@@ -2031,8 +2047,11 @@ def test_kv_opt_state_w_offloading(
20312047
.to_dense()[bucket_asc_ids_list[table_index].view(-1)]
20322048
)
20332049
self.assertLess(table_index, len(emb_state_dict_list))
2034-
assert len(split_optimizer_states[table_index]) == num_ids
2035-
opt = split_optimizer_states[table_index]
2050+
# NOTE: The [0] index is a hack since the test is fixed to use
2051+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2052+
# be upgraded in the future to support multiple optimizers
2053+
assert len(split_optimizer_states[table_index][0]) == num_ids
2054+
opt = split_optimizer_states[table_index][0]
20362055
new_ref_weight = torch.addcdiv(
20372056
emb_r_w.float(),
20382057
value=-lr,
@@ -2221,7 +2240,10 @@ def test_kv_state_dict_w_backend_return_whole_row(
22212240
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `__getitem__`.
22222241
ref_kv_opt = ref_optimizer_state[bucket_asc_ids_list[t]].view(-1)
22232242
opt = (
2224-
split_optimizer_states[t]
2243+
# NOTE: The [0] index is a hack since the test is fixed to use
2244+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2245+
# be upgraded in the future to support multiple optimizers
2246+
split_optimizer_states[t][0]
22252247
.narrow(0, 0, bucket_asc_ids_list[t].size(0))
22262248
.view(-1)
22272249
.view(torch.float32)
@@ -2276,7 +2298,10 @@ def test_kv_state_dict_w_backend_return_whole_row(
22762298
.to_dense()[bucket_asc_ids_list[table_index].view(-1)]
22772299
)
22782300
self.assertLess(table_index, len(emb_state_dict_list))
2279-
assert split_optimizer_states[table_index].size(0) == num_ids
2301+
# NOTE: The [0] index is a hack since the test is fixed to use
2302+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2303+
# be upgraded in the future to support multiple optimizers
2304+
assert split_optimizer_states[table_index][0].size(0) == num_ids
22802305
new_ref_weight = torch.addcdiv(
22812306
emb_r_w.float(),
22822307
value=-lr,
@@ -2501,9 +2526,12 @@ def test_apply_kv_state_dict(
25012526
# pyre-fixme[16]: Undefined attribute: Item `torch._tensor.Tensor` of `typing.Uni...
25022527
emb_state_dict_list[i].full_tensor()
25032528
)
2529+
# NOTE: The [0] index is a hack since the test is fixed to use
2530+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2531+
# be upgraded in the future to support multiple optimizers
25042532
# pyre-ignore [16]
25052533
emb2._cached_kvzch_data.cached_optimizer_state_per_table[i].copy_(
2506-
split_optimizer_states[i]
2534+
split_optimizer_states[i][0]
25072535
)
25082536
# pyre-ignore [16]
25092537
emb2._cached_kvzch_data.cached_id_tensor_per_table[i].copy_(
@@ -2547,8 +2575,8 @@ def test_apply_kv_state_dict(
25472575
rtol=tolerance,
25482576
)
25492577
torch.testing.assert_close(
2550-
split_optimizer_states[t][sorted_ids.indices],
2551-
split_optimizer_states2[t][sorted_ids2.indices],
2578+
split_optimizer_states[t][0][sorted_ids.indices],
2579+
split_optimizer_states2[t][0][sorted_ids2.indices],
25522580
atol=tolerance,
25532581
rtol=tolerance,
25542582
)
@@ -2820,7 +2848,10 @@ def copy_weights_hook(
28202848
# pyre-fixme[16]: Optional type has no attribute `float`.
28212849
ref_optimizer_state = emb_ref[f].weight.grad.float().to_dense().pow(2)
28222850
torch.testing.assert_close(
2823-
split_optimizer_states[t].float(),
2851+
# NOTE: The [0] index is a hack since the test is fixed to use
2852+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2853+
# be upgraded in the future to support multiple optimizers
2854+
split_optimizer_states[t][0].float(),
28242855
ref_optimizer_state.mean(dim=1),
28252856
atol=tolerance,
28262857
rtol=tolerance,
@@ -3036,7 +3067,10 @@ def copy_opt_states_hook(
30363067
cursor += local_idxes.numel()
30373068

30383069
torch.testing.assert_close(
3039-
split_optimizer_states[t][indices].float(),
3070+
# NOTE: The [0] index is a hack since the test is fixed to use
3071+
# EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
3072+
# be upgraded in the future to support multiple optimizers
3073+
split_optimizer_states[t][0][indices].float(),
30403074
opt_states_per_tb.cpu().float(),
30413075
atol=tolerance,
30423076
rtol=tolerance,

0 commit comments

Comments
 (0)