Skip to content

Commit a0a0969

Browse files
[float] document e2e training -> inference flow (#2190)
* document e2e training -> inference flow * add save/load checkpoint * update to how we load checkpoint * remove debugging * add more detail * remove unused import * lower lr to prevent large optimizer step into weight territory which produces inf * use actual loss function
1 parent c2d2d13 commit a0a0969

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

torchao/float8/README.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,95 @@ including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=re
230230
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./float8_training_benchmark.sh`
231231

232232
See the float8 training benchmarking [guide](.torchao/float8/benchmarking/README.md) for more details.
233+
234+
# E2E training + inference flow
235+
236+
The first step in the E2E is to train your model and save a checkpoint. The second step is to load the checkpoint and optionally apply inference quantization before serving the model.
237+
#### 1. Train model and save checkpoint
238+
```python
239+
import torch
240+
from torch import nn
241+
import torch.nn.functional as F
242+
243+
from torchao.float8.float8_linear_utils import convert_to_float8_training
244+
from torchao.float8.float8_linear import Float8Linear
245+
from torchao.float8 import convert_to_float8_training
246+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
247+
248+
if not TORCH_VERSION_AT_LEAST_2_5:
249+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
250+
251+
# create model and sample input
252+
m = nn.Sequential(
253+
nn.Linear(2048, 4096),
254+
nn.Linear(4096, 128),
255+
nn.Linear(128, 1),
256+
).bfloat16().cuda()
257+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
258+
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
259+
260+
# optional: filter modules from being eligible for float8 conversion
261+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
262+
# don't convert the last module
263+
if fqn == "1":
264+
return False
265+
# don't convert linear modules with weight dimensions not divisible by 16
266+
if isinstance(mod, torch.nn.Linear):
267+
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
268+
return False
269+
return True
270+
271+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
272+
convert_to_float8_training(m, module_filter_fn=module_filter_fn)
273+
274+
# enable torch.compile for competitive performance
275+
m = torch.compile(m)
276+
277+
# toy training loop
278+
for _ in range(10):
279+
optimizer.zero_grad()
280+
output = m(x)
281+
# use fake labels for demonstration purposes
282+
fake_labels = torch.ones_like(output)
283+
loss = F.mse_loss(output, fake_labels)
284+
loss.backward()
285+
optimizer.step()
286+
287+
# save the model
288+
torch.save({
289+
'model': m,
290+
'model_state_dict': m.state_dict(),
291+
'optimizer_state_dict': optimizer.state_dict(),
292+
}, 'checkpoint.pth')
293+
```
294+
295+
#### 2. Load checkpoint and optionally apply inference quantization
296+
297+
There are 3 float8 inference quantization strategies that be used after training with float8: 1) weight only quantization, and 2) dynamic activation and weight quantization, and 3) static quantization.
298+
299+
Below is an example of dynamic activation and weight quantization. For more details, examples, and inference benchmrks, see the [torchao inference docs](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md).
300+
301+
```python
302+
import torch
303+
304+
from torchao.float8.float8_linear import Float8Linear
305+
from torchao.quantization.granularity import PerTensor
306+
from torchao.quantization.quant_api import quantize_
307+
from torchao.quantization import (
308+
Float8DynamicActivationFloat8WeightConfig,
309+
)
310+
311+
# load checkpoint
312+
checkpoint = torch.load('checkpoint.pth', weights_only=False)
313+
model = checkpoint['model']
314+
model.load_state_dict(checkpoint['model_state_dict'])
315+
316+
# optional: apply dynamic float8 quantization on both activations and weights for inference
317+
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
318+
319+
# run inference
320+
x = torch.randn(1, 4096, 2048, device="cuda", dtype=torch.bfloat16)
321+
with torch.inference_mode():
322+
out = model(x)
323+
print(out)
324+
```

0 commit comments

Comments
 (0)