Skip to content

update float8 readme with more recent performance numbers #2580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,30 +97,32 @@ on using `torchao.float8` in a distributed setting.

A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the tables below for a microbenchmark based speedup estimate on NVIDIA H100:

### Tensorwise scaling
### tensorwise scaling

<img width="805" alt="float8_speedup" src="https://github.com/user-attachments/assets/5c5f2817-7eb7-4cab-bd03-49fe70cd31a8">
<img width="753" height="773" alt="Image" src="https://github.com/user-attachments/assets/e46c671a-ed35-41b4-b17c-50caf1629ecb" />

Example 1 (small shapes):
* forward input tensor size 1024x2048, linear weight size 2048x1024; M, K, N = 1024, 2048, 1024
* benchmark speedup is 0.80
* recommendation: leave this linear in bfloat16, the shapes are too small to benefit from float8 compute
```lang=shell
# reproduction: run the script below
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep
```

Example 2 (large shapes):
* forward input tensor size 4096x8192, linear weight size 8192x16384; M, K, N = 4096, 8192, 16384
* benchmark speedup is 1.39
* recommendation: enable float8 for this linear to get a speedup
### rowwise scaling

To reproduce the raw data for table above, you can run the following script
<img width="755" height="778" alt="Image" src="https://github.com/user-attachments/assets/7d70ba36-f480-459f-b5c0-797895332631" />

```lang=shell
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep
# reproduction: run the script below
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep --float8_recipe_name rowwise
```

### Rowwise scaling
### rowwise_with_gw_hp scaling

<img width="805" alt="float8_rowwise_speedup" src="../../docs/static/fp8-rowwise-perf.png" />
<img width="750" height="797" alt="Image" src="https://github.com/user-attachments/assets/e4479abc-1aca-436d-a142-60e5e804ff10" />

```lang=shell
# reproduction: run the script below
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep --float8_recipe_name rowwise_with_gw_hp
```

## Derivation

Expand Down
Loading