Skip to content

Commit d01549a

Browse files
q10facebook-github-bot
authored andcommitted
Support optimizer state offloading for partial rowwise adam optimizer (#4477)
Summary: Pull Request resolved: #4477 X-link: facebookresearch/FBGEMM#1534 Support optimizer state offloading for partial rowwise adam optimizer in the backend C++ code. This does not yet expose support in the frontend Python code, which requires a lot more code changes. The existing non-offloading codepath should not be affected by the changes. This is a re-land of D76491848, but with the backend code enabled instead of the frontend, which was breaking downstream compatibility tests Reviewed By: bobbyliujb, cthi Differential Revision: D78177062 fbshipit-source-id: 72f636d7231409750c5f4d5a6ddfab32c33abbf1
1 parent b3052b7 commit d01549a

File tree

1 file changed

+39
-5
lines changed

1 file changed

+39
-5
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,24 +1120,58 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11201120
"""
11211121
)
11221122
split_precomputation += """
1123+
1124+
// Define the optimizer state (for use with optimizer offloading)
1125+
struct OptimizerState {
1126+
// momentum2 is a single value so it will be accessed directly as a struct field
1127+
momentum2_ph_t momentum2;
1128+
1129+
// momentum1 is an array of D values, so a method to return a pointer given the offset is defined instead
1130+
DEVICE_INLINE momentum1_ph_t* momentum1_ptr() {
1131+
// Re-cast the address to momentum1_ph_t* and return
1132+
return reinterpret_cast<momentum1_ph_t *>(
1133+
// Cast the address this to momentum2_t* and increment by 1 to skip over the momentum2 value
1134+
reinterpret_cast<momentum2_ph_t *>(this) + 1
1135+
);
1136+
}
1137+
};
1138+
1139+
// Fetch the pointer to the optimizer state along the cache row
1140+
[[maybe_unused]] auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
1141+
1142+
// Fetch the pointer to the momentum1 value
1143+
// Define the fetch here instead of in split_weight_update to avoid conditionals inside a loop
1144+
auto* momentum1_start = enable_optimizer_offloading ?
1145+
(optimizer->momentum1_ptr()) :
1146+
(&momentum1[idx * D]);
1147+
11231148
const at::acc_type<cache_t, true> g_avg_square =
11241149
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
11251150
11261151
at::acc_type<cache_t, true> v_hat_t;
11271152
v_hat_t = 0.0;
11281153
if (threadIdx.x == 0) {
1129-
at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2);
1130-
momentum2[idx] = v_t;
1154+
auto v_t = g_avg_square * (1.0 - beta2);
1155+
1156+
if (enable_optimizer_offloading) {
1157+
v_t += optimizer->momentum2 * beta2;
1158+
optimizer->momentum2 = v_t;
1159+
} else {
1160+
v_t += momentum2[idx] * beta2;
1161+
momentum2[idx] = v_t;
1162+
}
1163+
11311164
v_hat_t = v_t / (1.0 - powf(beta2, iter));
11321165
}
11331166
v_hat_t = SHFL_SYNC(v_hat_t, 0);
11341167
"""
11351168

11361169
split_weight_update = """
1137-
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
1170+
auto* momentum1_ptr = momentum1_start + d;
1171+
Vec4T<momentum1_ph_t> m_t(momentum1_ptr);
11381172
m_t.mul_(beta1);
11391173
m_t.fma_(grad, 1.0 - beta1);
1140-
m_t.store(&momentum1[idx * D + d]);
1174+
m_t.store(momentum1_ptr);
11411175
11421176
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
11431177
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
@@ -1179,7 +1213,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11791213
"has_gpu_support": True,
11801214
"has_vbe_support": False,
11811215
"has_global_weight_decay_support": False,
1182-
"has_ssd_support": False,
1216+
"has_ssd_support": True,
11831217
}
11841218

11851219

0 commit comments

Comments
 (0)