Skip to content

Commit 03c850a

Browse files
authored
make float8 training's force_recompute_fp8_weight_in_bwd flag do nothing (#2356)
Summary: This PR makes the `Float8LinearConfig.force_recompute_fp8_weight_in_bwd` flag do nothing and marks it for a future deprecation. Now that PyTorch Core can handle this logic automatically, we no longer need the workaround. Please see #2251 for more context. Test Plan: ``` ./test/float8/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
1 parent dd22777 commit 03c850a

File tree

3 files changed

+15
-42
lines changed

3 files changed

+15
-42
lines changed

benchmarks/float8/training/torchtitan_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ fi
2929
# validate recipe name
3030
if [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ]; then
3131
if [ "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" == "tensorwise" ]; then
32-
FLOAT8_ARGS="--model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd"
32+
FLOAT8_ARGS="--model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp"
3333
else
3434
FLOAT8_ARGS="--model.converters="float8" --float8.recipe_name=${FLOAT8_RECIPE_WITH_BEST_SETTINGS}"
3535
fi

torchao/float8/config.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,9 @@ class Float8LinearConfig:
192192
# If True, emulation is used instead of hardware accelerated gemm
193193
emulate: bool = False
194194

195-
# If the option is enabled, fp8_weight will always be re-computed in backward.
196-
# It's recommended to enable this flag when using FSDP.
197-
# Otherwise, the entire fp8_weight, instead of the sharded weight may be saved.
198-
# If using outer activation checkpointing context or SAC, you may disable this option
199-
# and handle the recomputation of fp8 weight in your customized AC context.
200-
#
201-
# Details:
202-
# When using float8 training with FSDP, the original weight is sharded; fp8_weight (in forward) and fp8_weight_transpose (in backward) are used by the model.
203-
# However, when partitioning the forward_backward graph, torch.compile may decide to
204-
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
205-
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
206-
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.
207-
# TODO(future PR): either enable by default or have a warning and set up the
208-
# tests so that the warning does not spam the CI stdout.
195+
# This flag is deprecated and currently has no effect. It will be removed
196+
# in a future release. Please see https://github.com/pytorch/ao/issues/2251
197+
# for more context.
209198
force_recompute_fp8_weight_in_bwd: bool = False
210199

211200
# If this option is enabled, the scaling factor used for float8 quantization
@@ -278,13 +267,9 @@ def __post_init__(self):
278267
f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
279268
)
280269

281-
# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
282-
if (
283-
self.enable_fsdp_float8_all_gather
284-
and not self.force_recompute_fp8_weight_in_bwd
285-
):
270+
if self.force_recompute_fp8_weight_in_bwd:
286271
logger.warning(
287-
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
272+
"`config.force_recompute_fp8_weight_in_bwd` is deprecated and will be removed in a future release. Please see https://github.com/pytorch/ao/issues/2251 for more details."
288273
)
289274

290275
@staticmethod

torchao/float8/float8_linear.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Optional
1111

1212
import torch
13-
import torch.utils.checkpoint as checkpoint
1413

1514
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1615
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
@@ -325,29 +324,18 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
325324
# TODO(future PR): check for axiswise scaling for input, weight,
326325
# grad_output separately instead of together
327326
if not has_any_axiswise_scaling:
328-
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
329-
# weight_scale should be saved.
327+
# TODO(future PR): now that `force_recompute_fp8_weight_in_bwd` is
328+
# deprecated, we can simplify the below code and unify the per-tensor
329+
# and per-axis paths further.
330330
weight_scale = _get_weight_scale(
331331
self.weight, self.scaling_type_weight, self.config
332332
)
333-
334-
if self.config.force_recompute_fp8_weight_in_bwd:
335-
weight_fp8_t = checkpoint.checkpoint(
336-
_cast_weight_to_float8_t,
337-
self.weight,
338-
self.config,
339-
self.linear_mm_config,
340-
weight_scale,
341-
)
342-
else:
343-
weight_fp8_t = _cast_weight_to_float8_t(
344-
self.weight,
345-
self.config,
346-
self.linear_mm_config,
347-
weight_scale,
348-
)
349-
350-
weight_maybe_fp8_t = weight_fp8_t
333+
weight_maybe_fp8_t = _cast_weight_to_float8_t(
334+
self.weight,
335+
self.config,
336+
self.linear_mm_config,
337+
weight_scale,
338+
)
351339

352340
output = matmul_with_hp_or_float8_args.apply(
353341
input,

0 commit comments

Comments
 (0)