Skip to content

Commit 4551ff7

Browse files
q10facebook-github-bot
authored andcommitted
Back out "Support optimizer state offloading for partial rowwise adam optimizer" (#4478)
Summary: Pull Request resolved: #4478 X-link: facebookresearch/FBGEMM#1535 Original commit changeset: 3ae8ba678d5d Original Phabricator Diff: D76491848 Needs to be backed out bc the code to enable building the backend was not yet available ``` SSD_OPTIMIZERS = [ "rowwise_adagrad", "partial_rowwise_adam", ] ``` needs to be added first to the build. Reviewed By: cthi, bobbyliujb Differential Revision: D78176561 fbshipit-source-id: d997b735ea2b37825eeb26cebaae153e5a697b61
1 parent ed9ea0c commit 4551ff7

File tree

1 file changed

+5
-39
lines changed

1 file changed

+5
-39
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,58 +1120,24 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11201120
"""
11211121
)
11221122
split_precomputation += """
1123-
1124-
// Define the optimizer state (for use with optimizer offloading)
1125-
struct OptimizerState {
1126-
// momentum2 is a single value so it will be accessed directly as a struct field
1127-
momentum2_ph_t momentum2;
1128-
1129-
// momentum1 is an array of D values, so a method to return a pointer given the offset is defined instead
1130-
DEVICE_INLINE momentum1_ph_t* momentum1_ptr() {
1131-
// Re-cast the address to momentum1_ph_t* and return
1132-
return reinterpret_cast<momentum1_ph_t *>(
1133-
// Cast the address this to momentum2_t* and increment by 1 to skip over the momentum2 value
1134-
reinterpret_cast<momentum2_ph_t *>(this) + 1
1135-
);
1136-
}
1137-
};
1138-
1139-
// Fetch the pointer to the optimizer state along the cache row
1140-
[[maybe_unused]] auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
1141-
1142-
// Fetch the pointer to the momentum1 value
1143-
// Define the fetch here instead of in split_weight_update to avoid conditionals inside a loop
1144-
auto* momentum1_start = enable_optimizer_offloading ?
1145-
(optimizer->momentum1_ptr()) :
1146-
(&momentum1[idx * D]);
1147-
11481123
const at::acc_type<cache_t, true> g_avg_square =
11491124
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
11501125
11511126
at::acc_type<cache_t, true> v_hat_t;
11521127
v_hat_t = 0.0;
11531128
if (threadIdx.x == 0) {
1154-
auto v_t = g_avg_square * (1.0 - beta2);
1155-
1156-
if (enable_optimizer_offloading) {
1157-
v_t += optimizer->momentum2 * beta2;
1158-
optimizer->momentum2 = v_t;
1159-
} else {
1160-
v_t += momentum2[idx] * beta2;
1161-
momentum2[idx] = v_t;
1162-
}
1163-
1129+
at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2);
1130+
momentum2[idx] = v_t;
11641131
v_hat_t = v_t / (1.0 - powf(beta2, iter));
11651132
}
11661133
v_hat_t = SHFL_SYNC(v_hat_t, 0);
11671134
"""
11681135

11691136
split_weight_update = """
1170-
auto* momentum1_ptr = momentum1_start + d;
1171-
Vec4T<momentum1_ph_t> m_t(momentum1_ptr);
1137+
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
11721138
m_t.mul_(beta1);
11731139
m_t.fma_(grad, 1.0 - beta1);
1174-
m_t.store(momentum1_ptr);
1140+
m_t.store(&momentum1[idx * D + d]);
11751141
11761142
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
11771143
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
@@ -1213,7 +1179,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
12131179
"has_gpu_support": True,
12141180
"has_vbe_support": False,
12151181
"has_global_weight_decay_support": False,
1216-
"has_ssd_support": True,
1182+
"has_ssd_support": False,
12171183
}
12181184

12191185

0 commit comments

Comments
 (0)