Skip to content

Commit b0902b2

Browse files
[float8] add _auto_filter_for_recipe for float8 training (#1319)
Fixes #1207 ## Problem - float8 rowwise + vanilla TP in torchtitan had flat perf with respect to bfloat16 (see #1207). - RCA In #1207 found attention.wk and attention.wv layers were so small that float8 rowwise conversion resulted in approx ~40% slowdown for those layers, which nullified the perf benefits from fp8 rowwise conversion on larger linears. - This is because the default `filter_fqns` for float8 model conversion are fine for the fp8 tensorwise recipe, but bad for the float8 rowwise recipe. ### Solution This has been a footgun for various users as well (including Poolside), so I created an "auto filter" (pytorch/ao#2410) which automatically filters Linears for a given float8 recipe, by checking for the following criteria: 1. dims not divisible by 16 (hardware requirement for float8) 2. dim sizes below thresholds that may result in worse perf **for that given recipe**, using simple heuristics based on the linked recipe perf tables above. 3. fqn matches one of the user defined `filter_fqns` It prevents users from hitting this common footgun, while also preserving the flexibility to define their model-specific fqns. ## Results Benchmarks show a ~10% TPS improvement for TP and ~15% TPS improvement for async TP (over bf16 TP baseline). Llama3 70b on 256 H100s with FSDP=32, TP=8, torch.compile, full AC, local batch size 16: - [bfloat16 baseline](https://fburl.com/mlhub/ji9smr5u) = ~597TPS - [fp8 rowwise WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/cu4o6w5m) = ~600 TPS - [fp8 rowwise WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/mgzz309o) = ~660 TPS - [fp8 rowwise + async TP WITH attention.wk, attention.wv converted](https://fburl.com/mlhub/76q4mel9 ) = ~625 TPS - [fp8 rowwise + async TP WITHOUT attention.wk, attention.wv converted](https://fburl.com/mlhub/6b07aa4d) = ~695 TPS
1 parent 42b0beb commit b0902b2

File tree

2 files changed

+81
-33
lines changed

2 files changed

+81
-33
lines changed

docs/float8.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai
1717
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
1818
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
1919
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
20+
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
21+
* **Auto-filter**: add `"auto_filter_small_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers are not large enough to benefit from float8 training, since the GEMM has to be big enough that the speedup from using FP8 tensorcores is greater than the overhead of creating dynamically quantized inputs. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs, where (K,N) represents the linear layer weight shape. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.
2022
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
2123

2224
For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)

torchtitan/components/quantization/float8.py

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from functools import partial
87

98
import torch
@@ -20,6 +19,8 @@
2019

2120
from .utils import module_filter_fn
2221

22+
AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn"
23+
2324

2425
class Float8Converter(ModelConverter):
2526
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
@@ -54,14 +55,19 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5455
self.enabled = True
5556
self.filter_fqns = float8_config.filter_fqns
5657
self.moe_fqns = float8_config.moe_fqns_prototype
58+
self.filter_fn = self._init_filter_fn(float8_config)
5759

5860
if float8_config.recipe_name is not None:
59-
assert (
60-
not float8_config.enable_fsdp_float8_all_gather
61-
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
62-
assert (
63-
not float8_config.force_recompute_fp8_weight_in_bwd
64-
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
61+
assert not float8_config.enable_fsdp_float8_all_gather, (
62+
"using `float8_config.enable_fsdp_float8_all_gather` together "
63+
"with `float8_config.recipe_name` is not supported"
64+
)
65+
66+
assert not float8_config.force_recompute_fp8_weight_in_bwd, (
67+
"using `float8_config.force_recompute_fp8_weight_in_bwd` together "
68+
"with `float8_config.recipe_name` is not supported"
69+
)
70+
6571
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
6672
self.precompute_scale = False
6773
logger.info(
@@ -74,7 +80,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
7480
logger.debug(
7581
"Set torch._inductor.config.emulate_precision_casts to True"
7682
)
77-
7883
else:
7984
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
8085
enable_fsdp_float8_all_gather = (
@@ -93,6 +98,42 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
9398
)
9499
logger.info("Float8 tensorwise scaled training active")
95100

101+
def _init_filter_fn(self, float8_config: Float8):
102+
# use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns.
103+
use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns
104+
if use_auto_filter:
105+
try:
106+
from torchao.float8 import _auto_filter_for_recipe
107+
108+
logger.info(
109+
"Using _auto_filter_for_recipe to avoid converting linear layers with dims too small "
110+
"to benefit from float8 training. See docs/float8.md for more info."
111+
)
112+
113+
recipe_name = (
114+
float8_config.recipe_name
115+
if float8_config.recipe_name
116+
else "tensorwise"
117+
)
118+
119+
# remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe
120+
float8_config.filter_fqns.remove(AUTO_FILTER_SMALL_KN_FLAG)
121+
122+
return _auto_filter_for_recipe(
123+
recipe_name,
124+
filter_fqns=float8_config.filter_fqns,
125+
)
126+
except ImportError:
127+
logger.warning(
128+
(
129+
"Using default module_filter_fn for float8 model conversion. "
130+
"To use _auto_filter_for_recipe, please install torchao nightly build."
131+
)
132+
)
133+
134+
# use default filter func
135+
return partial(module_filter_fn, filter_fqns=float8_config.filter_fqns)
136+
96137
def convert(self, model: nn.Module):
97138
"""
98139
This function converts the linear layers of `model` to `Float8Linear`.
@@ -102,36 +143,12 @@ def convert(self, model: nn.Module):
102143
if not self.enabled:
103144
return
104145

105-
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
106-
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
107146
# MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will
108147
# be converted back to nn.Linear:
109148
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
110149
# TODO: add warning in torchao when this happens, or find a better way to avoid this.
111150
if self.moe_fqns:
112-
from torchao.quantization.quant_api import quantize_
113-
114-
try:
115-
from torchao.prototype.moe_training.conversion_utils import (
116-
MoETrainingConfig,
117-
)
118-
except ImportError as e:
119-
raise ImportError(
120-
"torchao installation does not have MoE training support. Please install torchao nightly build."
121-
) from e
122-
123-
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
124-
for target_fqn in self.moe_fqns:
125-
if target_fqn in cur_fqn:
126-
return True
127-
return False
128-
129-
config = MoETrainingConfig()
130-
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
131-
logger.info(
132-
f"Converted MoE layers matching FQNS {self.moe_fqns} "
133-
"to use dynamic float8 rowwise quantization with scaled grouped GEMMs"
134-
)
151+
self._convert_moe_layers(model)
135152

136153
from torchao.float8 import convert_to_float8_training
137154

@@ -146,6 +163,35 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
146163
f"{self.config.enable_fsdp_float8_all_gather}"
147164
)
148165

166+
def _convert_moe_layers(self, model: nn.Module):
167+
"""
168+
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
169+
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
170+
"""
171+
from torchao.quantization.quant_api import quantize_
172+
173+
try:
174+
from torchao.prototype.moe_training.conversion_utils import (
175+
MoETrainingConfig,
176+
)
177+
except ImportError as e:
178+
raise ImportError(
179+
"torchao installation does not have MoE training support. Please install torchao nightly build."
180+
) from e
181+
182+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
183+
for target_fqn in self.moe_fqns:
184+
if target_fqn in cur_fqn:
185+
return True
186+
return False
187+
188+
config = MoETrainingConfig()
189+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
190+
logger.info(
191+
f"Converted MoE layers matching FQNS {self.moe_fqns} "
192+
"to use dynamic float8 rowwise quantization with scaled grouped GEMMs"
193+
)
194+
149195
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
150196
if not self.enabled:
151197
return

0 commit comments

Comments
 (0)