Skip to content

Commit fe9946e

Browse files
q10facebook-github-bot
authored andcommitted
Optimize some code out of compilation in the table lookup kernel (#4371)
Summary: Pull Request resolved: #4371 X-link: facebookresearch/FBGEMM#1440 - Optimize some code out of compilation in the table lookup kernel Reviewed By: spcyppt Differential Revision: D76865732 fbshipit-source-id: a90c14567ea8899a1f7ffc43988bfba720507b6e
1 parent 311f6f9 commit fe9946e

File tree

4 files changed

+8
-17
lines changed

4 files changed

+8
-17
lines changed

fbgemm_gpu/codegen/genscript/jinja_environment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def generate_optimized_grad_sum_loop_access(
111111
smem_blob = blob.format(grad_vec="smem_grad_sum[d_vec]")
112112
reg_blob = blob.format(grad_vec="grad_sum[vec]")
113113
gen_blob = """
114-
if (kUseVecBlocking) {
114+
if constexpr (kUseVecBlocking) {
115115
// max_vecs is not known at compile time
116116
for (int32_t vec = 0;
117117
vec < max_vecs &&
@@ -121,8 +121,8 @@ def generate_optimized_grad_sum_loop_access(
121121
[[maybe_unused]] const int32_t d = d_vec * VEC_WIDTH;
122122
{smem_blob}
123123
}
124-
}
125-
else {
124+
125+
} else {
126126
// kFixedMaxVecsPerThread is known at compile time
127127
#pragma unroll kFixedMaxVecsPerThread
128128
for (int32_t vec = 0;

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,18 +1053,12 @@ def adam() -> Dict[str, Any]:
10531053

10541054
split_weight_update = """
10551055
Vec4T<cache_t> m_t(&momentum1[idx * D + d]);
1056-
m_t.acc.x *= beta1;
1057-
m_t.acc.y *= beta1;
1058-
m_t.acc.z *= beta1;
1059-
m_t.acc.w *= beta1;
1056+
m_t.mul_(beta1);
10601057
m_t.fma_(grad, 1.0 - beta1);
10611058
m_t.store(&momentum1[idx * D + d]);
10621059
10631060
Vec4T<cache_t> v_t(&momentum2[idx * D + d]);
1064-
v_t.acc.x *= beta2;
1065-
v_t.acc.y *= beta2;
1066-
v_t.acc.z *= beta2;
1067-
v_t.acc.w *= beta2;
1061+
v_t.mul_(beta2);
10681062
10691063
grad.acc.x *= grad.acc.x;
10701064
grad.acc.y *= grad.acc.y;
@@ -1141,10 +1135,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11411135

11421136
split_weight_update = """
11431137
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
1144-
m_t.acc.x *= beta1;
1145-
m_t.acc.y *= beta1;
1146-
m_t.acc.z *= beta1;
1147-
m_t.acc.w *= beta1;
1138+
m_t.mul_(beta1);
11481139
m_t.fma_(grad, 1.0 - beta1);
11491140
m_t.store(&momentum1[idx * D + d]);
11501141

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ Tensor {{ embedding_cuda_op }}(
11781178
int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize;
11791179
int32_t warp_per_row_smem_bytes = 0;
11801180

1181-
if (kUseVecBlocking) {
1181+
if constexpr (kUseVecBlocking) {
11821182
warp_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes(
11831183
&num_warp_per_row_groups,
11841184
// Use max_D to compute shmem_bytes (for smem_grad_sum)

fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
6969
emb_t* __restrict__ weights {nullptr};
7070
cache_t* __restrict__ cache_weights {nullptr};
7171
int32_t D_emb = D;
72-
if (kIsInt8) {
72+
if constexpr (kIsInt8) {
7373
D_emb += kINT8QparamsBytes;
7474
}
7575
const auto weights_placement = static_cast<PlacementType>(weights_placements[t]);

0 commit comments

Comments
 (0)