Skip to content

add float8 support #1378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: autoparallel
Choose a base branch
from
Open

add float8 support #1378

wants to merge 9 commits into from

Conversation

bdhirsh
Copy link

@bdhirsh bdhirsh commented Jul 10, 2025

repro command:

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --model.converters="float8"

my integration uses an existing float8 util that just replaces any linear layers in the user model with Float8Linear layers.

The good news

we don't actually need special subclass handling in autoparallel yet! The model weights are stored in high precision and gets quantized dynamically during the forward, so we have no subclass params in the state dict. After talking to Driss/Daniel, this is true for both float8 compute, and float8 allgathers

The bad news

The first thing that breaks is that _scaled_mm raises an error because it rquires its second arg to be column-major, and autoparallel accidentally makes all intermediate tensors contiguous during FlopCounter estimation. fix here: https://github.com/pytorch-labs/autoparallel/pull/37

After that, there are 3 problems. First, I ran the above repro on Francisco's deepseek branch, with my PR patched in above: https://github.com/pytorch-labs/autoparallel/pull/29

(1) i needed to patch this:

diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py
index f9b2de1..d4d347b 100644
--- a/autoparallel/compute_estimation.py
+++ b/autoparallel/compute_estimation.py
@@ -43,6 +43,7 @@ DEVICE_LIMITS: Tuple[DeviceLimit, ...] = (
             # but we want the full GEMM numbers
             torch.float32: 989 // 2,
             torch.float16: 1979 // 2,
+            torch.float8_e4m3fn: 3958 // 2,
             torch.bfloat16: 1979 // 2,
             torch.int8: 3958 // 2,
         },

(2) it fails with RuntimeError: Expected both dimensions of mat2 to be divisble by 16 but got torch.Size([256, 564]). This looks like the same issue that @wconstab has been looking into, where we are trying to use an invalid sharding.

(3) there are some ops without sharding rules:

aten.clamp_max.default
aten.clamp_min.default

wconstab and others added 9 commits June 16, 2025 12:32
TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop
"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""
Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.
basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now
Relying on pytorch-labs/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 10, 2025
@tianyu-l
Copy link
Contributor

Is this considering FSDP at all?

Previously we tried this for SimpleFSDP, but found the loss weren't close versus not applying Float8. Could it be that there's extra code needed when integrating Float8 with FSDP?
Anyways, @mori360 and I are looking at SimpleFSDP + float8. We can chat more when you are back.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants