@@ -915,6 +915,12 @@ def __init__(
915
915
self .l2_cache_capacity_stats_name : str = (
916
916
f"l2_cache.mem.tbe_id{ tbe_unique_id } .capacity_bytes"
917
917
)
918
+ self .dram_kv_actual_used_chunk_bytes_stats_name : str = (
919
+ f"dram_kv.mem.tbe_id{ tbe_unique_id } .actual_used_chunk_bytes"
920
+ )
921
+ self .dram_kv_allocated_bytes_stats_name : str = (
922
+ f"dram_kv.mem.tbe_id{ tbe_unique_id } .allocated_bytes"
923
+ )
918
924
if self .stats_reporter :
919
925
self .ssd_prefetch_read_timer = AsyncSeriesTimer (
920
926
functools .partial (
@@ -939,6 +945,10 @@ def __init__(
939
945
self .stats_reporter .register_stats (self .l2_num_cache_evictions_stats_name )
940
946
self .stats_reporter .register_stats (self .l2_cache_free_mem_stats_name )
941
947
self .stats_reporter .register_stats (self .l2_cache_capacity_stats_name )
948
+ self .stats_reporter .register_stats (self .dram_kv_allocated_bytes_stats_name )
949
+ self .stats_reporter .register_stats (
950
+ self .dram_kv_actual_used_chunk_bytes_stats_name
951
+ )
942
952
943
953
self .bounds_check_version : int = get_bounds_check_version_for_platform ()
944
954
@@ -1890,7 +1900,7 @@ def _prefetch( # noqa C901
1890
1900
self .ssd_cache_stats = torch .add (
1891
1901
self .ssd_cache_stats , self .local_ssd_cache_stats
1892
1902
)
1893
- self ._report_ssd_stats ()
1903
+ self ._report_kv_backend_stats ()
1894
1904
1895
1905
# Fetch data from SSD
1896
1906
if linear_cache_indices .numel () > 0 :
@@ -2881,7 +2891,7 @@ def prepare_inputs(
2881
2891
return indices , offsets , per_sample_weights , vbe_metadata
2882
2892
2883
2893
@torch .jit .ignore
2884
- def _report_ssd_stats (self ) -> None :
2894
+ def _report_kv_backend_stats (self ) -> None :
2885
2895
"""
2886
2896
All ssd stats report function entrance
2887
2897
"""
@@ -2896,6 +2906,8 @@ def _report_ssd_stats(self) -> None:
2896
2906
self ._report_ssd_io_stats ()
2897
2907
self ._report_ssd_mem_usage ()
2898
2908
self ._report_l2_cache_perf_stats ()
2909
+ if self .backend_type == BackendType .DRAM :
2910
+ self ._report_dram_kv_perf_stats ()
2899
2911
2900
2912
@torch .jit .ignore
2901
2913
def _report_ssd_l1_cache_stats (self ) -> None :
@@ -3162,6 +3174,184 @@ def _report_l2_cache_perf_stats(self) -> None:
3162
3174
time_unit = "us" ,
3163
3175
)
3164
3176
3177
+ @torch .jit .ignore
3178
+ def _report_dram_kv_perf_stats (self ) -> None :
3179
+ """
3180
+ EmbeddingKVDB will hold stats for DRAM cache performance in fwd/bwd
3181
+ this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
3182
+ """
3183
+ if self .stats_reporter is None :
3184
+ return
3185
+
3186
+ stats_reporter : TBEStatsReporter = self .stats_reporter
3187
+ if not stats_reporter .should_report (self .step ):
3188
+ return
3189
+
3190
+ dram_kv_perf_stats = self .ssd_db .get_dram_kv_perf (
3191
+ self .step , stats_reporter .report_interval # pyre-ignore
3192
+ )
3193
+
3194
+ if len (dram_kv_perf_stats ) != 22 :
3195
+ logging .error ("dram cache perf stats should have 22 elements" )
3196
+ return
3197
+
3198
+ dram_read_duration = dram_kv_perf_stats [0 ]
3199
+ dram_read_sharding_duration = dram_kv_perf_stats [1 ]
3200
+ dram_read_cache_hit_copy_duration = dram_kv_perf_stats [2 ]
3201
+ dram_read_fill_row_storage_duration = dram_kv_perf_stats [3 ]
3202
+ dram_read_lookup_cache_duration = dram_kv_perf_stats [4 ]
3203
+ dram_read_acquire_lock_duration = dram_kv_perf_stats [5 ]
3204
+ dram_read_missing_load = dram_kv_perf_stats [6 ]
3205
+ dram_write_sharing_duration = dram_kv_perf_stats [7 ]
3206
+
3207
+ dram_fwd_l1_eviction_write_duration = dram_kv_perf_stats [8 ]
3208
+ dram_fwd_l1_eviction_write_allocate_duration = dram_kv_perf_stats [9 ]
3209
+ dram_fwd_l1_eviction_write_cache_copy_duration = dram_kv_perf_stats [10 ]
3210
+ dram_fwd_l1_eviction_write_lookup_cache_duration = dram_kv_perf_stats [11 ]
3211
+ dram_fwd_l1_eviction_write_acquire_lock_duration = dram_kv_perf_stats [12 ]
3212
+ dram_fwd_l1_eviction_write_missing_load = dram_kv_perf_stats [13 ]
3213
+
3214
+ dram_bwd_l1_cnflct_miss_write_duration = dram_kv_perf_stats [14 ]
3215
+ dram_bwd_l1_cnflct_miss_write_allocate_duration = dram_kv_perf_stats [15 ]
3216
+ dram_bwd_l1_cnflct_miss_write_cache_copy_duration = dram_kv_perf_stats [16 ]
3217
+ dram_bwd_l1_cnflct_miss_write_lookup_cache_duration = dram_kv_perf_stats [17 ]
3218
+ dram_bwd_l1_cnflct_miss_write_acquire_lock_duration = dram_kv_perf_stats [18 ]
3219
+ dram_bwd_l1_cnflct_miss_write_missing_load = dram_kv_perf_stats [19 ]
3220
+
3221
+ dram_kv_allocated_bytes = dram_kv_perf_stats [20 ]
3222
+ dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats [21 ]
3223
+
3224
+ stats_reporter .report_duration (
3225
+ iteration_step = self .step ,
3226
+ event_name = "dram_kv.perf.get.dram_read_duration_us" ,
3227
+ duration_ms = dram_read_duration ,
3228
+ time_unit = "us" ,
3229
+ )
3230
+ stats_reporter .report_duration (
3231
+ iteration_step = self .step ,
3232
+ event_name = "dram_kv.perf.get.dram_read_sharding_duration_us" ,
3233
+ duration_ms = dram_read_sharding_duration ,
3234
+ time_unit = "us" ,
3235
+ )
3236
+ stats_reporter .report_duration (
3237
+ iteration_step = self .step ,
3238
+ event_name = "dram_kv.perf.get.dram_read_cache_hit_copy_duration_us" ,
3239
+ duration_ms = dram_read_cache_hit_copy_duration ,
3240
+ time_unit = "us" ,
3241
+ )
3242
+ stats_reporter .report_duration (
3243
+ iteration_step = self .step ,
3244
+ event_name = "dram_kv.perf.get.dram_read_fill_row_storage_duration_us" ,
3245
+ duration_ms = dram_read_fill_row_storage_duration ,
3246
+ time_unit = "us" ,
3247
+ )
3248
+ stats_reporter .report_duration (
3249
+ iteration_step = self .step ,
3250
+ event_name = "dram_kv.perf.get.dram_read_lookup_cache_duration_us" ,
3251
+ duration_ms = dram_read_lookup_cache_duration ,
3252
+ time_unit = "us" ,
3253
+ )
3254
+ stats_reporter .report_duration (
3255
+ iteration_step = self .step ,
3256
+ event_name = "dram_kv.perf.get.dram_read_acquire_lock_duration_us" ,
3257
+ duration_ms = dram_read_acquire_lock_duration ,
3258
+ time_unit = "us" ,
3259
+ )
3260
+ stats_reporter .report_data_amount (
3261
+ iteration_step = self .step ,
3262
+ event_name = "dram_kv.perf.get.dram_read_missing_load" ,
3263
+ data_bytes = dram_read_missing_load ,
3264
+ )
3265
+ stats_reporter .report_duration (
3266
+ iteration_step = self .step ,
3267
+ event_name = "dram_kv.perf.set.dram_write_sharing_duration_us" ,
3268
+ duration_ms = dram_write_sharing_duration ,
3269
+ time_unit = "us" ,
3270
+ )
3271
+
3272
+ stats_reporter .report_duration (
3273
+ iteration_step = self .step ,
3274
+ event_name = "dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us" ,
3275
+ duration_ms = dram_fwd_l1_eviction_write_duration ,
3276
+ time_unit = "us" ,
3277
+ )
3278
+ stats_reporter .report_duration (
3279
+ iteration_step = self .step ,
3280
+ event_name = "dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us" ,
3281
+ duration_ms = dram_fwd_l1_eviction_write_allocate_duration ,
3282
+ time_unit = "us" ,
3283
+ )
3284
+ stats_reporter .report_duration (
3285
+ iteration_step = self .step ,
3286
+ event_name = "dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us" ,
3287
+ duration_ms = dram_fwd_l1_eviction_write_cache_copy_duration ,
3288
+ time_unit = "us" ,
3289
+ )
3290
+ stats_reporter .report_duration (
3291
+ iteration_step = self .step ,
3292
+ event_name = "dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us" ,
3293
+ duration_ms = dram_fwd_l1_eviction_write_lookup_cache_duration ,
3294
+ time_unit = "us" ,
3295
+ )
3296
+ stats_reporter .report_duration (
3297
+ iteration_step = self .step ,
3298
+ event_name = "dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us" ,
3299
+ duration_ms = dram_fwd_l1_eviction_write_acquire_lock_duration ,
3300
+ time_unit = "us" ,
3301
+ )
3302
+ stats_reporter .report_data_amount (
3303
+ iteration_step = self .step ,
3304
+ event_name = "dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load" ,
3305
+ data_bytes = dram_fwd_l1_eviction_write_missing_load ,
3306
+ )
3307
+
3308
+ stats_reporter .report_duration (
3309
+ iteration_step = self .step ,
3310
+ event_name = "dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us" ,
3311
+ duration_ms = dram_bwd_l1_cnflct_miss_write_duration ,
3312
+ time_unit = "us" ,
3313
+ )
3314
+ stats_reporter .report_duration (
3315
+ iteration_step = self .step ,
3316
+ event_name = "dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us" ,
3317
+ duration_ms = dram_bwd_l1_cnflct_miss_write_allocate_duration ,
3318
+ time_unit = "us" ,
3319
+ )
3320
+ stats_reporter .report_duration (
3321
+ iteration_step = self .step ,
3322
+ event_name = "dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us" ,
3323
+ duration_ms = dram_bwd_l1_cnflct_miss_write_cache_copy_duration ,
3324
+ time_unit = "us" ,
3325
+ )
3326
+ stats_reporter .report_duration (
3327
+ iteration_step = self .step ,
3328
+ event_name = "dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us" ,
3329
+ duration_ms = dram_bwd_l1_cnflct_miss_write_lookup_cache_duration ,
3330
+ time_unit = "us" ,
3331
+ )
3332
+ stats_reporter .report_duration (
3333
+ iteration_step = self .step ,
3334
+ event_name = "dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us" ,
3335
+ duration_ms = dram_bwd_l1_cnflct_miss_write_acquire_lock_duration ,
3336
+ time_unit = "us" ,
3337
+ )
3338
+ stats_reporter .report_data_amount (
3339
+ iteration_step = self .step ,
3340
+ event_name = "dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load" ,
3341
+ data_bytes = dram_bwd_l1_cnflct_miss_write_missing_load ,
3342
+ )
3343
+
3344
+ stats_reporter .report_data_amount (
3345
+ iteration_step = self .step ,
3346
+ event_name = self .dram_kv_allocated_bytes_stats_name ,
3347
+ data_bytes = dram_kv_allocated_bytes ,
3348
+ )
3349
+ stats_reporter .report_data_amount (
3350
+ iteration_step = self .step ,
3351
+ event_name = self .dram_kv_actual_used_chunk_bytes_stats_name ,
3352
+ data_bytes = dram_kv_actual_used_chunk_bytes ,
3353
+ )
3354
+
3165
3355
# pyre-ignore
3166
3356
def _recording_to_timer (
3167
3357
self , timer : Optional [AsyncSeriesTimer ], ** kwargs : Any
0 commit comments