You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
float8 training: fix bug with AC + compile (#1329)
Summary:
In #1306 I accidentally broke
torchtitan + float8 + AC + compile.
I don't have a non-torchtitan repro now, putting up the fix first
to ensure torchtitan still works, and we should follow-up later
with adding test coverage to torchao to prevent similar breakages in the
future.
What broke:
* in the forward of `Float8Linear`, we were setting an attribute on
the module
* ^ is not supported with compile + something how torchtitan
specifically calls AC
The fix: remove this attribute setting altogether. Unfortunately this
breaks an edge case feature for ensuring scales are reprensentable in
`float16`. Since `float16` training is not commonly used with `float8`
and this feature was added during very early testing, removing this for
now is fine.
If we need to add this feature back in the future, I'd advocate for
doing it via explicit configuration such as `config.set_scale_upper_bound`
and avoiding the stateful hacks, which are usually not compiler
friendly.
Test Plan:
```
// this repo
./test/float8/test_everything.sh
// torchtitan - broken before this PR, works after this PR
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
```
Reviewers:
Subscribers:
Tasks:
Tags:
0 commit comments