Skip to content

Commit 9eeb101

Browse files
[float8] add _auto_filter_for_recipe to float8 (#2410)
* add auto_filter_for_recipe to float8 * lint * address comments * add tests
1 parent b963540 commit 9eeb101

File tree

2 files changed

+90
-3
lines changed

2 files changed

+90
-3
lines changed

torchao/float8/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
ScalingGranularity,
77
ScalingType,
88
)
9-
from torchao.float8.float8_linear_utils import convert_to_float8_training
9+
from torchao.float8.float8_linear_utils import (
10+
_auto_filter_for_recipe,
11+
convert_to_float8_training,
12+
)
1013
from torchao.float8.float8_tensor import (
1114
Float8Tensor,
1215
GemmInputRole,
@@ -44,6 +47,7 @@
4447
# top level UX
4548
"convert_to_float8_training",
4649
"precompute_float8_dynamic_scale_for_fsdp",
50+
"_auto_filter_for_recipe",
4751
# types
4852
"FP8Granularity",
4953
# note: Float8Tensor and Float8Linear are not public APIs

torchao/float8/float8_linear_utils.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import logging
7-
from typing import Callable, Optional
7+
from functools import partial
8+
from typing import Callable, List, Optional, Union
89

910
import torch.nn as nn
1011

11-
from torchao.float8.config import Float8LinearConfig
12+
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
1213
from torchao.float8.float8_linear import Float8Linear
1314

1415
log = logging.getLogger(__name__)
@@ -113,3 +114,85 @@ def convert_to_float8_training(
113114
from_float,
114115
module_filter_fn=module_filter_fn,
115116
)
117+
118+
119+
def _auto_filter_for_recipe(
120+
recipe: Union[str, Float8LinearRecipeName], filter_fqns: List[str]
121+
) -> Callable[[nn.Module, str], bool]:
122+
"""Returns function which automatically filters nn.Linear modules that meet at least one of the following criteria:
123+
124+
1. Dims not divisible by 16 (hardware requirement for float8).
125+
2. Dim sizes below certain thresholds, which may result in worse performance.
126+
127+
NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
128+
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
129+
your module, using the performance tables for the given float8 recipe here:
130+
https://github.com/pytorch/ao/tree/main/torchao/float8#performance). These benchmarks referenced for
131+
auto filtering layers were run on H100 GPUs, and may not be representative of other hardware.
132+
133+
This is an experimental API, the design may change in the future.
134+
"""
135+
if isinstance(recipe, str):
136+
recipe = Float8LinearRecipeName(recipe)
137+
if recipe == Float8LinearRecipeName.TENSORWISE:
138+
return partial(_auto_filter_for_tensorwise, filter_fqns=filter_fqns)
139+
elif recipe == Float8LinearRecipeName.ROWWISE:
140+
return partial(_auto_filter_for_rowwise, filter_fqns=filter_fqns)
141+
elif recipe == Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
142+
raise NotImplementedError(f"Unsupported recipe: {recipe}")
143+
else:
144+
raise ValueError(f"Invalid recipe: {recipe}")
145+
146+
147+
def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -> bool:
148+
if not isinstance(mod, nn.Linear):
149+
return False
150+
151+
# If the fqn matches any filtered fqn, then we should not convert this module.
152+
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
153+
if is_filtered_fqn:
154+
return False
155+
156+
# All dims must be divisible by 16 due to float8 hardware requirements.
157+
N, K = mod.weight.shape
158+
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
159+
if not dims_multiples_of_16:
160+
return False
161+
162+
# Dims below these thresholds may result in worse performance
163+
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
164+
# Note that these benchmarks referenced for auto filtering layers were run on
165+
# H100 GPUs, and may not be representative of other hardware.
166+
if N <= 2048:
167+
return False
168+
elif K <= 1024:
169+
return False
170+
elif N <= 4096 and K <= 2048:
171+
return False
172+
return True
173+
174+
175+
def _auto_filter_for_tensorwise(
176+
mod: nn.Module, fqn: str, filter_fqns: List[str]
177+
) -> bool:
178+
if not isinstance(mod, nn.Linear):
179+
return False
180+
181+
# If the fqn matches any filtered fqn, then we should not convert this module.
182+
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
183+
if is_filtered_fqn:
184+
return False
185+
186+
# All dims must be divisible by 16 due to float8 hardware requirements.
187+
N, K = mod.weight.shape
188+
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
189+
if not dims_multiples_of_16:
190+
return False
191+
192+
# Dims below these thresholds may result in worse performance
193+
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
194+
# Note that these benchmarks referenced for auto filtering layers were run on
195+
# H100 GPUs, and may not be representative of other hardware.
196+
if K <= 4096 and N <= 1024:
197+
return False
198+
return True

0 commit comments

Comments
 (0)