@@ -12,16 +12,12 @@ and composable with key systems such as autograd, ```torch.compile``` and distri
12
12
13
13
# Single GPU User API
14
14
15
- ## float8 linear with dynamic tensorwise scaling
16
-
17
- This is the default recipe, with a good balance of performance and accuracy.
18
-
19
15
``` python
20
16
import time
21
17
22
18
import torch
23
19
import torch.nn as nn
24
- from torchao.float8 import convert_to_float8_training
20
+ from torchao.float8 import convert_to_float8_training, Float8LinearConfig
25
21
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
26
22
27
23
if not TORCH_VERSION_AT_LEAST_2_5 :
@@ -47,8 +43,15 @@ def module_filter_fn(mod: torch.nn.Module, fqn: str):
47
43
return False
48
44
return True
49
45
46
+ # configure float8 recipe
47
+ # valid recipe names: "tensorwise", "rowwise", "rowwise_with_gw_hp"
48
+ config = Float8LinearConfig.from_recipe_name(" tensorwise" )
49
+
50
50
# convert specified `torch.nn.Linear` modules to `Float8Linear`
51
- convert_to_float8_training(m, module_filter_fn = module_filter_fn)
51
+ convert_to_float8_training(m, config = config, module_filter_fn = module_filter_fn)
52
+
53
+ # display converted model
54
+ print (m)
52
55
53
56
# enable torch.compile for competitive performance
54
57
m = torch.compile(m)
@@ -75,55 +78,6 @@ end_time = time.time()
75
78
print (" Training time:" , end_time - start_time)
76
79
```
77
80
78
- ## float8 linear with rowwise scaling
79
-
80
- This is a more accurate recipe compared to tensorwise, with more granular scaling.
81
-
82
- ``` python
83
- import torch
84
- import torch.nn as nn
85
- from torchao.float8 import convert_to_float8_training, Float8LinearConfig
86
- from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
87
-
88
- if not TORCH_VERSION_AT_LEAST_2_5 :
89
- raise AssertionError (" torchao.float8 requires PyTorch version 2.5 or greater" )
90
-
91
- # create model and sample input
92
- m = nn.Sequential(
93
- nn.Linear(2048 , 4096 ),
94
- nn.Linear(4096 , 128 ),
95
- ).bfloat16().cuda()
96
- x = torch.randn(4096 , 2048 , device = " cuda" , dtype = torch.bfloat16)
97
- optimizer = torch.optim.SGD(m.parameters(), lr = 0.1 )
98
-
99
- # optional: filter modules from being eligible for float8 conversion
100
- def module_filter_fn (mod : torch.nn.Module, fqn : str ):
101
- # don't convert the last module
102
- if fqn == " 1" :
103
- return False
104
- # don't convert linear modules with weight dimensions not divisible by 16
105
- if isinstance (mod, torch.nn.Linear):
106
- if mod.in_features % 16 != 0 or mod.out_features % 16 != 0 :
107
- return False
108
- return True
109
-
110
- # configure rowwise scaling
111
- config = Float8LinearConfig.from_recipe_name(" rowwise" )
112
-
113
- # convert specified `torch.nn.Linear` modules to `Float8Linear`
114
- convert_to_float8_training(m, config = config, module_filter_fn = module_filter_fn)
115
-
116
- # enable torch.compile for competitive performance
117
- m = torch.compile(m)
118
-
119
- # toy training loop
120
- for _ in range (10 ):
121
- optimizer.zero_grad()
122
- y = m(x)
123
- y.sum().backward()
124
- optimizer.step()
125
- ```
126
-
127
81
# Multi GPU User API
128
82
129
83
We compose with the ` DTensor ` based [ distributed APIs] ( https://pytorch.org/docs/stable/distributed.tensor.parallel.html ) ,
0 commit comments