Skip to content

Commit f3c1a00

Browse files
authored
float8 training: fix bug with AC + compile (#1329)
Summary: In #1306 I accidentally broke torchtitan + float8 + AC + compile. I don't have a non-torchtitan repro now, putting up the fix first to ensure torchtitan still works, and we should follow-up later with adding test coverage to torchao to prevent similar breakages in the future. What broke: * in the forward of `Float8Linear`, we were setting an attribute on the module * ^ is not supported with compile + something how torchtitan specifically calls AC The fix: remove this attribute setting altogether. Unfortunately this breaks an edge case feature for ensuring scales are reprensentable in `float16`. Since `float16` training is not commonly used with `float8` and this feature was added during very early testing, removing this for now is fine. If we need to add this feature back in the future, I'd advocate for doing it via explicit configuration such as `config.set_scale_upper_bound` and avoiding the stateful hacks, which are usually not compiler friendly. Test Plan: ``` // this repo ./test/float8/test_everything.sh // torchtitan - broken before this PR, works after this PR with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 8f73e84 commit f3c1a00

File tree

4 files changed

+8
-36
lines changed

4 files changed

+8
-36
lines changed

torchao/float8/float8_linear.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,6 @@ def __init__(self, *args, **kwargs):
335335
# TODO(future PR): add serialization for this flag
336336
self.is_amax_initialized = not self.config.enable_amax_init
337337

338-
# This is needed to properly handle autocast in the amax/scale
339-
# update function for torch.float16
340-
self.last_seen_output_dtype = None
341-
342338
# pre_forward and post_forward are currently broken with FSDP
343339
# and torch.compile, this option can disable them
344340
# Note that when using `self.config.enable_pre_and_post_forward = False`,
@@ -628,7 +624,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
628624

629625
if self.has_any_delayed_scaling:
630626
self.float8_post_forward()
631-
self.last_seen_output_dtype = output.dtype
632627
return output
633628

634629
def extra_repr(self):

torchao/float8/float8_linear_utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ def inner_func():
224224
fp8_weight_amax_history_stack = [None] * len(fp8_layers)
225225
fp8_grad_output_amax_history_stack = [None] * len(fp8_layers)
226226

227-
x_dtypes = set()
228227
scale_fn_recipes = set()
229228

230229
for idx, child in enumerate(fp8_layers):
@@ -236,16 +235,8 @@ def inner_func():
236235
fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight
237236
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output
238237

239-
x_dtypes.add(child.last_seen_output_dtype)
240238
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)
241239

242-
# TODO This way to get the activation dtype is not ideal
243-
if len(x_dtypes) != 1:
244-
raise ValueError(
245-
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
246-
)
247-
x_dtype = next(iter(x_dtypes))
248-
249240
if len(scale_fn_recipes) != 1:
250241
raise ValueError(
251242
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
@@ -303,13 +294,13 @@ def inner_func():
303294

304295
# Calculate the new scales from the updated history stacks
305296
new_input_scales = amax_history_to_scale_stack(
306-
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
297+
fp8_input_amax_history_stack, e4m3_dtype, scale_fn_recipe
307298
)
308299
new_weight_scales = amax_history_to_scale_stack(
309-
fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
300+
fp8_weight_amax_history_stack, e4m3_dtype, scale_fn_recipe
310301
)
311302
new_grad_output_scales = amax_history_to_scale_stack(
312-
fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe
303+
fp8_grad_output_amax_history_stack, e5m2_dtype, scale_fn_recipe
313304
)
314305

315306
# Iterate through the layers and update the scales

torchao/float8/float8_scaling_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
177177
new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
178178
cur_amax.fill_(new_amax)
179179
amax_history[0] = new_amax
180-
new_scale = amax_history_to_scale(
181-
amax_history, float8_dtype, x.dtype, scale_fn_name
182-
)
180+
new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name)
183181
scale.copy_(new_scale)
184182

185183

torchao/float8/float8_utils.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,11 @@
3434

3535

3636
@torch.no_grad()
37-
def amax_to_scale(
38-
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
39-
):
37+
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
4038
"""Converts the amax value of a tensor to the fp8 scale.
4139
Args:
4240
amax: The amax value of the tensor.
4341
float8_dtype: The float8 dtype.
44-
orig_dtype: The original dtype of the tensor.
4542
"""
4643
# torch.compile and eager show different numerics for 1.0 / float32,
4744
# upcast to float64 to ensure same numeric between compile and eager
@@ -51,51 +48,42 @@ def amax_to_scale(
5148
else:
5249
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
5350

54-
# Ensure that the scale is representable in float16,
55-
# this helps when amax is small. We are assuming that we don't need
56-
# to care about this for float32/bfloat16.
57-
if orig_dtype is torch.float16:
58-
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
5951
return res.to(torch.float32)
6052

6153

6254
@torch.no_grad()
6355
def amax_history_to_scale(
6456
amax_history: torch.Tensor,
6557
float8_dtype: torch.Tensor,
66-
orig_dtype: torch.dtype,
6758
history_to_scale_fn_type: Literal["max"],
6859
):
6960
"""Takes in a history of amax values and returns a scale tensor.
7061
Args:
7162
amax_history: A tensor containing the history of amax values.
7263
float8_dtype: The float8 dtype.
73-
orig_dtype: The original dtype of the tensor.
7464
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
7565
"""
7666
if history_to_scale_fn_type == "max":
7767
amax = torch.max(amax_history)
78-
return amax_to_scale(amax, float8_dtype, orig_dtype)
68+
return amax_to_scale(amax, float8_dtype)
7969
raise NotImplementedError()
8070

8171

8272
@torch.no_grad()
8373
def amax_history_to_scale_stack(
8474
amax_history: torch.Tensor,
8575
float8_dtype: torch.dtype,
86-
orig_dtype: torch.dtype,
8776
history_to_scale_fn_type: Literal["max"],
8877
) -> torch.Tensor:
8978
"""Takes in a stack of amax_history tensors and returns a scale tensor.
9079
Args:
9180
amax_history: A 2D tensor containing a stack of amax histories.
9281
float8_dtype: The float8 dtype.
93-
orig_dtype: The original dtype of the tensor.
9482
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
9583
"""
9684
if history_to_scale_fn_type == "max":
9785
amax_stack = torch.max(amax_history, dim=1).values
98-
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
86+
return amax_to_scale(amax_stack, float8_dtype)
9987
raise NotImplementedError(
10088
f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
10189
)
@@ -142,7 +130,7 @@ def tensor_to_scale(
142130
scaling_granularity,
143131
axiswise_dim,
144132
)
145-
return amax_to_scale(amax, float8_dtype, x.dtype)
133+
return amax_to_scale(amax, float8_dtype)
146134

147135

148136
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):

0 commit comments

Comments
 (0)