Skip to content

Commit 048376f

Browse files
Kaiwei Tufacebook-github-bot
authored andcommitted
Implement a stat library for fbgemm embedding (#4339)
Summary: Pull Request resolved: #4339 X-link: facebookresearch/FBGEMM#1408 This diff implements a framework to collect statistics within the FBGEMM kernel named `EmbeddingStatsTracker`. We have implemented a class method called `recordPattern` to track embedding access patterns inside `EmbeddingSpMDMAutovec.cc`, which generates a log file that counts the frequency of different access patterns. This feature is controlled by the following environment variables: - `FBGEMM_STATS_ENABLE` (required) enables or disables the feature. Set to 1 to enable, or leave unset to disable. - `FBGEMM_STATS_FREQ` (optional) controls how many samples are collected for each log file update and has a default value of `1000000`. - `FBGEMM_STATS_LOGPATH` (optional) specifies where the log file is written and has a default value of `"/tmp/fbgemm_embedding_stats.txt"`. We made the following changes: - Implemented a statistics framework for FBGEMM embedding in `EmbeddingStatsTracker.h` and `EmbeddingStatsTracker.cc`. - Added the corresponding tracking logic inside `EmbeddingSpMDMAutovec.cc`. - Added environment variable loading logic in `Util.h` and `Util.cc`. Reviewed By: excelle08 Differential Revision: D76060846 fbshipit-source-id: ad3445a15b7d0b6f0f8976914ca57def04fc68fd
1 parent 77febf6 commit 048376f

File tree

8 files changed

+510
-1
lines changed

8 files changed

+510
-1
lines changed

defs.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def get_fbgemm_inline_neon_srcs(msvc = False, buck = False):
202202
def get_fbgemm_autovec_srcs():
203203
return [
204204
"src/EmbeddingSpMDMAutovec.cc",
205+
"src/EmbeddingStatsTracker.cc",
205206
]
206207

207208
def get_fbgemm_tests(skip_tests = ["test/FP32Test.cc"]):

fbgemm_gpu/cmake/Fbgemm.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ set(fbgemm_sources_normal
1818
"${FBGEMM}/src/Utils.cc")
1919

2020
if(NOT DISABLE_FBGEMM_AUTOVEC)
21-
list(APPEND fbgemm_sources_normal "${FBGEMM}/src/EmbeddingSpMDMAutovec.cc")
21+
list(APPEND fbgemm_sources_normal "${FBGEMM}/src/EmbeddingSpMDMAutovec.cc" "${FBGEMM}/src/EmbeddingStatsTracker.cc")
2222
endif()
2323

2424
set(fbgemm_sources_avx2

include/fbgemm/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ FBGEMM_API bool is_radix_sort_accelerated_with_openmp();
431431
FBGEMM_API bool is_autovec_disabled();
432432
FBGEMM_API bool is_autovec_forced();
433433
FBGEMM_API bool is_asmjit_disabled();
434+
FBGEMM_API bool is_stats_enabled();
434435

435436
/**
436437
* @brief A function to check if the input parameter in the nbit CPU TBE kernel

src/EmbeddingSpMDMAutovec.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#define FBGEMM_EXPORTS
1212
#include "./EmbeddingSpMDMAutovec.h" // @manual
13+
#include "./EmbeddingStatsTracker.h"
1314
#include "./RefImplementations.h" // @manual
1415
#include "fbgemm/FbgemmBuild.h"
1516
#include "fbgemm/FloatConversion.h"
@@ -69,6 +70,18 @@ static inline void fill_output(
6970
}
7071
}
7172

73+
template <typename OutType>
74+
static inline EmbeddingStatsTracker::DataType get_output_type(
75+
const bool is_bf16_out) {
76+
if (std::is_same<OutType, float>::value) {
77+
return EmbeddingStatsTracker::DataType::FP32;
78+
} else if (std::is_same<OutType, uint16_t>::value && is_bf16_out) {
79+
return EmbeddingStatsTracker::DataType::BF16;
80+
} else {
81+
return EmbeddingStatsTracker::DataType::FP16;
82+
}
83+
}
84+
7285
template <typename IndexType, typename OffsetType, typename OutType>
7386
static bool ALWAYS_INLINE EmbeddingSpMDM8Bit_autovec(
7487
const int64_t block_size,
@@ -171,6 +184,15 @@ static bool ALWAYS_INLINE EmbeddingSpMDM8Bit_autovec(
171184
}
172185
out += output_stride;
173186
} // m
187+
// Track every forward pass in the no_bag case
188+
EmbeddingStatsTracker::getInstance().recordPattern(
189+
data_size,
190+
block_size,
191+
EmbeddingStatsTracker::DataType::INT8,
192+
isOutput8bit ? EmbeddingStatsTracker::DataType::INT8
193+
: get_output_type<OutType>(is_bf16_out),
194+
output_size,
195+
1);
174196
return true;
175197
} // no_bag
176198

@@ -185,6 +207,15 @@ static bool ALWAYS_INLINE EmbeddingSpMDM8Bit_autovec(
185207
return false;
186208
}
187209

210+
// Track every forward inference with the actual bag size (len)
211+
EmbeddingStatsTracker::getInstance().recordPattern(
212+
data_size,
213+
block_size,
214+
EmbeddingStatsTracker::DataType::INT8,
215+
get_output_type<OutType>(is_bf16_out),
216+
output_size,
217+
len);
218+
188219
const float* weights_addr = weights != nullptr
189220
? (is_weight_positional ? weights : weights + current)
190221
: nullptr;
@@ -316,6 +347,7 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
316347
WARN_ONCE("no_bag is only supported for int4 to int4");
317348
return false;
318349
}
350+
319351
for (int64_t i = 0; i < output_size; ++i) {
320352
const auto idx = indices[i];
321353
if (idx < 0 || idx > data_size) {
@@ -325,6 +357,15 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
325357
memcpy(out, input_row, sizeof(uint8_t) * input_stride);
326358
out += input_stride;
327359
}
360+
361+
// Track every forward pass with the actual bag size (len)
362+
EmbeddingStatsTracker::getInstance().recordPattern(
363+
data_size,
364+
block_size,
365+
EmbeddingStatsTracker::DataType::INT4,
366+
EmbeddingStatsTracker::DataType::INT4,
367+
output_size,
368+
1);
328369
return true;
329370
}
330371

@@ -348,6 +389,16 @@ static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
348389
if (end > index_size) {
349390
return false;
350391
}
392+
393+
// Track every forward pass with the actual bag size (len)
394+
EmbeddingStatsTracker::getInstance().recordPattern(
395+
data_size,
396+
block_size,
397+
input_bit_rate == 4 ? EmbeddingStatsTracker::DataType::INT4
398+
: EmbeddingStatsTracker::DataType::INT2,
399+
get_output_type<OutType>(is_bf16_out),
400+
output_size,
401+
len);
351402
memset(buf, 0, sizeof(float) * rounded_block_size);
352403

353404
const float* weights_addr = weights != nullptr
@@ -558,6 +609,16 @@ static bool ALWAYS_INLINE EmbeddingSpMDM_autovec(
558609
fill_output(out, buf, block_size, is_bf16_out);
559610
out += output_stride;
560611
} // m
612+
613+
EmbeddingStatsTracker::getInstance().recordPattern(
614+
data_size,
615+
block_size,
616+
is_bf16_in ? EmbeddingStatsTracker::DataType::BF16
617+
: EmbeddingStatsTracker::DataType::FP32,
618+
get_output_type<OutType>(is_bf16_out),
619+
output_size,
620+
1);
621+
561622
return true;
562623
} // no_bag
563624

@@ -592,6 +653,15 @@ static bool ALWAYS_INLINE EmbeddingSpMDM_autovec(
592653
if (current + len > index_size) {
593654
return false;
594655
}
656+
// Track every inference for actual bag size (len)
657+
EmbeddingStatsTracker::getInstance().recordPattern(
658+
data_size,
659+
block_size,
660+
is_bf16_in ? EmbeddingStatsTracker::DataType::BF16
661+
: EmbeddingStatsTracker::DataType::FP32,
662+
get_output_type<OutType>(is_bf16_out),
663+
output_size,
664+
len);
595665

596666
for (int i = 0; i < len; ++i) {
597667
int64_t idx = indices[current];
@@ -683,6 +753,13 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
683753
if (end > index_size) {
684754
return false;
685755
}
756+
EmbeddingStatsTracker::getInstance().recordPattern(
757+
uncompressed_data_size,
758+
block_size,
759+
EmbeddingStatsTracker::DataType::SPARSE_INT8,
760+
EmbeddingStatsTracker::DataType::FP32,
761+
output_size,
762+
len);
686763
const float* weights_addr = weights != nullptr
687764
? (is_weight_positional ? weights : weights + current)
688765
: nullptr;
@@ -749,6 +826,14 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
749826
return false;
750827
}
751828

829+
EmbeddingStatsTracker::getInstance().recordPattern(
830+
uncompressed_data_size,
831+
block_size,
832+
EmbeddingStatsTracker::DataType::SPARSE_FP32,
833+
EmbeddingStatsTracker::DataType::FP32,
834+
output_size,
835+
len);
836+
752837
const float* weights_addr = weights != nullptr
753838
? (is_weight_positional ? weights : weights + current)
754839
: nullptr;
@@ -926,6 +1011,14 @@ static bool ALWAYS_INLINE EmbeddingSpMDMFP8_autovec(
9261011
return false;
9271012
}
9281013

1014+
EmbeddingStatsTracker::getInstance().recordPattern(
1015+
data_size,
1016+
block_size,
1017+
EmbeddingStatsTracker::DataType::FP8,
1018+
get_output_type<OutType>(is_bf16_out),
1019+
output_size,
1020+
len);
1021+
9291022
// Adjust these as necessary to reflect actual batch size
9301023
const int batch_size = block_size; // Assuming the entire block is
9311024
// processed at once; adjust if needed

src/EmbeddingStatsTracker.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#define FBGEMM_EXPORTS
10+
#include "./EmbeddingStatsTracker.h"
11+
#include <iostream>
12+
#include "fbgemm/Utils.h"
13+
14+
namespace fbgemm {
15+
16+
EmbeddingStatsTracker& EmbeddingStatsTracker::getInstance() {
17+
static EmbeddingStatsTracker instance;
18+
return instance;
19+
}
20+
21+
void EmbeddingStatsTracker::recordPattern(
22+
int64_t rows,
23+
int64_t dims,
24+
DataType input_type,
25+
DataType output_type,
26+
int64_t batch_size,
27+
int64_t bag_size) {
28+
if (!is_stats_enabled() || bag_size == 0) {
29+
return;
30+
}
31+
std::lock_guard<std::mutex> lock(mutex_);
32+
33+
// Create the entry and ensure the pattern exists
34+
AccessPatternEntry key(
35+
rows, dims, batch_size, bag_size, input_type, output_type);
36+
auto result = tables_.find(key);
37+
if (result == tables_.end()) {
38+
tables_[key] = 1;
39+
} else {
40+
result->second += 1;
41+
}
42+
43+
sampleCount_ += 1;
44+
45+
if (sampleCount_ % config_.getLogFreq() == 0) {
46+
// Log the table statistics - only try to open the file if it's not
47+
logFile_.open(config_.getLogFilePath(), std::ios::out | std::ios::trunc);
48+
49+
if (!logFile_) {
50+
std::cerr << "Failed to open log file: " << config_.getLogFilePath()
51+
<< '\n';
52+
return;
53+
}
54+
for (const auto& pair : tables_) {
55+
const auto& pattern = pair.first;
56+
logFile_ << pattern.toString() << "freq=" << pair.second << ";"
57+
<< std::endl;
58+
}
59+
logFile_.flush();
60+
logFile_.close();
61+
}
62+
}
63+
64+
} // namespace fbgemm

0 commit comments

Comments
 (0)