You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[float8] add _auto_filter_for_recipe for float8 training (#1319)
Fixes#1207
## Problem
- float8 rowwise + vanilla TP in torchtitan had flat perf with respect
to bfloat16 (see #1207).
- RCA In #1207 found attention.wk and attention.wv layers were so small
that float8 rowwise conversion resulted in approx ~40% slowdown for
those layers, which nullified the perf benefits from fp8 rowwise
conversion on larger linears.
- This is because the default `filter_fqns` for float8 model conversion
are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise
recipe.
### Solution
This has been a footgun for various users as well (including Poolside),
so I created an "auto filter" (pytorch/ao#2410)
which automatically filters Linears for a given float8 recipe, by
checking for the following criteria:
1. dims not divisible by 16 (hardware requirement for float8)
2. dim sizes below thresholds that may result in worse perf **for that
given recipe**, using simple heuristics based on the linked recipe perf
tables above.
3. fqn matches one of the user defined `filter_fqns`
It prevents users from hitting this common footgun, while also
preserving the flexibility to define their model-specific fqns.
## Results
Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement
for async TP (over bf16 TP baseline).
Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC,
local batch size 16:
- [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS
- [fp8 rowwise WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS
- [fp8 rowwise WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS
- [fp8 rowwise + async TP WITH attention.wk, attention.wv
converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS
- [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv
converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
*`--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
18
18
*`--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
19
19
*`--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
20
+
*`--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
21
+
***Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.
20
22
*`--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
21
23
22
24
For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)
0 commit comments