Skip to content

Commit 79f47fa

Browse files
q10facebook-github-bot
authored andcommitted
Remove debug_split_optimizer_states (#4397)
Summary: Pull Request resolved: #4397 - Remove `debug_split_optimizer_states` from training.py, since it has been superseded by `split_optimizer_states` Reviewed By: duduyi2013 Differential Revision: D77256897 fbshipit-source-id: 2897a7dba9b0477be9ee00feacf1273e6158027e
1 parent 8ba5184 commit 79f47fa

File tree

2 files changed

+30
-76
lines changed

2 files changed

+30
-76
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 13 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,6 @@ def __init__(
940940
)
941941
# pyre-ignore
942942
self.stats_reporter.register_stats(self.l2_num_cache_misses_stats_name)
943-
# pyre-ignore
944943
self.stats_reporter.register_stats(self.l2_num_cache_lookups_stats_name)
945944
self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name)
946945
self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
@@ -1083,7 +1082,7 @@ def _report_duration(
10831082
"""
10841083
recorded_itr, stream_cnt, report_val = self.prefetch_duration_us
10851084
duration = dur_ms
1086-
if time_unit == "us": # pyre-ignore
1085+
if time_unit == "us":
10871086
duration = dur_ms * 1000
10881087
if it_step == recorded_itr:
10891088
report_val = max(report_val, duration)
@@ -1124,7 +1123,6 @@ def record_function_via_dummy_profile_factory(
11241123

11251124
def func(
11261125
name: str,
1127-
# pyre-ignore[2]
11281126
fn: Callable[..., Any],
11291127
*args: Any,
11301128
**kwargs: Any,
@@ -2168,64 +2166,10 @@ def forward(
21682166
)
21692167

21702168
if self.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD:
2171-
# pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
21722169
return invokers.lookup_rowwise_adagrad_ssd.invoke(
21732170
common_args, self.optimizer_args, momentum1
21742171
)
21752172

2176-
@torch.jit.ignore
2177-
def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
2178-
"""
2179-
Returns a list of optimizer states, table_input_id_start, table_input_id_end, split by table
2180-
Testing only
2181-
"""
2182-
(rows, _) = zip(*self.embedding_specs)
2183-
2184-
rows_cumsum = [0] + list(itertools.accumulate(rows))
2185-
if self.kv_zch_params:
2186-
opt_list = []
2187-
table_offset = 0
2188-
for t, row in enumerate(rows):
2189-
# pyre-ignore
2190-
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[t]
2191-
# pyre-ignore
2192-
bucket_size = self.kv_zch_params.bucket_sizes[t]
2193-
table_input_id_start = (
2194-
min(bucket_id_start * bucket_size, row) + table_offset
2195-
)
2196-
table_input_id_end = (
2197-
min(bucket_id_end * bucket_size, row) + table_offset
2198-
)
2199-
2200-
# TODO: this is a hack for preallocated optimizer, update this part once we have optimizer offloading
2201-
unlinearized_id_tensor = self._ssd_db.get_keys_in_range_by_snapshot(
2202-
table_input_id_start,
2203-
table_input_id_end,
2204-
0, # no need for table offest, as optimizer is preallocated using table offset
2205-
None,
2206-
)
2207-
sorted_offsets, _ = torch.sort(unlinearized_id_tensor.view(-1))
2208-
opt_list.append(
2209-
(
2210-
self.momentum1_dev.detach()[sorted_offsets],
2211-
table_input_id_start - table_offset,
2212-
table_input_id_end - table_offset,
2213-
)
2214-
)
2215-
table_offset += row
2216-
return opt_list
2217-
else:
2218-
return [
2219-
(
2220-
self.momentum1_dev.detach()[
2221-
rows_cumsum[t] : rows_cumsum[t + 1]
2222-
].view(row),
2223-
-1,
2224-
-1,
2225-
)
2226-
for t, row in enumerate(rows)
2227-
]
2228-
22292173
@torch.jit.ignore
22302174
def _split_optimizer_states_non_kv_zch(
22312175
self,
@@ -2344,6 +2288,7 @@ def split_optimizer_states(
23442288
table_offset += emb_height
23452289
logging.info(
23462290
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
2291+
# pyre-ignore [16]
23472292
f"num ids list: {[ids.numel() for ids in sorted_id_tensor]}"
23482293
)
23492294
return opt_list
@@ -2623,6 +2568,7 @@ def split_embedding_weights(
26232568
)
26242569
if self.kv_zch_params is not None:
26252570
logging.info(
2571+
# pyre-ignore [16]
26262572
f"num ids list: {[ids.numel() for ids in bucket_sorted_id_splits]}"
26272573
)
26282574

@@ -2946,7 +2892,7 @@ def _report_ssd_l1_cache_stats(self) -> None:
29462892
/ passed_steps
29472893
),
29482894
)
2949-
# pyre-ignore
2895+
29502896
self.stats_reporter.report_data_amount(
29512897
iteration_step=self.step,
29522898
event_name=f"ssd_tbe.prefetch.cache_stats.{stat_index.name.lower()}",
@@ -2973,35 +2919,35 @@ def _report_ssd_io_stats(self) -> None:
29732919
bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration[3]
29742920
flush_write_dur = ssd_io_duration[4]
29752921

2976-
# pyre-ignore
2922+
# pyre-ignore [16]
29772923
self.stats_reporter.report_duration(
29782924
iteration_step=self.step,
29792925
event_name="ssd.io_duration.read_us",
29802926
duration_ms=ssd_read_dur_us,
29812927
time_unit="us",
29822928
)
2983-
# pyre-ignore
2929+
29842930
self.stats_reporter.report_duration(
29852931
iteration_step=self.step,
29862932
event_name="ssd.io_duration.write.fwd_rocksdb_read_us",
29872933
duration_ms=fwd_rocksdb_read_dur,
29882934
time_unit="us",
29892935
)
2990-
# pyre-ignore
2936+
29912937
self.stats_reporter.report_duration(
29922938
iteration_step=self.step,
29932939
event_name="ssd.io_duration.write.fwd_l1_eviction_us",
29942940
duration_ms=fwd_l1_eviction_dur,
29952941
time_unit="us",
29962942
)
2997-
# pyre-ignore
2943+
29982944
self.stats_reporter.report_duration(
29992945
iteration_step=self.step,
30002946
event_name="ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us",
30012947
duration_ms=bwd_l1_cnflct_miss_write_back_dur,
30022948
time_unit="us",
30032949
)
3004-
# pyre-ignore
2950+
30052951
self.stats_reporter.report_duration(
30062952
iteration_step=self.step,
30072953
event_name="ssd.io_duration.write.flush_write_us",
@@ -3023,25 +2969,25 @@ def _report_ssd_mem_usage(
30232969
memtable_usage = mem_usage_list[2]
30242970
block_cache_pinned_usage = mem_usage_list[3]
30252971

3026-
# pyre-ignore
2972+
# pyre-ignore [16]
30272973
self.stats_reporter.report_data_amount(
30282974
iteration_step=self.step,
30292975
event_name="ssd.mem_usage.block_cache",
30302976
data_bytes=block_cache_usage,
30312977
)
3032-
# pyre-ignore
2978+
30332979
self.stats_reporter.report_data_amount(
30342980
iteration_step=self.step,
30352981
event_name="ssd.mem_usage.estimate_table_reader",
30362982
data_bytes=estimate_table_reader_usage,
30372983
)
3038-
# pyre-ignore
2984+
30392985
self.stats_reporter.report_data_amount(
30402986
iteration_step=self.step,
30412987
event_name="ssd.mem_usage.memtable",
30422988
data_bytes=memtable_usage,
30432989
)
3044-
# pyre-ignore
2990+
30452991
self.stats_reporter.report_data_amount(
30462992
iteration_step=self.step,
30472993
event_name="ssd.mem_usage.block_cache_pinned",

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,6 @@ def generate_ssd_tbes(
526526
if share_table:
527527
# autograd with shared embedding only works for exact
528528
table_to_replicate = T // 2
529-
# pyre-fixme[6]: For 2nd param expected `Embedding` but got
530-
# `Union[Embedding, EmbeddingBag]`.
531529
feature_table_map.insert(table_to_replicate, table_to_replicate)
532530
emb_ref.insert(table_to_replicate, emb_ref[table_to_replicate])
533531

@@ -746,6 +744,17 @@ def execute_ssd_forward_(
746744
)
747745
return output_ref_list, output
748746

747+
def split_optimizer_states_(
748+
self, emb: SSDTableBatchedEmbeddingBags
749+
) -> List[torch.Tensor]:
750+
_, bucket_asc_ids_list, _ = emb.split_embedding_weights(
751+
no_snapshot=False, should_flush=True
752+
)
753+
754+
return emb.split_optimizer_states(
755+
bucket_asc_ids_list, no_snapshot=False, should_flush=True
756+
)
757+
749758
@given(
750759
**default_st, backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM])
751760
)
@@ -937,7 +946,7 @@ def test_ssd_backward_adagrad(
937946
)
938947

939948
# Compare optimizer states
940-
split_optimizer_states = [s for (s, _, _) in emb.debug_split_optimizer_states()]
949+
split_optimizer_states = self.split_optimizer_states_(emb)
941950
for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map):
942951
# pyre-fixme[16]: Optional type has no attribute `float`.
943952
ref_optimizer_state = emb_ref[f].weight.grad.float().to_dense().pow(2)
@@ -1079,7 +1088,7 @@ def test_ssd_emb_state_dict(
10791088
else 1.0e-2
10801089
)
10811090

1082-
split_optimizer_states = [s for (s, _, _) in emb.debug_split_optimizer_states()]
1091+
split_optimizer_states = self.split_optimizer_states_(emb)
10831092
emb.flush()
10841093

10851094
# Compare emb state dict with expected values from nn.EmbeddingBag
@@ -1168,7 +1177,7 @@ def execute_ssd_cache_pipeline_( # noqa C901
11681177
)
11691178

11701179
optimizer_states_ref = [
1171-
s.clone().float() for (s, _, _) in emb.debug_split_optimizer_states()
1180+
s.clone().float() for s in self.split_optimizer_states_(emb)
11721181
]
11731182

11741183
Es = [emb.embedding_specs[t][0] for t in range(T)]
@@ -1312,15 +1321,12 @@ def _prefetch(b_it: int) -> int:
13121321
emb.flush()
13131322

13141323
# Compare optimizer states
1315-
split_optimizer_states = [
1316-
s for (s, _, _) in emb.debug_split_optimizer_states()
1317-
]
1324+
split_optimizer_states = self.split_optimizer_states_(emb)
13181325
for f, t in self.get_physical_table_arg_indices_(emb.feature_table_map):
13191326
optim_state_r = optimizer_states_ref[t]
13201327
optim_state_t = split_optimizer_states[t]
13211328
emb_r = emb_ref[f]
13221329

1323-
# pyre-fixme[16]: Optional type has no attribute `float`.
13241330
optim_state_r.add_(
13251331
# pyre-fixme[16]: `Optional` has no attribute `float`.
13261332
emb_r.weight.grad.float()
@@ -2252,7 +2258,9 @@ def test_apply_kv_state_dict(
22522258
)
22532259

22542260
torch.testing.assert_close(
2261+
# pyre-ignore [16]
22552262
emb_state_dict_list[t].full_tensor()[sorted_ids.indices],
2263+
# pyre-ignore [16]
22562264
emb_state_dict_list2[t].full_tensor()[sorted_ids2.indices],
22572265
atol=tolerance,
22582266
rtol=tolerance,

0 commit comments

Comments
 (0)