You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[float] document e2e training -> inference flow (#2190)
* document e2e training -> inference flow
* add save/load checkpoint
* update to how we load checkpoint
* remove debugging
* add more detail
* remove unused import
* lower lr to prevent large optimizer step into weight territory which produces inf
* use actual loss function
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.
233
+
234
+
# E2E training + inference flow
235
+
236
+
The first step in the E2E is to train your model and save a checkpoint. The second step is to load the checkpoint and optionally apply inference quantization before serving the model.
237
+
#### 1. Train model and save checkpoint
238
+
```python
239
+
import torch
240
+
from torch import nn
241
+
import torch.nn.functional as F
242
+
243
+
from torchao.float8.float8_linear_utils import convert_to_float8_training
244
+
from torchao.float8.float8_linear import Float8Linear
245
+
from torchao.float8 import convert_to_float8_training
246
+
from torchao.utils importTORCH_VERSION_AT_LEAST_2_5
247
+
248
+
ifnotTORCH_VERSION_AT_LEAST_2_5:
249
+
raiseAssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
250
+
251
+
# create model and sample input
252
+
m = nn.Sequential(
253
+
nn.Linear(2048, 4096),
254
+
nn.Linear(4096, 128),
255
+
nn.Linear(128, 1),
256
+
).bfloat16().cuda()
257
+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
# enable torch.compile for competitive performance
275
+
m = torch.compile(m)
276
+
277
+
# toy training loop
278
+
for _ inrange(10):
279
+
optimizer.zero_grad()
280
+
output = m(x)
281
+
# use fake labels for demonstration purposes
282
+
fake_labels = torch.ones_like(output)
283
+
loss = F.mse_loss(output, fake_labels)
284
+
loss.backward()
285
+
optimizer.step()
286
+
287
+
# save the model
288
+
torch.save({
289
+
'model': m,
290
+
'model_state_dict': m.state_dict(),
291
+
'optimizer_state_dict': optimizer.state_dict(),
292
+
}, 'checkpoint.pth')
293
+
```
294
+
295
+
#### 2. Load checkpoint and optionally apply inference quantization
296
+
297
+
There are 3 float8 inference quantization strategies that be used after training with float8: 1) weight only quantization, and 2) dynamic activation and weight quantization, and 3) static quantization.
298
+
299
+
Below is an example of dynamic activation and weight quantization. For more details, examples, and inference benchmrks, see the [torchao inference docs](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md).
300
+
301
+
```python
302
+
import torch
303
+
304
+
from torchao.float8.float8_linear import Float8Linear
305
+
from torchao.quantization.granularity import PerTensor
306
+
from torchao.quantization.quant_api import quantize_
0 commit comments