Skip to content

Commit 353dd44

Browse files
authored
float8 readme: remove duplication (#2447)
We had two duplicate example training loops in float8 readme, removing and making the same example work for all recipes
1 parent 420f782 commit 353dd44

File tree

1 file changed

+9
-55
lines changed

1 file changed

+9
-55
lines changed

torchao/float8/README.md

Lines changed: 9 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,12 @@ and composable with key systems such as autograd, ```torch.compile``` and distri
1212

1313
# Single GPU User API
1414

15-
## float8 linear with dynamic tensorwise scaling
16-
17-
This is the default recipe, with a good balance of performance and accuracy.
18-
1915
```python
2016
import time
2117

2218
import torch
2319
import torch.nn as nn
24-
from torchao.float8 import convert_to_float8_training
20+
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
2521
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
2622

2723
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -47,8 +43,15 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str):
4743
return False
4844
return True
4945

46+
# configure float8 recipe
47+
# valid recipe names: "tensorwise", "rowwise", "rowwise_with_gw_hp"
48+
config = Float8LinearConfig.from_recipe_name("tensorwise")
49+
5050
# convert specified `torch.nn.Linear` modules to `Float8Linear`
51-
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
51+
convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn)
52+
53+
# display converted model
54+
print(m)
5255

5356
# enable torch.compile for competitive performance
5457
m = torch.compile(m)
@@ -75,55 +78,6 @@ end_time = time.time()
7578
print("Training time:", end_time - start_time)
7679
```
7780

78-
## float8 linear with rowwise scaling
79-
80-
This is a more accurate recipe compared to tensorwise, with more granular scaling.
81-
82-
```python
83-
import torch
84-
import torch.nn as nn
85-
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
86-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
87-
88-
if not TORCH_VERSION_AT_LEAST_2_5:
89-
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
90-
91-
# create model and sample input
92-
m = nn.Sequential(
93-
nn.Linear(2048, 4096),
94-
nn.Linear(4096, 128),
95-
).bfloat16().cuda()
96-
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
97-
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
98-
99-
# optional: filter modules from being eligible for float8 conversion
100-
def module_filter_fn(mod: torch.nn.Module, fqn: str):
101-
# don't convert the last module
102-
if fqn == "1":
103-
return False
104-
# don't convert linear modules with weight dimensions not divisible by 16
105-
if isinstance(mod, torch.nn.Linear):
106-
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
107-
return False
108-
return True
109-
110-
# configure rowwise scaling
111-
config = Float8LinearConfig.from_recipe_name("rowwise")
112-
113-
# convert specified `torch.nn.Linear` modules to `Float8Linear`
114-
convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn)
115-
116-
# enable torch.compile for competitive performance
117-
m = torch.compile(m)
118-
119-
# toy training loop
120-
for _ in range(10):
121-
optimizer.zero_grad()
122-
y = m(x)
123-
y.sum().backward()
124-
optimizer.step()
125-
```
126-
12781
# Multi GPU User API
12882

12983
We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html),

0 commit comments

Comments
 (0)