Skip to content

Commit 8feae04

Browse files
cyyeverfacebook-github-bot
authored andcommitted
Use if constexpr in more places (#4436)
Summary: Pull Request resolved: #4436 X-link: facebookresearch/FBGEMM#1501 A thorough code examination revealed that it's possible to use `if constexpr` in more places. Pull Request resolved: #4434 Reviewed By: cthi Differential Revision: D77642003 Pulled By: q10 fbshipit-source-id: 9e4480e13b7c83bbfb7f8899b5692f12c1c4babc
1 parent 1988fb5 commit 8feae04

25 files changed

+270
-250
lines changed

bench/ConvUnifiedBenchmark.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,42 +230,42 @@ static void performance_test(
230230
const int NITER = repetitions;
231231

232232
string header = "MB, IC, OC, ";
233-
if (SPATIAL_DIM == 3) {
233+
if constexpr (SPATIAL_DIM == 3) {
234234
header += "IT, ";
235235
}
236236
if (SPATIAL_DIM > 1) {
237237
header += "IH, ";
238238
}
239239
header += "IW, G, ";
240-
if (SPATIAL_DIM == 3) {
240+
if constexpr (SPATIAL_DIM == 3) {
241241
header += "KT, ";
242242
}
243243
if (SPATIAL_DIM > 1) {
244244
header += "KH, ";
245245
}
246246
header += "KW, ";
247-
if (SPATIAL_DIM == 3) {
247+
if constexpr (SPATIAL_DIM == 3) {
248248
header += "stride_t, ";
249249
}
250250
if (SPATIAL_DIM > 1) {
251251
header += "stride_h, ";
252252
}
253253
header += "stride_w, ";
254-
if (SPATIAL_DIM == 3) {
254+
if constexpr (SPATIAL_DIM == 3) {
255255
header += "pad_t, ";
256256
}
257257
if (SPATIAL_DIM > 1) {
258258
header += "pad_h, ";
259259
}
260260
header += "pad_w, ";
261-
if (SPATIAL_DIM == 3) {
261+
if constexpr (SPATIAL_DIM == 3) {
262262
header += "dilation_t, ";
263263
}
264264
if (SPATIAL_DIM > 1) {
265265
header += "dilation_h, ";
266266
}
267267
header += "dilation_w, ";
268-
if (SPATIAL_DIM == 3) {
268+
if constexpr (SPATIAL_DIM == 3) {
269269
header += "output_padding_t, ";
270270
}
271271
if (SPATIAL_DIM > 1) {

include/fbgemm/OutputProcessing-inl.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,20 @@ ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
8181
block.col_size <= ncol_per_group &&
8282
"ReQuantizeOutput should be called at most 1 group at a time.");
8383
int g = block.col_start / ncol_per_group;
84-
if (instSet == inst_set_t::anyarch || !std::is_same<outT, uint8_t>::value) {
84+
if constexpr (
85+
instSet == inst_set_t::anyarch || !std::is_same<outT, uint8_t>::value) {
8586
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
8687
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
8788
inT raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)];
8889
if (Aq_zero_point_) {
8990
raw -= Aq_zero_point_ * q_col_offsets_[j];
9091
}
9192
int Bq_zero_point_idx;
92-
if (Q_GRAN == QuantizationGranularity::TENSOR) {
93+
if constexpr (Q_GRAN == QuantizationGranularity::TENSOR) {
9394
Bq_zero_point_idx = 0;
94-
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
95+
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
9596
Bq_zero_point_idx = g;
96-
} else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
97+
} else if constexpr (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
9798
Bq_zero_point_idx = j;
9899
} else {
99100
assert(false && "unknown quantization granularity");
@@ -123,7 +124,8 @@ ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
123124
std::min(255l, rounded));
124125
}
125126
}
126-
} else if (instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
127+
} else if constexpr (
128+
instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
127129
bool b_symmetric =
128130
(Q_GRAN == QuantizationGranularity::TENSOR && Bq_zero_point_[0] == 0) ||
129131
q_row_offsets_ == nullptr;
@@ -211,19 +213,20 @@ inline int ReQuantizeForFloat<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
211213
block.col_size <= ncol_per_group &&
212214
"ReQuantizeOutput should be called at most 1 group at a time.");
213215
int g = block.col_start / ncol_per_group;
214-
if (instSet == inst_set_t::anyarch || !std::is_same<outT, float>::value) {
216+
if constexpr (
217+
instSet == inst_set_t::anyarch || !std::is_same<outT, float>::value) {
215218
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
216219
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
217220
inT raw = inp[(i - block.row_start) * ld_in + j - block.col_start];
218221
if (Aq_zero_point_) {
219222
raw -= Aq_zero_point_ * q_col_offsets_[j];
220223
}
221224
int Bq_zero_point_idx;
222-
if (Q_GRAN == QuantizationGranularity::TENSOR) {
225+
if constexpr (Q_GRAN == QuantizationGranularity::TENSOR) {
223226
Bq_zero_point_idx = 0;
224-
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
227+
} else if constexpr (Q_GRAN == QuantizationGranularity::GROUP) {
225228
Bq_zero_point_idx = g;
226-
} else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
229+
} else if constexpr (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
227230
Bq_zero_point_idx = j;
228231
} else {
229232
assert(false && "unknown quantization granularity");
@@ -242,7 +245,8 @@ inline int ReQuantizeForFloat<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
242245
}
243246
}
244247
}
245-
} else if (instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
248+
} else if constexpr (
249+
instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
246250
bool b_symmetric =
247251
(Q_GRAN == QuantizationGranularity::TENSOR && Bq_zero_point_[0] == 0) ||
248252
q_row_offsets_ == nullptr;

include/fbgemm/Utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ bool isValidBlockingFactor(const BlockingFactors* const param) {
267267
constexpr bool is_16bit = std::is_same<accT, int16_t>::value;
268268
static const auto iset = fbgemmInstructionSet();
269269

270-
if (is_32bit) {
270+
if constexpr (is_32bit) {
271271
if (param->ROW_INTERLEAVE != 4)
272272
return false;
273273

@@ -278,7 +278,7 @@ bool isValidBlockingFactor(const BlockingFactors* const param) {
278278
if (param->NR_MIN != 8 || param->NR % param->NR_MIN)
279279
return false;
280280
}
281-
} else if (is_16bit) {
281+
} else if constexpr (is_16bit) {
282282
if (param->ROW_INTERLEAVE != 2)
283283
return false;
284284

@@ -296,11 +296,11 @@ bool isValidBlockingFactor(const BlockingFactors* const param) {
296296
if (param->NCB % param->NR)
297297
return false;
298298
if (isZmm(iset)) {
299-
if (is_32bit) {
299+
if constexpr (is_32bit) {
300300
// Zmm register usage for C
301301
if (param->MR * (param->NR / param->NR_MIN) > 28)
302302
return false;
303-
} else if (is_16bit) {
303+
} else if constexpr (is_16bit) {
304304
// Zmm register usage for C + one row for loading B
305305
if ((param->MR * (param->NR / param->NR_MIN) +
306306
(param->NR / param->NR_MIN)) > 28)

src/DirectConv.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ class DirectConvCodeGenBase {
124124
<< "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB)
125125
<< "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR)
126126
<< "_NR-" + std::to_string(NR);
127-
if (instSet == inst_set_t::avx512_vnni) {
127+
if constexpr (instSet == inst_set_t::avx512_vnni) {
128128
oss << "_avx512vnni";
129-
} else if (instSet == inst_set_t::avx512) {
129+
} else if constexpr (instSet == inst_set_t::avx512) {
130130
oss << "_avx512";
131-
} else if (instSet == inst_set_t::avx512_ymm) {
131+
} else if constexpr (instSet == inst_set_t::avx512_ymm) {
132132
oss << "_avx512_ymm";
133-
} else if (instSet == inst_set_t::avx2) {
133+
} else if constexpr (instSet == inst_set_t::avx2) {
134134
oss << "_avx2";
135135
}
136136
oss << ".txt";

src/EmbeddingSpMDM.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ GenEmbeddingSpMDMLookup<
354354
asmjit::FuncFrame frame;
355355
frame.init(func);
356356

357-
if (instSet == inst_set_t::avx2) {
357+
if constexpr (instSet == inst_set_t::avx2) {
358358
frame.setDirtyRegs(
359359
asmjit::RegGroup::kVec,
360360
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
@@ -468,7 +468,7 @@ GenEmbeddingSpMDMLookup<
468468
}
469469

470470
if (remainder) {
471-
if (instSet == inst_set_t::avx2) {
471+
if constexpr (instSet == inst_set_t::avx2) {
472472
a->vmovups(
473473
mask_vreg,
474474
x86::ymmword_ptr(
@@ -524,7 +524,7 @@ GenEmbeddingSpMDMLookup<
524524

525525
// OK to use vreg0 because it's for out_vreg used in the main loop
526526
vec_reg_t temp_vreg(0);
527-
if (instSet == inst_set_t::avx2) {
527+
if constexpr (instSet == inst_set_t::avx2) {
528528
a->mov(scratchReg1_, 1);
529529
a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_);
530530
a->cvtsi2ss(temp_vreg.xmm(), lengths_R_);
@@ -752,7 +752,7 @@ GenEmbeddingSpMDMLookup<
752752
a->vfmadd231ps(out_vreg, src_vreg, scale_vreg);
753753
} else if (is_16bit_in) {
754754
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
755-
if (instSet == inst_set_t::avx2) {
755+
if constexpr (instSet == inst_set_t::avx2) {
756756
if (remainder % 2 == 0) {
757757
a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr);
758758
} else {
@@ -819,7 +819,7 @@ GenEmbeddingSpMDMLookup<
819819
}
820820
if (has_weight) {
821821
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
822-
if (instSet == inst_set_t::avx2) {
822+
if constexpr (instSet == inst_set_t::avx2) {
823823
a->vfmadd231ps(out_vreg, w_vreg, src_vreg);
824824
} else {
825825
a->k(x86::k(1)).vfmadd231ps(out_vreg, w_vreg, src_addr);
@@ -829,7 +829,7 @@ GenEmbeddingSpMDMLookup<
829829
}
830830
} else {
831831
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
832-
if (instSet == inst_set_t::avx2) {
832+
if constexpr (instSet == inst_set_t::avx2) {
833833
a->vaddps(out_vreg, out_vreg, src_vreg);
834834
} else {
835835
a->k(x86::k(1)).vaddps(out_vreg, out_vreg, src_addr);
@@ -864,7 +864,7 @@ GenEmbeddingSpMDMLookup<
864864

865865
if constexpr (std::is_same_v<outType, float>) {
866866
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
867-
if (instSet == inst_set_t::avx2) {
867+
if constexpr (instSet == inst_set_t::avx2) {
868868
a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm());
869869
} else {
870870
a->k(x86::k(1)).vmovups(dst_addr, out_vreg);
@@ -874,7 +874,7 @@ GenEmbeddingSpMDMLookup<
874874
}
875875
} else {
876876
// fp16/bf16 output
877-
if (instSet == inst_set_t::avx2) {
877+
if constexpr (instSet == inst_set_t::avx2) {
878878
// round nearest with no exception
879879
if (is_fp16_out) {
880880
a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8);

src/EmbeddingSpMDMAutovec.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,9 @@ static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
736736
float* out,
737737
const bool is_weight_positional,
738738
const bool use_offsets) {
739-
bool is8bit = std::is_same_v<InType, uint8_t>;
739+
constexpr bool is8bit = std::is_same_v<InType, uint8_t>;
740740

741-
if (is8bit) {
741+
if constexpr (is8bit) {
742742
// block_size is the number of elements and fused_block_size is the size
743743
// of an entire row, including scale and bias.
744744
const auto scale_bias_offset = 2 * sizeof(float);

src/EmbeddingSpMDMNBit.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ GenEmbeddingSpMDMNBitLookup<
285285
++reg_id;
286286
x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15
287287
x86::Gp scratchReg3_;
288-
if (instSet == inst_set_t::avx2) {
288+
if constexpr (instSet == inst_set_t::avx2) {
289289
scratchReg3_ = a->zax();
290290
}
291291

@@ -470,7 +470,7 @@ GenEmbeddingSpMDMNBitLookup<
470470
unroll_factor = unroll_factor / 4 * 4;
471471

472472
if (remainder) {
473-
if (instSet == inst_set_t::avx2) {
473+
if constexpr (instSet == inst_set_t::avx2) {
474474
a->vmovups(
475475
mask_vreg,
476476
x86::ymmword_ptr(
@@ -496,7 +496,7 @@ GenEmbeddingSpMDMNBitLookup<
496496
}
497497

498498
if (remainder_32bit_granularity) {
499-
if (instSet == inst_set_t::avx2) {
499+
if constexpr (instSet == inst_set_t::avx2) {
500500
a->lea(
501501
x86::rsp,
502502
x86::dword_ptr(
@@ -548,7 +548,7 @@ GenEmbeddingSpMDMNBitLookup<
548548
a->jl(IfLengthsEnd);
549549

550550
vec_reg_t temp_vreg0(0);
551-
if (instSet == inst_set_t::avx2) {
551+
if constexpr (instSet == inst_set_t::avx2) {
552552
a->mov(scratchReg1_, 1);
553553
a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_);
554554
a->cvtsi2ss(temp_vreg0.xmm(), lengths_R_);
@@ -755,7 +755,7 @@ GenEmbeddingSpMDMNBitLookup<
755755
if (bit_rate == 4) {
756756
if (num_vec_regs_per_block - (vec_idx + v) < 4 &&
757757
remainder_32bit_granularity) {
758-
if (instSet == inst_set_t::avx512) {
758+
if constexpr (instSet == inst_set_t::avx512) {
759759
a->k(x86::k(2)).vmovups(src_vreg.ymm(), src_addr);
760760
} else {
761761
a->vpmaskmovd(src_vreg.xmm(), mask2_vreg.xmm(), src_addr);
@@ -765,7 +765,7 @@ GenEmbeddingSpMDMNBitLookup<
765765
a->vpmovzxbw(src_vreg, src_addr);
766766
}
767767
a->vpslld(temp_vreg, src_vreg, asmjit::Imm(4));
768-
if (instSet == inst_set_t::avx512) {
768+
if constexpr (instSet == inst_set_t::avx512) {
769769
a->vpord(src_vreg, src_vreg, temp_vreg);
770770
a->vpandd(src_vreg, src_vreg, extract_mask_vreg);
771771
} else {
@@ -776,7 +776,7 @@ GenEmbeddingSpMDMNBitLookup<
776776
} else {
777777
if (num_vec_regs_per_block - (vec_idx + v) < 4 &&
778778
remainder_32bit_granularity) {
779-
if (instSet == inst_set_t::avx512) {
779+
if constexpr (instSet == inst_set_t::avx512) {
780780
a->k(x86::k(2)).vmovups(src_vreg.xmm(), src_addr);
781781
a->vpmovzxbd(src_vreg, src_vreg.xmm());
782782
} else {
@@ -788,13 +788,13 @@ GenEmbeddingSpMDMNBitLookup<
788788
}
789789
a->vpslld(temp_vreg, src_vreg, 2 * 8 + 2);
790790
a->vpslld(temp2_vreg, src_vreg, 8 + 4);
791-
if (instSet == inst_set_t::avx512) {
791+
if constexpr (instSet == inst_set_t::avx512) {
792792
a->vpord(temp_vreg, temp_vreg, temp2_vreg);
793793
} else {
794794
a->vpor(temp_vreg.ymm(), temp_vreg.ymm(), temp2_vreg.ymm());
795795
}
796796
a->vpslld(temp2_vreg, src_vreg, 6);
797-
if (instSet == inst_set_t::avx512) {
797+
if constexpr (instSet == inst_set_t::avx512) {
798798
a->vpord(temp_vreg, temp_vreg, temp2_vreg);
799799
a->vpord(src_vreg, temp_vreg, src_vreg);
800800
a->vpandd(src_vreg, src_vreg, extract_mask_vreg);
@@ -817,11 +817,11 @@ GenEmbeddingSpMDMNBitLookup<
817817
if (i == 0) {
818818
a->vpmovsxbd(temp_vreg, src_vreg.xmm());
819819
// this is only needed for avx2
820-
if (instSet == inst_set_t::avx2) {
820+
if constexpr (instSet == inst_set_t::avx2) {
821821
a->vmovups(temp2_vreg, src_vreg);
822822
}
823823
} else {
824-
if (instSet == inst_set_t::avx512) {
824+
if constexpr (instSet == inst_set_t::avx512) {
825825
// We could've used avx512_ymm for clock frequency advantage,
826826
// if there's an instruction to extract a 64-bit portion from
827827
// a YMM as an XMM register.
@@ -868,7 +868,7 @@ GenEmbeddingSpMDMNBitLookup<
868868

869869
if constexpr (std::is_same_v<outType, float>) {
870870
if (remainder && vec_idx + v == num_vec_regs_per_block - 1) {
871-
if (instSet == inst_set_t::avx512) {
871+
if constexpr (instSet == inst_set_t::avx512) {
872872
a->k(x86::k(1)).vmovups(dst_addr, out_vreg);
873873
} else {
874874
a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm());
@@ -878,7 +878,7 @@ GenEmbeddingSpMDMNBitLookup<
878878
}
879879
} else {
880880
// 16-bit output
881-
if (instSet == inst_set_t::avx2) {
881+
if constexpr (instSet == inst_set_t::avx2) {
882882
if (is_bf16_out) {
883883
a->vpaddd(out_vreg, out_vreg, ones_vreg);
884884
a->vpsrld(out_vreg, out_vreg, 16);

src/Fbgemm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ void fbgemmPacked(
209209

210210
template <int SPATIAL_DIM>
211211
bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
212-
if (SPATIAL_DIM == 1)
212+
if constexpr (SPATIAL_DIM == 1)
213213
return false;
214214

215215
int C_per_G = conv_p.IC / conv_p.G;

0 commit comments

Comments
 (0)