Skip to content

add monitroing metrics for dram cache perf #4383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 192 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,12 @@ def __init__(
self.l2_cache_capacity_stats_name: str = (
f"l2_cache.mem.tbe_id{tbe_unique_id}.capacity_bytes"
)
self.dram_kv_actual_used_chunk_bytes_stats_name: str = (
f"dram_kv.mem.tbe_id{tbe_unique_id}.actual_used_chunk_bytes"
)
self.dram_kv_allocated_bytes_stats_name: str = (
f"dram_kv.mem.tbe_id{tbe_unique_id}.allocated_bytes"
)
if self.stats_reporter:
self.ssd_prefetch_read_timer = AsyncSeriesTimer(
functools.partial(
Expand All @@ -939,6 +945,10 @@ def __init__(
self.stats_reporter.register_stats(self.l2_num_cache_evictions_stats_name)
self.stats_reporter.register_stats(self.l2_cache_free_mem_stats_name)
self.stats_reporter.register_stats(self.l2_cache_capacity_stats_name)
self.stats_reporter.register_stats(self.dram_kv_allocated_bytes_stats_name)
self.stats_reporter.register_stats(
self.dram_kv_actual_used_chunk_bytes_stats_name
)

self.bounds_check_version: int = get_bounds_check_version_for_platform()

Expand Down Expand Up @@ -1890,7 +1900,7 @@ def _prefetch( # noqa C901
self.ssd_cache_stats = torch.add(
self.ssd_cache_stats, self.local_ssd_cache_stats
)
self._report_ssd_stats()
self._report_kv_backend_stats()

# Fetch data from SSD
if linear_cache_indices.numel() > 0:
Expand Down Expand Up @@ -2881,7 +2891,7 @@ def prepare_inputs(
return indices, offsets, per_sample_weights, vbe_metadata

@torch.jit.ignore
def _report_ssd_stats(self) -> None:
def _report_kv_backend_stats(self) -> None:
"""
All ssd stats report function entrance
"""
Expand All @@ -2896,6 +2906,8 @@ def _report_ssd_stats(self) -> None:
self._report_ssd_io_stats()
self._report_ssd_mem_usage()
self._report_l2_cache_perf_stats()
if self.backend_type == BackendType.DRAM:
self._report_dram_kv_perf_stats()

@torch.jit.ignore
def _report_ssd_l1_cache_stats(self) -> None:
Expand Down Expand Up @@ -3162,6 +3174,184 @@ def _report_l2_cache_perf_stats(self) -> None:
time_unit="us",
)

@torch.jit.ignore
def _report_dram_kv_perf_stats(self) -> None:
"""
EmbeddingKVDB will hold stats for DRAM cache performance in fwd/bwd
this function fetch the stats from EmbeddingKVDB and report it with stats_reporter
"""
if self.stats_reporter is None:
return

stats_reporter: TBEStatsReporter = self.stats_reporter
if not stats_reporter.should_report(self.step):
return

dram_kv_perf_stats = self.ssd_db.get_dram_kv_perf(
self.step, stats_reporter.report_interval # pyre-ignore
)

if len(dram_kv_perf_stats) != 22:
logging.error("dram cache perf stats should have 22 elements")
return

dram_read_duration = dram_kv_perf_stats[0]
dram_read_sharding_duration = dram_kv_perf_stats[1]
dram_read_cache_hit_copy_duration = dram_kv_perf_stats[2]
dram_read_fill_row_storage_duration = dram_kv_perf_stats[3]
dram_read_lookup_cache_duration = dram_kv_perf_stats[4]
dram_read_acquire_lock_duration = dram_kv_perf_stats[5]
dram_read_missing_load = dram_kv_perf_stats[6]
dram_write_sharing_duration = dram_kv_perf_stats[7]

dram_fwd_l1_eviction_write_duration = dram_kv_perf_stats[8]
dram_fwd_l1_eviction_write_allocate_duration = dram_kv_perf_stats[9]
dram_fwd_l1_eviction_write_cache_copy_duration = dram_kv_perf_stats[10]
dram_fwd_l1_eviction_write_lookup_cache_duration = dram_kv_perf_stats[11]
dram_fwd_l1_eviction_write_acquire_lock_duration = dram_kv_perf_stats[12]
dram_fwd_l1_eviction_write_missing_load = dram_kv_perf_stats[13]

dram_bwd_l1_cnflct_miss_write_duration = dram_kv_perf_stats[14]
dram_bwd_l1_cnflct_miss_write_allocate_duration = dram_kv_perf_stats[15]
dram_bwd_l1_cnflct_miss_write_cache_copy_duration = dram_kv_perf_stats[16]
dram_bwd_l1_cnflct_miss_write_lookup_cache_duration = dram_kv_perf_stats[17]
dram_bwd_l1_cnflct_miss_write_acquire_lock_duration = dram_kv_perf_stats[18]
dram_bwd_l1_cnflct_miss_write_missing_load = dram_kv_perf_stats[19]

dram_kv_allocated_bytes = dram_kv_perf_stats[20]
dram_kv_actual_used_chunk_bytes = dram_kv_perf_stats[21]

stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_duration_us",
duration_ms=dram_read_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_sharding_duration_us",
duration_ms=dram_read_sharding_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_cache_hit_copy_duration_us",
duration_ms=dram_read_cache_hit_copy_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_fill_row_storage_duration_us",
duration_ms=dram_read_fill_row_storage_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_lookup_cache_duration_us",
duration_ms=dram_read_lookup_cache_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_acquire_lock_duration_us",
duration_ms=dram_read_acquire_lock_duration,
time_unit="us",
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.get.dram_read_missing_load",
data_bytes=dram_read_missing_load,
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_write_sharing_duration_us",
duration_ms=dram_write_sharing_duration,
time_unit="us",
)

stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_duration_us",
duration_ms=dram_fwd_l1_eviction_write_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_allocate_duration_us",
duration_ms=dram_fwd_l1_eviction_write_allocate_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_cache_copy_duration_us",
duration_ms=dram_fwd_l1_eviction_write_cache_copy_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_lookup_cache_duration_us",
duration_ms=dram_fwd_l1_eviction_write_lookup_cache_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_acquire_lock_duration_us",
duration_ms=dram_fwd_l1_eviction_write_acquire_lock_duration,
time_unit="us",
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_fwd_l1_eviction_write_missing_load",
data_bytes=dram_fwd_l1_eviction_write_missing_load,
)

stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_duration_us",
duration_ms=dram_bwd_l1_cnflct_miss_write_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_allocate_duration_us",
duration_ms=dram_bwd_l1_cnflct_miss_write_allocate_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_cache_copy_duration_us",
duration_ms=dram_bwd_l1_cnflct_miss_write_cache_copy_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_lookup_cache_duration_us",
duration_ms=dram_bwd_l1_cnflct_miss_write_lookup_cache_duration,
time_unit="us",
)
stats_reporter.report_duration(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_acquire_lock_duration_us",
duration_ms=dram_bwd_l1_cnflct_miss_write_acquire_lock_duration,
time_unit="us",
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name="dram_kv.perf.set.dram_bwd_l1_cnflct_miss_write_missing_load",
data_bytes=dram_bwd_l1_cnflct_miss_write_missing_load,
)

stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_allocated_bytes_stats_name,
data_bytes=dram_kv_allocated_bytes,
)
stats_reporter.report_data_amount(
iteration_step=self.step,
event_name=self.dram_kv_actual_used_chunk_bytes_stats_name,
data_bytes=dram_kv_actual_used_chunk_bytes,
)

# pyre-ignore
def _recording_to_timer(
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
Expand Down
17 changes: 14 additions & 3 deletions fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,24 @@ class SynchronizedShardedMap {
return shards_.size();
}

auto getUsedMemSize() const {
auto getUsedMemSizeInBytes() const {
size_t used_mem_size = 0;
size_t block_size = mempools_[0]->get_aligned_block_size();
for (size_t i = 0; i < shards_.size(); ++i) {
auto rlmap = shards_[i].rlock();
int64_t mempool_idx = i % mempools_.size();
// only calculate the sizes of K, V and block that are used
used_mem_size += rlmap->size() * (sizeof(K) + sizeof(V) + block_size);
if (mempools_[mempool_idx]->get_allocated_chunk_bytes() > 0) {
auto rlmap = shards_[i].rlock();
used_mem_size += rlmap->size() * (sizeof(K) + sizeof(V) + block_size);
}
}
return used_mem_size;
}

auto getActualUsedChunkInBytes() const {
size_t used_mem_size = 0;
for (size_t i = 0; i < mempools_.size(); ++i) {
used_mem_size += mempools_[i]->get_allocated_chunk_bytes();
}
return used_mem_size;
}
Expand Down
Loading
Loading