Skip to content

Commit f6100fc

Browse files
cyyeverfacebook-github-bot
authored andcommitted
Use if constexpr (#4426)
Summary: Pull Request resolved: #4426 X-link: facebookresearch/FBGEMM#1493 Add `if constexpr` to all possible if statements. Pull Request resolved: #4422 Reviewed By: gchalump Differential Revision: D77571436 Pulled By: q10 fbshipit-source-id: 056aee5283dfb6b9f2c39ba987383fa6ce394a6b
1 parent 80ed942 commit f6100fc

26 files changed

+87
-83
lines changed

bench/EmbeddingSpMDM8BitBenchmark.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ int run_benchmark(
261261
for (size_t i = 0; i < output.size(); ++i) {
262262
float tmp1 = 0;
263263
float tmp2 = 0;
264-
if (std::is_same<OutType, float>::value) {
264+
if constexpr (std::is_same<OutType, float>::value) {
265265
tmp1 = output[i];
266266
tmp2 = output_ref[i];
267-
} else if (std::is_same<OutType, uint16_t>::value) {
267+
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
268268
if (is_bf16_out) {
269269
tmp1 = cpu_bf162float(output[i]);
270270
tmp2 = cpu_bf162float(output_ref[i]);
@@ -288,9 +288,9 @@ int run_benchmark(
288288
#pragma omp barrier
289289
#endif
290290
if (fbgemm_get_thread_num() == 0) {
291-
if (std::is_same<OutType, float>::value) {
291+
if constexpr (std::is_same<OutType, float>::value) {
292292
cout << "out type fp32";
293-
} else if (std::is_same<OutType, uint16_t>::value) {
293+
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
294294
if (is_bf16_out) {
295295
cout << "out type bf16";
296296
} else {

bench/EmbeddingSpMDMNBitBenchmark.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,10 @@ int run_benchmark(
375375
for (size_t i = 0; i < output.size(); ++i) {
376376
float tmp1 = 0;
377377
float tmp2 = 0;
378-
if (std::is_same<OutType, float>::value) {
378+
if constexpr (std::is_same<OutType, float>::value) {
379379
tmp1 = output[i];
380380
tmp2 = output_ref[i];
381-
} else if (std::is_same<OutType, uint16_t>::value) {
381+
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
382382
if (is_bf16_out) {
383383
tmp1 = cpu_bf162float(output[i]);
384384
tmp2 = cpu_bf162float(output_ref[i]);
@@ -411,10 +411,10 @@ int run_benchmark(
411411
for (size_t i = 0; i < output_autovec.size(); ++i) {
412412
float tmp1 = 0;
413413
float tmp2 = 0;
414-
if (std::is_same<OutType, float>::value) {
414+
if constexpr (std::is_same<OutType, float>::value) {
415415
tmp1 = output_autovec[i];
416416
tmp2 = output_ref[i];
417-
} else if (std::is_same<OutType, uint16_t>::value) {
417+
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
418418
if (is_bf16_out) {
419419
tmp1 = cpu_bf162float(output_autovec[i]);
420420
tmp2 = cpu_bf162float(output_ref[i]);
@@ -437,9 +437,9 @@ int run_benchmark(
437437
#endif
438438
}
439439

440-
if (std::is_same<OutType, float>::value) {
440+
if constexpr (std::is_same<OutType, float>::value) {
441441
cout << "out type fp32, ";
442-
} else if (std::is_same<OutType, uint16_t>::value) {
442+
} else if constexpr (std::is_same<OutType, uint16_t>::value) {
443443
if (is_bf16_out) {
444444
cout << "out type bf16, ";
445445
} else {

fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ Tensor _fusednbitrowwise_to_float_cpu(
130130
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
131131

132132
Tensor output;
133-
if (std::is_same<output_t, float>::value) {
133+
if constexpr (std::is_same<output_t, float>::value) {
134134
output = at::empty(
135135
{nrows, output_columns}, // 4 = sizeof(float)
136136
input.options().dtype(at::kFloat));
@@ -167,15 +167,15 @@ Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
167167
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;
168168

169169
Tensor output;
170-
if (std::is_same<output_t, float>::value) {
170+
if constexpr (std::is_same<output_t, float>::value) {
171171
output = at::empty(
172172
{nrows, output_columns}, // 4 = sizeof(float)
173173
input.options().dtype(at::kFloat));
174-
} else if (std::is_same<output_t, at::Half>::value) {
174+
} else if constexpr (std::is_same<output_t, at::Half>::value) {
175175
output = at::empty(
176176
{nrows, output_columns}, // 2 = sizeof(half)
177177
input.options().dtype(at::kHalf));
178-
} else if (std::is_same<output_t, at::BFloat16>::value) {
178+
} else if constexpr (std::is_same<output_t, at::BFloat16>::value) {
179179
output = at::empty(
180180
{nrows, output_columns}, // 2 = sizeof(half)
181181
input.options().dtype(at::kBFloat16));
@@ -258,7 +258,7 @@ Tensor float_or_half_to_fused8bitrowwise_cpu(const Tensor& input) {
258258
input.options().dtype(at::kByte)); // at::kBytes for uint8_t
259259
FBGEMM_DISPATCH_FLOAT_AND_HALF(
260260
input.scalar_type(), "float_or_half_to_fused8bitrowwise_cpu", [&] {
261-
if (std::is_same<scalar_t, float>::value) {
261+
if constexpr (std::is_same<scalar_t, float>::value) {
262262
_float_to_fused8bitrowwise_cpu_out(output, input);
263263
} else { // scalar_t = at::Half
264264
_half_to_fused8bitrowwise_cpu_out(output, input);
@@ -419,7 +419,7 @@ Tensor float_or_half_to_fusednbitrowwise_cpu(
419419
Tensor output;
420420
FBGEMM_DISPATCH_FLOAT_AND_HALF(
421421
input.scalar_type(), "float_or_half_to_fusednbitrowwise_cpu", [&] {
422-
if (std::is_same<scalar_t, float>::value) {
422+
if constexpr (std::is_same<scalar_t, float>::value) {
423423
output = _float_to_fusednbitrowwise_cpu<float>(input, bit_rate);
424424
} else { // scalar_t = at::Half
425425
output =

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ Tensor masked_index_impl(
141141
const auto func_name = is_index_put ? "masked_index_put_kernel"
142142
: "masked_index_select_kernel";
143143
#endif
144-
if (std::is_same_v<value_t, uint8_t>) {
144+
if constexpr (std::is_same_v<value_t, uint8_t>) {
145145
TORCH_CHECK(D % 16 == 0, "D needs to be padded to be multiple of 16");
146146
}
147147
FBGEMM_DISPATCH_INTEGRAL_TYPES(

include/fbgemm/FbgemmPackMatrixB.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class PackedGemmMatrixB {
6565
const int brow = 512)
6666
: nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) {
6767
#ifdef FBGEMM_ENABLE_KLEIDIAI
68-
if (std::is_same<T, float16>::value) {
68+
if constexpr (std::is_same<T, float16>::value) {
6969
kernel_ncol_blocks_ = 1;
7070
}
7171
#endif
@@ -94,7 +94,7 @@ class PackedGemmMatrixB {
9494
size_(size),
9595
kernel_ncol_blocks_(2) {
9696
#ifdef FBGEMM_ENABLE_KLEIDIAI
97-
if (std::is_same<T, float16>::value) {
97+
if constexpr (std::is_same<T, float16>::value) {
9898
kernel_ncol_blocks_ = 1;
9999
}
100100
#endif
@@ -122,7 +122,7 @@ class PackedGemmMatrixB {
122122
size_(size),
123123
kernel_ncol_blocks_(kernel_ncol_blocks) {
124124
#ifdef FBGEMM_ENABLE_KLEIDIAI
125-
if (std::is_same<T, float16>::value) {
125+
if constexpr (std::is_same<T, float16>::value) {
126126
kernel_ncol_blocks_ = 1;
127127
}
128128
#endif

include/fbgemm/OutputProcessing-inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
104104
}
105105
float raw_f;
106106
if (bias_) {
107-
if (std::is_same<BIAS_TYPE, float>::value) {
107+
if constexpr (std::is_same<BIAS_TYPE, float>::value) {
108108
raw_f = raw;
109109
raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx];
110110
} else {

include/fbgemm/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ void nbit_embedding_sanity_check(
447447
assert(
448448
(input_bit_rate == 2 || input_bit_rate == 4) &&
449449
"input_bit_rate must be 2 or 4");
450-
if (std::is_same<OutType, uint8_t>::value) {
450+
if constexpr (std::is_same<OutType, uint8_t>::value) {
451451
assert(
452452
(no_bag && input_bit_rate == 4 && output_bit_rate == 4) &&
453453
"we currently only support int4 to int4 for sequential TBE");

src/DirectConv.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ class DirectConvCodeGenBase {
113113
int NR) {
114114
std::ostringstream oss;
115115
oss << "directconv_";
116-
if (std::is_same<accT, std::int16_t>::value) {
116+
if constexpr (std::is_same<accT, std::int16_t>::value) {
117117
oss << "acc16_";
118-
} else if (std::is_same<accT, std::int32_t>::value) {
118+
} else if constexpr (std::is_same<accT, std::int32_t>::value) {
119119
oss << "acc32_";
120120
} else {
121121
oss << "unknown_";

src/EmbeddingSpMDM.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ GenEmbeddingSpMDMLookup<
862862
a->vmulps(out_vreg, out_vreg, vlen_inv_vreg);
863863
}
864864

865-
if (std::is_same_v<outType, float>) {
865+
if constexpr (std::is_same_v<outType, float>) {
866866
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
867867
if (instSet == inst_set_t::avx2) {
868868
a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm());
@@ -1042,7 +1042,7 @@ typename EmbeddingSpMDMKernelSignature<inType, indxType, offsetType, outType>::
10421042
output_stride = block_size;
10431043
}
10441044
if (input_stride == -1) {
1045-
if (std::is_same_v<inType, uint8_t>) {
1045+
if constexpr (std::is_same_v<inType, uint8_t>) {
10461046
const auto scale_bias_offset =
10471047
2 * (scale_bias_last ? sizeof(float) : sizeof(uint16_t));
10481048
input_stride = block_size + scale_bias_offset;
@@ -1351,7 +1351,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
13511351
bool use_offsets) {
13521352
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
13531353
int64_t input_stride = block_size;
1354-
if (std::is_same_v<inType, uint8_t>) {
1354+
if constexpr (std::is_same_v<inType, uint8_t>) {
13551355
const auto scale_bias_offset = 2 * sizeof(float);
13561356
input_stride = block_size + scale_bias_offset;
13571357
}

src/EmbeddingSpMDMAutovec.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ static inline void fill_output(
5454
const float* src,
5555
const int64_t block_size,
5656
const bool is_bf16_out) {
57-
if (std::is_same_v<OutType, float>) {
57+
if constexpr (std::is_same_v<OutType, float>) {
5858
for (int j = 0; j < block_size; ++j) {
5959
out[j] = src[j];
6060
}
@@ -72,7 +72,7 @@ static inline void fill_output(
7272
template <typename OutType>
7373
static inline EmbeddingStatsTracker::DataType get_output_type(
7474
const bool is_bf16_out) {
75-
if (std::is_same_v<OutType, float>) {
75+
if constexpr (std::is_same_v<OutType, float>) {
7676
return EmbeddingStatsTracker::DataType::FP32;
7777
} else if (std::is_same_v<OutType, uint16_t> && is_bf16_out) {
7878
return EmbeddingStatsTracker::DataType::BF16;
@@ -1139,7 +1139,7 @@ template <typename InType>
11391139
static int64_t stride_SpMDMWithStrides(
11401140
int64_t block_size,
11411141
bool scale_bias_last) {
1142-
if (std::is_same_v<InType, uint8_t>) {
1142+
if constexpr (std::is_same_v<InType, uint8_t>) {
11431143
const size_t scale_bias_offset =
11441144
2 * (scale_bias_last ? sizeof(float) : sizeof(uint16_t));
11451145
return block_size + scale_bias_offset;
@@ -1215,7 +1215,7 @@ typename EmbeddingSpMDMKernelSignature<InType, IndexType, OffsetType, OutType>::
12151215
} else { \
12161216
weights = nullptr; \
12171217
} \
1218-
if (std::is_same<InType, uint8_t>::value) { \
1218+
if constexpr (std::is_same<InType, uint8_t>::value) { \
12191219
assert(!specialize(IS_BF16_IN, is_bf16_in)); \
12201220
return EmbeddingSpMDM8Bit_autovec( \
12211221
specialize(BLOCK_SIZE, block_size), \

0 commit comments

Comments
 (0)