Skip to content

Commit fdddb2c

Browse files
[BE/docs] Add fp8 rowwise perf table to float8 training readme (#2312)
* add fp8 rowwise perf table * add fp8 rowwise perf table to readme
1 parent d6cfdad commit fdddb2c

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

docs/static/fp8-rowwise-perf.png

336 KB
Loading

torchao/float8/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ on using `torchao.float8` in a distributed setting.
132132

133133
# Performance
134134

135-
A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:
135+
A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the tables below for a microbenchmark based speedup estimate on NVIDIA H100:
136+
137+
### Tensorwise scaling
136138

137139
<img width="805" alt="float8_speedup" src="https://github.com/user-attachments/assets/5c5f2817-7eb7-4cab-bd03-49fe70cd31a8">
138140

@@ -152,6 +154,11 @@ To reproduce the raw data for table above, you can run the following script
152154
python benchmarks/float8/float8_roofline.py your_output_filename.csv --shape_gen_name sweep
153155
```
154156

157+
### Rowwise scaling
158+
159+
<img width="805" alt="float8_rowwise_speedup" src="../../docs/static/fp8-rowwise-perf.png" />
160+
161+
155162
## Derivation
156163

157164
In a bf16 linear, assume all of the time is spent in gemms. In a float8 linear, account for max_abs and casting overhead. We want to know when

0 commit comments

Comments
 (0)