Skip to content

Commit 584e30b

Browse files
committed
Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ae35a4f ghstack-comment-id: 3460951962 Pull-Request: #3257
1 parent 6c70630 commit 584e30b

File tree

7 files changed

+459
-307
lines changed

7 files changed

+459
-307
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_invalid_granularity(self):
152152
def test_mismatched_granularity(self):
153153
with pytest.raises(
154154
ValueError,
155-
match="Different granularities for activation and weight are not supported",
155+
match="Unsupported granularity types",
156156
):
157157
Float8DynamicActivationFloat8WeightConfig(
158158
granularity=(PerTensor(), PerRow())
@@ -165,7 +165,7 @@ def test_unsupported_granularity(self):
165165
class UnsupportedGranularity:
166166
pass
167167

168-
with pytest.raises(ValueError, match="Invalid granularity types"):
168+
with pytest.raises(ValueError, match="Unsupported granularity types"):
169169
Float8DynamicActivationFloat8WeightConfig(
170170
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
171171
)

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
PerBlock,
2122
PerRow,
2223
PerTensor,
2324
quantize_,
@@ -64,7 +65,10 @@ def setUp(self):
6465
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6566
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
6667
@common_utils.parametrize("compile", [True, False])
67-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
68+
@common_utils.parametrize(
69+
"granularity",
70+
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
71+
)
6872
@common_utils.parametrize(
6973
"kernel_preference",
7074
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
@@ -74,7 +78,7 @@ def setUp(self):
7478
"sizes",
7579
[
7680
((128,), 256, 128),
77-
((32, 128), 64, 256),
81+
((32, 128), 256, 512),
7882
],
7983
)
8084
def test_fp8_linear_variants(
@@ -86,13 +90,21 @@ def test_fp8_linear_variants(
8690
kernel_preference: KernelPreference,
8791
sizes: Tuple,
8892
):
89-
if (
90-
isinstance(granularity, PerTensor)
91-
and kernel_preference == KernelPreference.FBGEMM
92-
):
93-
return unittest.skip(
94-
"per tensor with fbgemm kernel preferece does not work yet"
95-
)
93+
if isinstance(granularity, PerTensor):
94+
if kernel_preference is KernelPreference.FBGEMM:
95+
return unittest.skip(
96+
"per tensor with fbgemm kernel preference does not work yet"
97+
)
98+
elif mode == "weight-only":
99+
return unittest.skip("unimplemented")
100+
101+
elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
102+
if dtype is torch.float32:
103+
return unittest.skip("unimplemented")
104+
elif mode == "weight-only":
105+
return unittest.skip("unimplemented")
106+
elif kernel_preference is KernelPreference.FBGEMM:
107+
return unittest.skip("unimplemented")
96108

97109
error_message = None
98110
if isinstance(granularity, PerRow):
@@ -137,6 +149,20 @@ def test_fp8_linear_variants(
137149

138150
quantize_(quantized_model, config)
139151

152+
# ensure weight scaling is what we expect
153+
qs1 = quantized_model.linear1.weight.scale
154+
qs2 = quantized_model.linear2.weight.scale
155+
if granularity == PerTensor():
156+
assert qs1.shape == (1, 1)
157+
assert qs2.shape == (1, 1)
158+
elif granularity == PerRow():
159+
assert qs1.shape == (N, 1)
160+
assert qs2.shape == (K, 1)
161+
else:
162+
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
163+
assert qs1.shape == (N // 128, K // 128)
164+
assert qs2.shape == (K // 128, N // 128)
165+
140166
if compile:
141167
quantized_model = torch.compile(quantized_model, fullgraph=True)
142168

torchao/float8/inference.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
Defines an nn module designed to be used during inference
88
"""
99

10+
import math
1011
from typing import List, NamedTuple, Optional, Tuple, Union
1112

1213
import torch
1314

1415
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
1516
from torchao.float8.types import FP8Granularity
1617
from torchao.quantization.granularity import (
18+
PerBlock,
1719
PerRow,
1820
PerTensor,
1921
)
@@ -196,6 +198,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
196198
)
197199

198200

201+
def _is_1_128_scaled(x: torch.Tensor) -> bool:
202+
"""Checks if a quantized tensor is scaled with a block size of 1x128
203+
Args:
204+
x: quantized tensor (should have `block_size` attribute)
205+
"""
206+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
207+
b = x.block_size
208+
return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128
209+
210+
211+
def _is_128_128_scaled(x: torch.Tensor) -> bool:
212+
"""Checks if a quantized tensor is scaled with a block size of 128x128
213+
Args:
214+
x: quantized tensor (should have `block_size` attribute)
215+
"""
216+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
217+
b = x.block_size
218+
return len(b) == 2 and b[0] == 128 and b[1] == 128
219+
220+
221+
def _granularity_is_a_1_128_w_128_128(
222+
g: Union[
223+
FP8Granularity,
224+
Tuple[FP8Granularity, FP8Granularity],
225+
list[FP8Granularity],
226+
],
227+
) -> bool:
228+
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
229+
230+
199231
def _normalize_granularity(
200232
granularity: Optional[
201233
Union[
@@ -211,22 +243,23 @@ def _normalize_granularity(
211243
elif isinstance(granularity, (PerTensor, PerRow)):
212244
processed_granularity = (granularity, granularity)
213245
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
214-
if not (
215-
isinstance(granularity[0], (PerTensor, PerRow))
216-
and isinstance(granularity[1], (PerTensor, PerRow))
217-
):
218-
raise ValueError(
219-
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
220-
)
246+
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
247+
granularity[1], PerTensor
248+
)
249+
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
250+
granularity[1], PerRow
251+
)
252+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
253+
254+
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
255+
raise ValueError(f"Unsupported granularity types: {granularity}.")
221256
if not isinstance(granularity[0], type(granularity[1])):
222257
raise ValueError(
223-
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
258+
f"Different granularities for activation and weight are not supported: {granularity}."
224259
)
225260
processed_granularity = tuple(granularity)
226261
else:
227-
raise ValueError(
228-
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
229-
)
262+
raise ValueError(f"Invalid granularity specification: {granularity}.")
230263
return processed_granularity
231264

232265

@@ -243,12 +276,22 @@ def _check_hardware_support(
243276
AssertionError: If hardware doesn't support the requested granularity
244277
ValueError: If invalid granularity type is provided
245278
"""
246-
for _granularity in granularities:
247-
if not isinstance(_granularity, (PerTensor, PerRow)):
248-
raise ValueError(
249-
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
250-
)
279+
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
280+
granularities[1], PerTensor
281+
)
282+
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
283+
granularities[1], PerRow
284+
)
285+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
251286

287+
if is_per_tensor or is_per_row:
252288
assert is_sm_at_least_89() or is_MI300(), (
253289
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
254290
)
291+
elif is_a_1_128_w_128_128:
292+
# TODO(future PR): look into AMD support
293+
assert is_sm_at_least_89(), (
294+
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
295+
)
296+
else:
297+
raise ValueError(f"Invalid granularities {granularities}.")

0 commit comments

Comments
 (0)