Skip to content

Commit 1d7d50e

Browse files
cyyeverfacebook-github-bot
authored andcommitted
Fix performance issues identified by clang-tidy (pytorch#4444)
Summary: Pull Request resolved: pytorch#4444 X-link: facebookresearch/FBGEMM#1506 Fix performance issues identified by clang-tidy. The most notable change is adding references to amsjit classes. Pull Request resolved: pytorch#4442 Reviewed By: gchalump Differential Revision: D77746424 Pulled By: q10 fbshipit-source-id: a452c82f5fe1f57d7a271b495973f97b8bb82ac9
1 parent c409ff6 commit 1d7d50e

27 files changed

+210
-213
lines changed

.clang-tidy

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ modernize*,
1919
-modernize-use-ranges,
2020
-modernize-use-integer-sign-comparison
2121
-modernize-use-nodiscard,
22+
performance*,
23+
-performance-avoid-endl
2224
'
2325
CheckOptions:
2426
- key: facebook-cuda-safe-api-call-check.HandlerName

bench/EmbeddingQuantizeBenchmark.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ static void performance_test() {
2929
constexpr int NWARMUP = 4;
3030
constexpr int NITER = 256;
3131

32-
if (is_same_v<T, float16>) {
32+
if constexpr (is_same_v<T, float16>) {
3333
cout << "With scale and bias as float16" << endl;
3434
} else {
3535
cout << "With scale and bias as float" << endl;
@@ -38,7 +38,7 @@ static void performance_test() {
3838
<< "cols" << "," << setw(16) << "elems_per_usec" << "," << setw(10)
3939
<< "GB/Sec" << endl;
4040
std::vector<int> bit_rates;
41-
if (is_same_v<T, float16>) {
41+
if constexpr (is_same_v<T, float16>) {
4242
bit_rates = {2, 4, 8};
4343
} else {
4444
// float
@@ -52,7 +52,7 @@ static void performance_test() {
5252

5353
int out_emb_cols = colSize;
5454

55-
if (is_same<T, float16>::value) {
55+
if constexpr (is_same_v<T, float16>) {
5656
int elements_per_byte = 8 / bit_rate;
5757
out_emb_cols = (colSize + elements_per_byte - 1) / elements_per_byte;
5858
}
@@ -63,7 +63,7 @@ static void performance_test() {
6363

6464
duration = measureWithWarmup(
6565
[&]() {
66-
is_same<T, float16>::value
66+
is_same_v<T, float16>
6767
? FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
6868
bit_rate,
6969
inpVec.data(),

bench/EmbeddingQuantizeFloatToFloatOrHalfBenchmark.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ static void performance_test() {
2929
constexpr int NWARMUP = 4;
3030
constexpr int NITER = 256;
3131

32-
if (is_same_v<T, float16>) {
32+
if constexpr (is_same_v<T, float16>) {
3333
cout << "With result as float16" << endl;
3434
} else {
3535
cout << "With result as float" << endl;
@@ -44,15 +44,15 @@ static void performance_test() {
4444

4545
int out_emb_cols = colSize;
4646

47-
if (is_same<T, float16>::value) {
47+
if constexpr (is_same_v<T, float16>) {
4848
out_emb_cols /= 2;
4949
}
5050
int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(T));
5151
aligned_vector<T> outVec(outVecSize);
5252

5353
double duration = 0.0f;
5454

55-
int constexpr kNumRepeats = is_same<T, float16>::value ? 16 : 32;
55+
int constexpr kNumRepeats = is_same_v<T, float16> ? 16 : 32;
5656

5757
duration = measureWithWarmup(
5858
[&]() {

bench/EmbeddingSpMDM8BitBenchmark.cc

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#endif
1313
#include <algorithm>
1414
#include <cassert>
15-
#include <chrono>
1615
#include <cmath>
1716
#include <cstdint>
1817
#include <iomanip>
@@ -262,10 +261,10 @@ static int run_benchmark(
262261
for (size_t i = 0; i < output.size(); ++i) {
263262
float tmp1 = 0;
264263
float tmp2 = 0;
265-
if constexpr (std::is_same<OutType, float>::value) {
264+
if constexpr (std::is_same_v<OutType, float>) {
266265
tmp1 = output[i];
267266
tmp2 = output_ref[i];
268-
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
267+
} else if constexpr (std::is_same_v<OutType, uint16_t>) {
269268
if (is_bf16_out) {
270269
tmp1 = cpu_bf162float(output[i]);
271270
tmp2 = cpu_bf162float(output_ref[i]);
@@ -289,9 +288,9 @@ static int run_benchmark(
289288
#pragma omp barrier
290289
#endif
291290
if (fbgemm_get_thread_num() == 0) {
292-
if constexpr (std::is_same<OutType, float>::value) {
291+
if constexpr (std::is_same_v<OutType, float>) {
293292
cout << "out type fp32";
294-
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
293+
} else if constexpr (std::is_same_v<OutType, uint16_t>) {
295294
if (is_bf16_out) {
296295
cout << "out type bf16";
297296
} else {
@@ -340,22 +339,17 @@ static int run_benchmark(
340339
}
341340

342341
int main() {
343-
int batch_size;
344-
int num_rows;
345-
int embedding_dim;
346-
int average_len;
347-
348342
bool stress_multi_threading = false;
349343

350344
vector<vector<int>> inputs(GetInputs_());
351345
benchmarkTimes.resize(fbgemm_get_max_threads());
352346

353347
for (auto& input : inputs) {
354348
assert(input.size() > 3);
355-
batch_size = input[0];
356-
num_rows = input[1];
357-
embedding_dim = input[2];
358-
average_len = input[3];
349+
int batch_size = input[0];
350+
int num_rows = input[1];
351+
int embedding_dim = input[2];
352+
int average_len = input[3];
359353

360354
cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
361355
<< setw(16) << num_rows << setw(10) << "emb dim" << setw(6)

bench/EmbeddingSpMDMBenchmark.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ static void run_benchmark(
6464

6565
vector<float> embedding_table(num_rows * embedding_dim);
6666
normal_distribution<float> embedding_distribution;
67-
for (size_t i = 0; i < embedding_table.size(); ++i) {
68-
embedding_table[i] = embedding_distribution(generator);
67+
for (float& i : embedding_table) {
68+
i = embedding_distribution(generator);
6969
}
7070
vector<float16> embedding_table_fp16;
7171
vector<bfloat16> embedding_table_bf16;
@@ -235,15 +235,15 @@ static void run_benchmark(
235235
prefetch ? 16 : 0,
236236
/*is_weight_positional=*/false,
237237
/*use_offsets=*/true,
238-
/*isbf16=*/true);
238+
/*is_bf16_out=*/true);
239239
auto kernel_bf16_i64 = GenerateEmbeddingSpMDM<bfloat16, int64_t>(
240240
embedding_dim,
241241
has_weight,
242242
normalize_by_lengths,
243243
prefetch ? 16 : 0,
244244
/*is_weight_positional=*/false,
245-
/*is_weight_positional=*/true,
246-
/*isbf16=*/true);
245+
/*use_offsets=*/true,
246+
/*is_bf16_out=*/true);
247247

248248
vector<float>& output = has_weight ? output_slws : output_sls;
249249
for (bool flush_cache : {false, true}) {

bench/EmbeddingSpMDMNBit2Benchmark.cc

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ static void print_benchmark_results() {
143143
<< "asmjit b/w (GB/s), asmjit effective b/w (GB/s), asmjit time, "
144144
<< "autovec b/w (GB/s), autovec effective b/w (GB/s), autovec time, "
145145
<< "ref b/w (GB/s), ref effective b/w (GB/s), ref time, "
146-
<< "asmjit speedup ratio, autovec speedup ratio" << std::endl;
147-
for (size_t i = 0; i < benchmarks.size(); ++i) {
148-
BenchmarkSpec& spec = benchmarks[i].first;
149-
BenchmarkResult& res = benchmarks[i].second;
146+
<< "asmjit speedup ratio, autovec speedup ratio" << endl;
147+
for (auto& benchmark : benchmarks) {
148+
BenchmarkSpec& spec = benchmark.first;
149+
BenchmarkResult& res = benchmark.second;
150150
float asmjit_speedup = res.ref_bw > 0.0 ? res.asmjit_bw / res.ref_bw : 0;
151151
float autovec_speedup = res.ref_bw > 0.0 ? res.autovec_bw / res.ref_bw : 0;
152152
std::cout << spec.bit_rate << ", " << spec.batch_size << ", "
@@ -158,7 +158,7 @@ static void print_benchmark_results() {
158158
<< res.asmjit_time << ", " << res.autovec_bw << ", "
159159
<< res.autovec_eff_bw << ", " << res.autovec_time << ", "
160160
<< res.ref_bw << ", " << res.ref_eff_bw << ", " << res.ref_time
161-
<< ", " << asmjit_speedup << ", " << autovec_speedup << std::endl;
161+
<< ", " << asmjit_speedup << ", " << autovec_speedup << endl;
162162
}
163163
}
164164

@@ -457,7 +457,7 @@ static int run_benchmark(
457457
find_benchmark_record(spec).set_asmjit_result(
458458
bytes / 1e9 / t, bytes_padded / 1e9 / t, t);
459459
} else {
460-
std::cerr << "Bad kern_type parameter: " << kern_type << std::endl;
460+
std::cerr << "Bad kern_type parameter: " << kern_type << endl;
461461
assert(false);
462462
}
463463
if (!success) {
@@ -469,20 +469,15 @@ static int run_benchmark(
469469
}
470470

471471
static void sweep_benchmark(KernelType kern_type) {
472-
int batch_size;
473-
int num_rows;
474-
int embedding_dim;
475-
int average_len;
476-
477472
vector<vector<int>> inputs(GetInputs_());
478473

479474
for (int bit_rate : {4, 2}) {
480475
for (auto& input : inputs) {
481476
assert(input.size() > 3);
482-
batch_size = input[0];
483-
num_rows = input[1];
484-
embedding_dim = input[2];
485-
average_len = input[3];
477+
int batch_size = input[0];
478+
int num_rows = input[1];
479+
int embedding_dim = input[2];
480+
int average_len = input[3];
486481

487482
auto run_benchmark_with_above_shape = [&](bool use_32_bit_indices,
488483
bool prefetch) {

bench/EmbeddingSpMDMNBitBenchmark.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ static int run_benchmark(
485485
#ifndef OUT_TYPE_FLOAT16
486486
cout << ", asmjit speedup, " << t_ref / t;
487487
#endif
488-
cout << std::endl;
488+
cout << endl;
489489
} // flush_cache
490490
} // has_weight
491491
return 0;

include/fbgemm/Fbgemm.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ class FBGEMM_API PackAMatrix final
379379
/**
380380
* @brief Print the packed block.
381381
*/
382-
void printPackedMatrix(std::string name);
382+
void printPackedMatrix(const std::string& name);
383383

384384
private:
385385
matrix_op_t trans_;
@@ -464,7 +464,7 @@ class FBGEMM_API PackBMatrix final
464464
* @brief Print the packed block.
465465
*/
466466
void printPackedMatrix(
467-
std::string name,
467+
const std::string& name,
468468
const BlockingFactors* params = nullptr);
469469

470470
/**
@@ -745,7 +745,7 @@ class FBGEMM_API PackAWithIm2Col
745745
/**
746746
* @brief Print the packed block.
747747
*/
748-
void printPackedMatrix(std::string name);
748+
void printPackedMatrix(const std::string& name);
749749

750750
/**
751751
* @return Size of row offset buffer in number of elements
@@ -835,7 +835,7 @@ class FBGEMM_API PackAWithRowOffset final
835835
/**
836836
* @brief Print the packed block.
837837
*/
838-
void printPackedMatrix(std::string name);
838+
void printPackedMatrix(const std::string& name);
839839

840840
/**
841841
* @return size of row offset buffer in number of elements
@@ -927,7 +927,7 @@ class FBGEMM_API PackAWithQuantRowOffset final
927927
/**
928928
* @brief Print the packed block.
929929
*/
930-
void printPackedMatrix(std::string name);
930+
void printPackedMatrix(const std::string& name);
931931

932932
/**
933933
* @return Size of row offset buffer in number of elements

src/DirectConv.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ class DirectConvCodeGenBase {
8080
x86::Emitter* a,
8181
int rowRegs,
8282
int colRegs,
83-
x86::Gp C_Offset,
84-
x86::Gp ldcReg,
83+
const x86::Gp& C_Offset,
84+
const x86::Gp& ldcReg,
8585
bool accum);
8686

8787
/**
@@ -93,9 +93,9 @@ class DirectConvCodeGenBase {
9393
x86::Emitter* a,
9494
int rowRegs,
9595
int colRegs,
96-
x86::Gp C_offset,
97-
x86::Gp o1XocReg,
98-
x86::Gp ldcReg,
96+
const x86::Gp& C_offset,
97+
const x86::Gp& o1XocReg,
98+
const x86::Gp& ldcReg,
9999
bool accum);
100100

101101
/**
@@ -167,9 +167,9 @@ class DirectConvCodeGenBase {
167167
template <inst_set_t instSet>
168168
void genComputeBlock(
169169
x86::Emitter* a,
170-
x86::Gp buffer_A,
171-
x86::Gp buffer_B,
172-
x86::Gp B_pf,
170+
const x86::Gp& buffer_A,
171+
const x86::Gp& buffer_B,
172+
const x86::Gp& B_pf,
173173
int rowRegs,
174174
int colRegs,
175175
int lda);
@@ -179,9 +179,9 @@ class DirectConvCodeGenBase {
179179
template <inst_set_t instSet>
180180
void genComputeBlockDirectConv(
181181
x86::Emitter* a,
182-
x86::Gp buffer_A,
183-
x86::Gp buffer_B,
184-
x86::Gp B_pf,
182+
const x86::Gp& buffer_A,
183+
const x86::Gp& buffer_B,
184+
const x86::Gp& B_pf,
185185
int rowRegs,
186186
int colRegs,
187187
int strideXich);
@@ -192,10 +192,10 @@ class DirectConvCodeGenBase {
192192
template <inst_set_t instSet>
193193
void genComputeBlockDirectConvTrans(
194194
x86::Emitter* a,
195-
x86::Gp buffer_A,
196-
x86::Gp buffer_B,
197-
x86::Gp icReg,
198-
x86::Gp C_offset,
195+
const x86::Gp& buffer_A,
196+
const x86::Gp& buffer_B,
197+
const x86::Gp& icReg,
198+
const x86::Gp& C_offset,
199199
int rowRegs,
200200
int colRegs);
201201

src/EmbeddingSpMDM.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,15 @@ GenEmbeddingSpMDMLookup<
235235
offsetType,
236236
outType,
237237
ROWWISE_SPARSE>::jit_embedding_kernel {
238-
bool is_8bit_in = std::is_same_v<inType, uint8_t>;
239-
bool is_16bit_in = std::is_same_v<inType, uint16_t>;
240-
bool is_16bit_out = std::is_same_v<outType, uint16_t>;
238+
constexpr bool is_8bit_in = std::is_same_v<inType, uint8_t>;
239+
constexpr bool is_16bit_in = std::is_same_v<inType, uint16_t>;
240+
constexpr bool is_16bit_out = std::is_same_v<outType, uint16_t>;
241241
bool is_fp16_in = is_16bit_in && !is_bf16_in;
242242
bool is_fp16_out = is_16bit_out && !is_bf16_out;
243243

244244
// TODO: Make this tunable
245245
int pref_dist = prefetch;
246-
bool areIndices64b = std::is_same_v<indxType, int64_t>;
246+
constexpr bool areIndices64b = std::is_same_v<indxType, int64_t>;
247247

248248
asmjit::CodeHolder code;
249249
code.init(runtime().environment());
@@ -576,15 +576,15 @@ GenEmbeddingSpMDMLookup<
576576
a->jl(LoopDataIndexEnd);
577577

578578
// Array out of bound check
579-
if (areIndices64b) {
579+
if constexpr (areIndices64b) {
580580
a->mov(scratchReg1_, x86::qword_ptr(indices));
581581
} else {
582582
a->mov(scratchReg1_.r32(), x86::dword_ptr(indices));
583583
}
584584
if (!scale_bias_last) {
585585
// When scale_bias_last == false, assume this is for table batched
586586
// embedding (TBE) that can get -1 for pruned rows.
587-
if (areIndices64b) {
587+
if constexpr (areIndices64b) {
588588
a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1));
589589
} else {
590590
a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
@@ -623,7 +623,7 @@ GenEmbeddingSpMDMLookup<
623623
a->cmp(scratchReg2_, index_size);
624624
a->jge(pref_dist_reset_start);
625625

626-
if (areIndices64b) {
626+
if constexpr (areIndices64b) {
627627
a->mov(
628628
scratchReg2_,
629629
x86::qword_ptr(indices, pref_dist * sizeof(indxType)));
@@ -638,7 +638,7 @@ GenEmbeddingSpMDMLookup<
638638
a->bind(pref_dist_reset_start);
639639
// things are not okay just get the current row
640640
// this can be improved to getting the max dist row.
641-
if (areIndices64b) {
641+
if constexpr (areIndices64b) {
642642
a->mov(scratchReg2_, x86::qword_ptr(indices));
643643
} else {
644644
a->mov(scratchReg2_.r32(), x86::dword_ptr(indices));

0 commit comments

Comments
 (0)