Skip to content

Commit 9e343d1

Browse files
YuzeDaiMetafacebook-github-bot
authored andcommitted
Introduce Weighted UVM Caching Stats Report (#1623)
Summary: Pull Request resolved: #1623 previous cache stats report (e.g. cache miss rate, unique rate, etc) didn't take the number of requested indices of each TBE op into consideration, which could easily cause inaccurate cache miss rate than the real situation, especially with unbalance-requested TBEs. Now we update the cache stats report using the weighted caching ratios by accumulating all the TBE's cache data within the same round of lookup call. Reviewed By: doehyun Differential Revision: D43729139 fbshipit-source-id: 491b8613e70cf17c3d0f89dbd5c808ee9fe33587
1 parent 6a63116 commit 9e343d1

File tree

1 file changed

+70
-16
lines changed

1 file changed

+70
-16
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_host.cpp

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <ATen/core/op_registration/op_registration.h>
1010
#include <ATen/cuda/CUDAContext.h>
1111
#include <torch/library.h>
12+
#include <algorithm>
1213
#include "c10/core/ScalarType.h"
1314
#ifdef FBCODE_CAFFE2
1415
#include "common/stats/Stats.h"
@@ -73,27 +74,61 @@ void process_uvm_cache_stats(
7374
const int64_t total_cache_hash_size,
7475
const int64_t call_count,
7576
const bool gather_uvm_stats,
76-
const Tensor& uvm_cache_stats) {
77+
const Tensor& uvm_cache_stats,
78+
const bool populate_uvm_stats) {
7779
if (gather_uvm_stats) {
80+
static std::mutex cache_mutex;
81+
82+
// uvm_cache_stats_counters is a vector of size 4, storing the cumulated
83+
// cache stats. Each element represents different counter respectively:
84+
// uvm_cache_stats_counters[0]: num_req_indices
85+
// uvm_cache_stats_counters[1]: num_unique_indices
86+
// uvm_cache_stats_counters[2]: num_unique_misses
87+
// uvm_cache_stats_counters[3]: num_unique_conflict_misses
88+
// They should be zero-out after the calculated rates are populated into
89+
// cache counters.
90+
static std::vector<int64_t> uvm_cache_stats_counters(4);
91+
7892
// Export cache stats.
7993
auto uvm_cache_stats_cpu = uvm_cache_stats.cpu();
8094
auto* uvm_cache_stats_ptr = uvm_cache_stats_cpu.data_ptr<int32_t>();
8195
if (uvm_cache_stats_ptr[1] > 0) {
8296
// Report cache stats in per-mille.
83-
double num_requested_indices =
84-
static_cast<double>(uvm_cache_stats_ptr[1]);
85-
double unique_rate = static_cast<double>(uvm_cache_stats_ptr[2] * 1000) /
86-
num_requested_indices;
87-
double unique_miss_rate =
88-
static_cast<double>(uvm_cache_stats_ptr[3] * 1000) /
89-
num_requested_indices;
90-
double unique_conflict_miss_rate =
91-
static_cast<double>(uvm_cache_stats_ptr[4] * 1000) /
92-
num_requested_indices;
93-
STATS_tbe_uvm_cache_unique_rate.addValue(unique_rate);
94-
STATS_tbe_uvm_cache_unique_miss_rate.addValue(unique_miss_rate);
95-
STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue(
96-
unique_conflict_miss_rate);
97+
{
98+
// Add cache stats values into the culmulated variables.
99+
std::lock_guard<std::mutex> guard(cache_mutex);
100+
std::transform(
101+
uvm_cache_stats_counters.begin(),
102+
uvm_cache_stats_counters.end(),
103+
uvm_cache_stats_ptr + 1,
104+
uvm_cache_stats_counters.begin(),
105+
std::plus<int>());
106+
107+
// Calculate cache related ratios based on the cumulated numbers and
108+
// push them into the counter pools.
109+
if (populate_uvm_stats && uvm_cache_stats_counters[0] > 0) {
110+
double unique_rate =
111+
static_cast<double>(uvm_cache_stats_counters[1]) /
112+
uvm_cache_stats_counters[0] * 1000;
113+
double unique_miss_rate =
114+
static_cast<double>(uvm_cache_stats_counters[2]) /
115+
uvm_cache_stats_counters[0] * 1000;
116+
double unique_conflict_miss_rate =
117+
static_cast<double>(uvm_cache_stats_counters[3]) /
118+
uvm_cache_stats_counters[0] * 1000;
119+
STATS_tbe_uvm_cache_unique_rate.addValue(unique_rate);
120+
STATS_tbe_uvm_cache_unique_miss_rate.addValue(unique_miss_rate);
121+
STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue(
122+
unique_conflict_miss_rate);
123+
124+
// Fill all the elements of the vector uvm_cache_stats_counters as 0
125+
// to zero out the cumulated counters.
126+
std::fill(
127+
uvm_cache_stats_counters.begin(),
128+
uvm_cache_stats_counters.end(),
129+
0);
130+
}
131+
}
97132
}
98133
if (call_count % FLAGS_tbe_uvm_cache_stats_print_out_period == 0) {
99134
LOG(INFO) << "$Stats [" << signature << "] "
@@ -358,6 +393,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
358393
cache_hash_size_cumsum.value(), indices, offsets);
359394

360395
bool gather_uvm_stats = false;
396+
// populate_uvm_stats indicates whether to calculate cache related ratios,
397+
// using the data from cumulated counters, and populate them into the cache
398+
// stats pools to get the percentil stats. We want to calculate the weighted
399+
// cache ratios, taking the # req indices of each TBE as the weight. so we
400+
// will populate stats when we think the current lookup is for the last TBE
401+
// call of the same round.
402+
bool populate_uvm_stats = true;
361403
Tensor uvm_cache_stats =
362404
at::empty({0}, lxu_cache_weights.value().options().dtype(at::kInt));
363405
#ifdef FBCODE_CAFFE2
@@ -370,6 +412,17 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
370412
}
371413
tbe_call_count[signature]++;
372414
call_count = tbe_call_count[signature];
415+
416+
// populate_uvm_stats is used as an indicator whether to push the cache
417+
// related ratios caclulated from cumulative counters into the cache stats
418+
// pools. We want to wait until all the knwon TBE ops' data been included
419+
// to get the weighted ratios.
420+
for (const auto& [sig, count] : tbe_call_count) {
421+
if (count < call_count) {
422+
populate_uvm_stats = false;
423+
break;
424+
}
425+
}
373426
}
374427

375428
if (FLAGS_tbe_uvm_cache_stat_report > 0 &&
@@ -413,7 +466,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
413466
total_cache_hash_size.value(),
414467
call_count,
415468
gather_uvm_stats,
416-
uvm_cache_stats);
469+
uvm_cache_stats,
470+
populate_uvm_stats);
417471
#endif
418472
}
419473

0 commit comments

Comments
 (0)