@@ -940,7 +940,6 @@ def __init__(
940
940
)
941
941
# pyre-ignore
942
942
self .stats_reporter .register_stats (self .l2_num_cache_misses_stats_name )
943
- # pyre-ignore
944
943
self .stats_reporter .register_stats (self .l2_num_cache_lookups_stats_name )
945
944
self .stats_reporter .register_stats (self .l2_num_cache_evictions_stats_name )
946
945
self .stats_reporter .register_stats (self .l2_cache_free_mem_stats_name )
@@ -1083,7 +1082,7 @@ def _report_duration(
1083
1082
"""
1084
1083
recorded_itr , stream_cnt , report_val = self .prefetch_duration_us
1085
1084
duration = dur_ms
1086
- if time_unit == "us" : # pyre-ignore
1085
+ if time_unit == "us" :
1087
1086
duration = dur_ms * 1000
1088
1087
if it_step == recorded_itr :
1089
1088
report_val = max (report_val , duration )
@@ -1124,7 +1123,6 @@ def record_function_via_dummy_profile_factory(
1124
1123
1125
1124
def func (
1126
1125
name : str ,
1127
- # pyre-ignore[2]
1128
1126
fn : Callable [..., Any ],
1129
1127
* args : Any ,
1130
1128
** kwargs : Any ,
@@ -2168,64 +2166,10 @@ def forward(
2168
2166
)
2169
2167
2170
2168
if self .optimizer == OptimType .EXACT_ROWWISE_ADAGRAD :
2171
- # pyre-fixme[7]: Expected `Tensor` but got implicit return value of `None`.
2172
2169
return invokers .lookup_rowwise_adagrad_ssd .invoke (
2173
2170
common_args , self .optimizer_args , momentum1
2174
2171
)
2175
2172
2176
- @torch .jit .ignore
2177
- def debug_split_optimizer_states (self ) -> List [Tuple [torch .Tensor , int , int ]]:
2178
- """
2179
- Returns a list of optimizer states, table_input_id_start, table_input_id_end, split by table
2180
- Testing only
2181
- """
2182
- (rows , _ ) = zip (* self .embedding_specs )
2183
-
2184
- rows_cumsum = [0 ] + list (itertools .accumulate (rows ))
2185
- if self .kv_zch_params :
2186
- opt_list = []
2187
- table_offset = 0
2188
- for t , row in enumerate (rows ):
2189
- # pyre-ignore
2190
- bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [t ]
2191
- # pyre-ignore
2192
- bucket_size = self .kv_zch_params .bucket_sizes [t ]
2193
- table_input_id_start = (
2194
- min (bucket_id_start * bucket_size , row ) + table_offset
2195
- )
2196
- table_input_id_end = (
2197
- min (bucket_id_end * bucket_size , row ) + table_offset
2198
- )
2199
-
2200
- # TODO: this is a hack for preallocated optimizer, update this part once we have optimizer offloading
2201
- unlinearized_id_tensor = self ._ssd_db .get_keys_in_range_by_snapshot (
2202
- table_input_id_start ,
2203
- table_input_id_end ,
2204
- 0 , # no need for table offest, as optimizer is preallocated using table offset
2205
- None ,
2206
- )
2207
- sorted_offsets , _ = torch .sort (unlinearized_id_tensor .view (- 1 ))
2208
- opt_list .append (
2209
- (
2210
- self .momentum1_dev .detach ()[sorted_offsets ],
2211
- table_input_id_start - table_offset ,
2212
- table_input_id_end - table_offset ,
2213
- )
2214
- )
2215
- table_offset += row
2216
- return opt_list
2217
- else :
2218
- return [
2219
- (
2220
- self .momentum1_dev .detach ()[
2221
- rows_cumsum [t ] : rows_cumsum [t + 1 ]
2222
- ].view (row ),
2223
- - 1 ,
2224
- - 1 ,
2225
- )
2226
- for t , row in enumerate (rows )
2227
- ]
2228
-
2229
2173
@torch .jit .ignore
2230
2174
def _split_optimizer_states_non_kv_zch (
2231
2175
self ,
@@ -2344,6 +2288,7 @@ def split_optimizer_states(
2344
2288
table_offset += emb_height
2345
2289
logging .info (
2346
2290
f"KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms, "
2291
+ # pyre-ignore [16]
2347
2292
f"num ids list: { [ids .numel () for ids in sorted_id_tensor ]} "
2348
2293
)
2349
2294
return opt_list
@@ -2623,6 +2568,7 @@ def split_embedding_weights(
2623
2568
)
2624
2569
if self .kv_zch_params is not None :
2625
2570
logging .info (
2571
+ # pyre-ignore [16]
2626
2572
f"num ids list: { [ids .numel () for ids in bucket_sorted_id_splits ]} "
2627
2573
)
2628
2574
@@ -2946,7 +2892,7 @@ def _report_ssd_l1_cache_stats(self) -> None:
2946
2892
/ passed_steps
2947
2893
),
2948
2894
)
2949
- # pyre-ignore
2895
+
2950
2896
self .stats_reporter .report_data_amount (
2951
2897
iteration_step = self .step ,
2952
2898
event_name = f"ssd_tbe.prefetch.cache_stats.{ stat_index .name .lower ()} " ,
@@ -2973,35 +2919,35 @@ def _report_ssd_io_stats(self) -> None:
2973
2919
bwd_l1_cnflct_miss_write_back_dur = ssd_io_duration [3 ]
2974
2920
flush_write_dur = ssd_io_duration [4 ]
2975
2921
2976
- # pyre-ignore
2922
+ # pyre-ignore [16]
2977
2923
self .stats_reporter .report_duration (
2978
2924
iteration_step = self .step ,
2979
2925
event_name = "ssd.io_duration.read_us" ,
2980
2926
duration_ms = ssd_read_dur_us ,
2981
2927
time_unit = "us" ,
2982
2928
)
2983
- # pyre-ignore
2929
+
2984
2930
self .stats_reporter .report_duration (
2985
2931
iteration_step = self .step ,
2986
2932
event_name = "ssd.io_duration.write.fwd_rocksdb_read_us" ,
2987
2933
duration_ms = fwd_rocksdb_read_dur ,
2988
2934
time_unit = "us" ,
2989
2935
)
2990
- # pyre-ignore
2936
+
2991
2937
self .stats_reporter .report_duration (
2992
2938
iteration_step = self .step ,
2993
2939
event_name = "ssd.io_duration.write.fwd_l1_eviction_us" ,
2994
2940
duration_ms = fwd_l1_eviction_dur ,
2995
2941
time_unit = "us" ,
2996
2942
)
2997
- # pyre-ignore
2943
+
2998
2944
self .stats_reporter .report_duration (
2999
2945
iteration_step = self .step ,
3000
2946
event_name = "ssd.io_duration.write.bwd_l1_cnflct_miss_write_back_us" ,
3001
2947
duration_ms = bwd_l1_cnflct_miss_write_back_dur ,
3002
2948
time_unit = "us" ,
3003
2949
)
3004
- # pyre-ignore
2950
+
3005
2951
self .stats_reporter .report_duration (
3006
2952
iteration_step = self .step ,
3007
2953
event_name = "ssd.io_duration.write.flush_write_us" ,
@@ -3023,25 +2969,25 @@ def _report_ssd_mem_usage(
3023
2969
memtable_usage = mem_usage_list [2 ]
3024
2970
block_cache_pinned_usage = mem_usage_list [3 ]
3025
2971
3026
- # pyre-ignore
2972
+ # pyre-ignore [16]
3027
2973
self .stats_reporter .report_data_amount (
3028
2974
iteration_step = self .step ,
3029
2975
event_name = "ssd.mem_usage.block_cache" ,
3030
2976
data_bytes = block_cache_usage ,
3031
2977
)
3032
- # pyre-ignore
2978
+
3033
2979
self .stats_reporter .report_data_amount (
3034
2980
iteration_step = self .step ,
3035
2981
event_name = "ssd.mem_usage.estimate_table_reader" ,
3036
2982
data_bytes = estimate_table_reader_usage ,
3037
2983
)
3038
- # pyre-ignore
2984
+
3039
2985
self .stats_reporter .report_data_amount (
3040
2986
iteration_step = self .step ,
3041
2987
event_name = "ssd.mem_usage.memtable" ,
3042
2988
data_bytes = memtable_usage ,
3043
2989
)
3044
- # pyre-ignore
2990
+
3045
2991
self .stats_reporter .report_data_amount (
3046
2992
iteration_step = self .step ,
3047
2993
event_name = "ssd.mem_usage.block_cache_pinned" ,
0 commit comments