You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
make float8 training's force_recompute_fp8_weight_in_bwd flag do nothing (#2356)
Summary:
This PR makes the `Float8LinearConfig.force_recompute_fp8_weight_in_bwd`
flag do nothing and marks it for a future deprecation.
Now that PyTorch Core can handle this logic automatically, we
no longer need the workaround. Please
see #2251 for more context.
Test Plan:
```
./test/float8/test_everything.sh
```
Reviewers:
Subscribers:
Tasks:
Tags:
Copy file name to clipboardExpand all lines: torchao/float8/config.py
+5-20Lines changed: 5 additions & 20 deletions
Original file line number
Diff line number
Diff line change
@@ -192,20 +192,9 @@ class Float8LinearConfig:
192
192
# If True, emulation is used instead of hardware accelerated gemm
193
193
emulate: bool=False
194
194
195
-
# If the option is enabled, fp8_weight will always be re-computed in backward.
196
-
# It's recommended to enable this flag when using FSDP.
197
-
# Otherwise, the entire fp8_weight, instead of the sharded weight may be saved.
198
-
# If using outer activation checkpointing context or SAC, you may disable this option
199
-
# and handle the recomputation of fp8 weight in your customized AC context.
200
-
#
201
-
# Details:
202
-
# When using float8 training with FSDP, the original weight is sharded; fp8_weight (in forward) and fp8_weight_transpose (in backward) are used by the model.
203
-
# However, when partitioning the forward_backward graph, torch.compile may decide to
204
-
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
205
-
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
206
-
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.
207
-
# TODO(future PR): either enable by default or have a warning and set up the
208
-
# tests so that the warning does not spam the CI stdout.
195
+
# This flag is deprecated and currently has no effect. It will be removed
196
+
# in a future release. Please see https://github.com/pytorch/ao/issues/2251
197
+
# for more context.
209
198
force_recompute_fp8_weight_in_bwd: bool=False
210
199
211
200
# If this option is enabled, the scaling factor used for float8 quantization
@@ -278,13 +267,9 @@ def __post_init__(self):
278
267
f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
279
268
)
280
269
281
-
# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
282
-
if (
283
-
self.enable_fsdp_float8_all_gather
284
-
andnotself.force_recompute_fp8_weight_in_bwd
285
-
):
270
+
ifself.force_recompute_fp8_weight_in_bwd:
286
271
logger.warning(
287
-
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
272
+
"`config.force_recompute_fp8_weight_in_bwd` is deprecated and will be removed in a future release. Please see https://github.com/pytorch/ao/issues/2251 for more details."
0 commit comments