9
9
#include < ATen/core/op_registration/op_registration.h>
10
10
#include < ATen/cuda/CUDAContext.h>
11
11
#include < torch/library.h>
12
+ #include < algorithm>
12
13
#include " c10/core/ScalarType.h"
13
14
#ifdef FBCODE_CAFFE2
14
15
#include " common/stats/Stats.h"
@@ -73,27 +74,61 @@ void process_uvm_cache_stats(
73
74
const int64_t total_cache_hash_size,
74
75
const int64_t call_count,
75
76
const bool gather_uvm_stats,
76
- const Tensor& uvm_cache_stats) {
77
+ const Tensor& uvm_cache_stats,
78
+ const bool populate_uvm_stats) {
77
79
if (gather_uvm_stats) {
80
+ static std::mutex cache_mutex;
81
+
82
+ // uvm_cache_stats_counters is a vector of size 4, storing the cumulated
83
+ // cache stats. Each element represents different counter respectively:
84
+ // uvm_cache_stats_counters[0]: num_req_indices
85
+ // uvm_cache_stats_counters[1]: num_unique_indices
86
+ // uvm_cache_stats_counters[2]: num_unique_misses
87
+ // uvm_cache_stats_counters[3]: num_unique_conflict_misses
88
+ // They should be zero-out after the calculated rates are populated into
89
+ // cache counters.
90
+ static std::vector<int64_t > uvm_cache_stats_counters (4 );
91
+
78
92
// Export cache stats.
79
93
auto uvm_cache_stats_cpu = uvm_cache_stats.cpu ();
80
94
auto * uvm_cache_stats_ptr = uvm_cache_stats_cpu.data_ptr <int32_t >();
81
95
if (uvm_cache_stats_ptr[1 ] > 0 ) {
82
96
// Report cache stats in per-mille.
83
- double num_requested_indices =
84
- static_cast <double >(uvm_cache_stats_ptr[1 ]);
85
- double unique_rate = static_cast <double >(uvm_cache_stats_ptr[2 ] * 1000 ) /
86
- num_requested_indices;
87
- double unique_miss_rate =
88
- static_cast <double >(uvm_cache_stats_ptr[3 ] * 1000 ) /
89
- num_requested_indices;
90
- double unique_conflict_miss_rate =
91
- static_cast <double >(uvm_cache_stats_ptr[4 ] * 1000 ) /
92
- num_requested_indices;
93
- STATS_tbe_uvm_cache_unique_rate.addValue (unique_rate);
94
- STATS_tbe_uvm_cache_unique_miss_rate.addValue (unique_miss_rate);
95
- STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue (
96
- unique_conflict_miss_rate);
97
+ {
98
+ // Add cache stats values into the culmulated variables.
99
+ std::lock_guard<std::mutex> guard (cache_mutex);
100
+ std::transform (
101
+ uvm_cache_stats_counters.begin (),
102
+ uvm_cache_stats_counters.end (),
103
+ uvm_cache_stats_ptr + 1 ,
104
+ uvm_cache_stats_counters.begin (),
105
+ std::plus<int >());
106
+
107
+ // Calculate cache related ratios based on the cumulated numbers and
108
+ // push them into the counter pools.
109
+ if (populate_uvm_stats && uvm_cache_stats_counters[0 ] > 0 ) {
110
+ double unique_rate =
111
+ static_cast <double >(uvm_cache_stats_counters[1 ]) /
112
+ uvm_cache_stats_counters[0 ] * 1000 ;
113
+ double unique_miss_rate =
114
+ static_cast <double >(uvm_cache_stats_counters[2 ]) /
115
+ uvm_cache_stats_counters[0 ] * 1000 ;
116
+ double unique_conflict_miss_rate =
117
+ static_cast <double >(uvm_cache_stats_counters[3 ]) /
118
+ uvm_cache_stats_counters[0 ] * 1000 ;
119
+ STATS_tbe_uvm_cache_unique_rate.addValue (unique_rate);
120
+ STATS_tbe_uvm_cache_unique_miss_rate.addValue (unique_miss_rate);
121
+ STATS_tbe_uvm_cache_conflict_unique_miss_rate.addValue (
122
+ unique_conflict_miss_rate);
123
+
124
+ // Fill all the elements of the vector uvm_cache_stats_counters as 0
125
+ // to zero out the cumulated counters.
126
+ std::fill (
127
+ uvm_cache_stats_counters.begin (),
128
+ uvm_cache_stats_counters.end (),
129
+ 0 );
130
+ }
131
+ }
97
132
}
98
133
if (call_count % FLAGS_tbe_uvm_cache_stats_print_out_period == 0 ) {
99
134
LOG (INFO) << " $Stats [" << signature << " ] "
@@ -358,6 +393,13 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
358
393
cache_hash_size_cumsum.value (), indices, offsets);
359
394
360
395
bool gather_uvm_stats = false ;
396
+ // populate_uvm_stats indicates whether to calculate cache related ratios,
397
+ // using the data from cumulated counters, and populate them into the cache
398
+ // stats pools to get the percentil stats. We want to calculate the weighted
399
+ // cache ratios, taking the # req indices of each TBE as the weight. so we
400
+ // will populate stats when we think the current lookup is for the last TBE
401
+ // call of the same round.
402
+ bool populate_uvm_stats = true ;
361
403
Tensor uvm_cache_stats =
362
404
at::empty ({0 }, lxu_cache_weights.value ().options ().dtype (at::kInt ));
363
405
#ifdef FBCODE_CAFFE2
@@ -370,6 +412,17 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
370
412
}
371
413
tbe_call_count[signature]++;
372
414
call_count = tbe_call_count[signature];
415
+
416
+ // populate_uvm_stats is used as an indicator whether to push the cache
417
+ // related ratios caclulated from cumulative counters into the cache stats
418
+ // pools. We want to wait until all the knwon TBE ops' data been included
419
+ // to get the weighted ratios.
420
+ for (const auto & [sig, count] : tbe_call_count) {
421
+ if (count < call_count) {
422
+ populate_uvm_stats = false ;
423
+ break ;
424
+ }
425
+ }
373
426
}
374
427
375
428
if (FLAGS_tbe_uvm_cache_stat_report > 0 &&
@@ -413,7 +466,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
413
466
total_cache_hash_size.value (),
414
467
call_count,
415
468
gather_uvm_stats,
416
- uvm_cache_stats);
469
+ uvm_cache_stats,
470
+ populate_uvm_stats);
417
471
#endif
418
472
}
419
473
0 commit comments