Skip to content

Commit b65e513

Browse files
authored
float8 training: move module attribute setting to sync function (#1341)
Summary: This PR moves the setting of `is_amax_initialized` flag on `Float8Linear` to the `sync_float8_amax_and_scale_history` function. There are two reasons for this: 1. the current logic does not work with torchtitan + delayed scaling + AC, failing with https://gist.github.com/vkuzo/70819a2cffb9346bf44ecd9079b8bf51 . 2. in general, stateful logic such as changing module attributes adds complexity. Even if we fix (1) in compile land, something else could break. The `sync_float8_amax_and_scale_history` function is already called outside of the main model forward/backward, it's already required to be called at every iteration, it does not need to know about AC, and it seems like a great place to stash logic which isn't easily compileable such as this init code. After this PR the `enable_amax_init` and `enable_pre_and_post_forward` config options are now no-ops. In a future PR we should add a deprecation warning, and eventually remove these. Test Plan: ``` // this repo ./test/float8/test_everything.sh // torchtitan with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.scaling_type_input delayed --float8.scaling_type_weight delayed --float8.scaling_type_grad_output delayed ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 8b1b168 commit b65e513

File tree

5 files changed

+18
-17
lines changed

5 files changed

+18
-17
lines changed

test/float8/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,14 @@ def _test_linear_impl(
265265
config,
266266
)
267267
for _ in range(2):
268-
if linear_requires_sync(config):
269-
sync_float8_amax_and_scale_history(m_fp8)
270268
if use_ac:
271269
y_fp8 = torch.utils.checkpoint.checkpoint(m_fp8, x, use_reentrant=False)
272270
else:
273271
y_fp8 = m_fp8(x)
274272
y_fp8.sum().backward()
273+
if linear_requires_sync(config):
274+
sync_float8_amax_and_scale_history(m_fp8)
275+
275276
if use_ac:
276277
y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False)
277278
else:

torchao/float8/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ config = Float8LinearConfig(
9595
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
9696
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
9797
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
98-
# enable_amax_init=False, # only needed for autocast + compile + FSDP + float8 delayed
99-
# enable_pre_and_post_forward=False # only needed for autocast + compile + FSDP + float8 delayed
10098
)
10199

102100
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
@@ -111,8 +109,11 @@ for _ in range(10):
111109
y = m(x)
112110
y.sum().backward()
113111

114-
# specific to float8 with delayed scaling: separate step to sync scales/amaxes
115-
# in the future, this may move to a context manager
112+
# Specific to delayed scaling: separate step to sync scales/amaxes.
113+
# On the first call, this function also sets the `is_amax_initialized` flag to
114+
# mark the amax and scale buffers as initialized.
115+
# Make sure you run this after every model forward+backward pass.
116+
# In the future, this may move to a context manager.
116117
sync_float8_amax_and_scale_history(m)
117118

118119
optimizer.step()

torchao/float8/config.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,12 @@ class Float8LinearConfig:
180180
# Per-linear configuration
181181
#
182182

183-
# If True, on the first iteration of Float8Linear the amaxes will be
184-
# initialized with the incoming data. As of 2023-12-30, this doesn't work
185-
# with autocast + torch.compile + FSDP. Enabling this option is nice for
186-
# testing, but this is not necessary for real training jobs.
183+
# This configuration option is deprecated and no longer has an effect. It may
184+
# be removed in a future release.
187185
enable_amax_init: bool = True
188186

189-
# If True, pre-forward and post-forward functions are run. As of 2023-12-30,
190-
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
191-
# option is useful for safety, but not strictly necessary.
187+
# This configuration option is deprecated and no longer has an effect. It may
188+
# be removed in a future release.
192189
enable_pre_and_post_forward: bool = True
193190

194191
# If True, then uses a tensor subclass for the float8 linear module's weight that

torchao/float8/float8_linear.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,6 @@ def float8_post_forward(self):
545545
# config setting
546546
if not self.enable_pre_and_post_forward:
547547
return
548-
self.is_amax_initialized = True
549548

550549
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
551550
has_any_axiswise_scaling = (

torchao/float8/float8_linear_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
193193
and we loop over all fp8_layers to sync and update amax scale histories.
194194
Users can use get_float8_layers to get all fp8 layers.
195195
"""
196+
# TODO(future): consider adding a flag to control setting the `is_amax_initialized`
197+
# flag only on the first iteration.
198+
196199
if fp8_layers is None:
197200
fp8_layers = get_float8_layers(model)
198201

@@ -309,10 +312,10 @@ def inner_func():
309312
child.fp8_scale_weight.copy_(new_weight_scales[idx])
310313
child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx])
311314

312-
# This allows for the compile to succede on the inner func and fail on the graph breaks
315+
# This allows for the compile to succeed on the inner func and fail on the graph breaks
313316
# at the beginning and and of syncing
314317
inner_func()
315318

316319
for child in fp8_layers:
317-
# Set a flag to signal amaxes/scales are ready
318-
child.amax_and_scale_synced = True
320+
# Set a flag to signal that initialization is done
321+
child.is_amax_initialized = True

0 commit comments

Comments
 (0)