Skip to content

Commit c4a4523

Browse files
committed
Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8d58dfe ghstack-comment-id: 3460951962 Pull-Request: #3257
1 parent 1e473ed commit c4a4523

File tree

8 files changed

+504
-307
lines changed

8 files changed

+504
-307
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_invalid_granularity(self):
152152
def test_mismatched_granularity(self):
153153
with pytest.raises(
154154
ValueError,
155-
match="Different granularities for activation and weight are not supported",
155+
match="Unsupported granularity types",
156156
):
157157
Float8DynamicActivationFloat8WeightConfig(
158158
granularity=(PerTensor(), PerRow())
@@ -165,7 +165,7 @@ def test_unsupported_granularity(self):
165165
class UnsupportedGranularity:
166166
pass
167167

168-
with pytest.raises(ValueError, match="Invalid granularity types"):
168+
with pytest.raises(ValueError, match="Unsupported granularity types"):
169169
Float8DynamicActivationFloat8WeightConfig(
170170
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
171171
)

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
PerBlock,
2122
PerRow,
2223
PerTensor,
2324
quantize_,
@@ -64,7 +65,10 @@ def setUp(self):
6465
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6566
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
6667
@common_utils.parametrize("compile", [True, False])
67-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
68+
@common_utils.parametrize(
69+
"granularity",
70+
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
71+
)
6872
@common_utils.parametrize(
6973
"kernel_preference",
7074
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
@@ -74,7 +78,7 @@ def setUp(self):
7478
"sizes",
7579
[
7680
((128,), 256, 128),
77-
((32, 128), 64, 256),
81+
((32, 128), 256, 512),
7882
],
7983
)
8084
def test_fp8_linear_variants(
@@ -86,13 +90,21 @@ def test_fp8_linear_variants(
8690
kernel_preference: KernelPreference,
8791
sizes: Tuple,
8892
):
89-
if (
90-
isinstance(granularity, PerTensor)
91-
and kernel_preference == KernelPreference.FBGEMM
92-
):
93-
return unittest.skip(
94-
"per tensor with fbgemm kernel preferece does not work yet"
95-
)
93+
if isinstance(granularity, PerTensor):
94+
if kernel_preference is KernelPreference.FBGEMM:
95+
return unittest.skip(
96+
"per tensor with fbgemm kernel preference does not work yet"
97+
)
98+
elif mode == "weight-only":
99+
return unittest.skip("unimplemented")
100+
101+
elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
102+
if dtype is torch.float32:
103+
return unittest.skip("unimplemented")
104+
elif mode == "weight-only":
105+
return unittest.skip("unimplemented")
106+
elif kernel_preference is KernelPreference.FBGEMM:
107+
return unittest.skip("unimplemented")
96108

97109
error_message = None
98110
if isinstance(granularity, PerRow):
@@ -137,6 +149,20 @@ def test_fp8_linear_variants(
137149

138150
quantize_(quantized_model, config)
139151

152+
# ensure weight scaling is what we expect
153+
qs1 = quantized_model.linear1.weight.scale
154+
qs2 = quantized_model.linear2.weight.scale
155+
if granularity == PerTensor():
156+
assert qs1.shape == (1, 1)
157+
assert qs2.shape == (1, 1)
158+
elif granularity == PerRow():
159+
assert qs1.shape == (N, 1)
160+
assert qs2.shape == (K, 1)
161+
else:
162+
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
163+
assert qs1.shape == (N // 128, K // 128)
164+
assert qs2.shape == (K // 128, N // 128)
165+
140166
if compile:
141167
quantized_model = torch.compile(quantized_model, fullgraph=True)
142168

test/quantization/test_quant_primitives.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
MappingType,
1515
ZeroPointDomain,
1616
_choose_qparams_affine_tinygemm,
17+
_choose_scale_float8,
1718
_fake_quantize_affine,
1819
_fake_quantize_affine_cachemask,
1920
_maybe_expand_scale_to_tensor_shape,
21+
_quantize_affine_float8,
2022
choose_qparams_affine,
2123
dequantize_affine,
2224
quantize_affine,
@@ -55,6 +57,23 @@ def check_idempotent(self, fn, *args, **kwargs):
5557
return output1
5658

5759

60+
# from https://github.com/pytorch/pytorch/blob/7563f61cc8a40a5ba21a498a2d98895b4eec3f39/test/test_scaled_matmul_cuda.py#L100
61+
# with scale modified to be the inverse of the version in PT core
62+
def _tensor_to_scale_block(
63+
x: torch.Tensor,
64+
float8_dtype: torch.dtype,
65+
block_outer: int,
66+
block_inner: int,
67+
) -> tuple[torch.Tensor, torch.Tensor]:
68+
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
69+
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
70+
scale = amax / torch.finfo(float8_dtype).max
71+
x = x.div(scale).to(float8_dtype)
72+
x = x.flatten(2, 3).flatten(0, 1)
73+
scale = scale.flatten(2, 3).flatten(0, 1)
74+
return x, scale
75+
76+
5877
# Legacy tinygemm ops
5978
def _get_groupwise_affine_qparams(
6079
w,
@@ -798,6 +817,33 @@ def test_maybe_expand_scale_to_tensor_shape(self):
798817
self.assertEqual(new_scale5.shape, torch.Size([3, 2, 8]))
799818
self.assertEqual(new_scale5.unique(dim=-1).shape, torch.Size([3, 2, 2]))
800819

820+
def test_float8_blockwise_scaling(self):
821+
M, K = 512, 1024
822+
hp_tensor = torch.randn(M, K, dtype=torch.float)
823+
# make the scales from some of the blocks obviously different
824+
hp_tensor[0:128, 0:128] *= 3.0
825+
hp_tensor[0:128, 128:256] *= 7.0
826+
hp_tensor[128:256, 0:128] *= 2.0
827+
hp_tensor[128:256, 128:256] *= 100.0
828+
829+
block_size = (128, 128)
830+
831+
scale = _choose_scale_float8(
832+
hp_tensor,
833+
float8_dtype=torch.float8_e4m3fn,
834+
block_size=block_size,
835+
hp_value_lb=None,
836+
hp_value_ub=None,
837+
)
838+
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)
839+
840+
ref_data, ref_scale = _tensor_to_scale_block(
841+
hp_tensor, torch.float8_e4m3fn, 128, 128
842+
)
843+
844+
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
845+
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)
846+
801847

802848
if __name__ == "__main__":
803849
unittest.main()

torchao/float8/inference.py

Lines changed: 59 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
Defines an nn module designed to be used during inference
88
"""
99

10+
import math
1011
from typing import List, NamedTuple, Optional, Tuple, Union
1112

1213
import torch
1314

1415
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
1516
from torchao.float8.types import FP8Granularity
1617
from torchao.quantization.granularity import (
18+
PerBlock,
1719
PerRow,
1820
PerTensor,
1921
)
@@ -196,6 +198,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
196198
)
197199

198200

201+
def _is_1_128_scaled(x: torch.Tensor) -> bool:
202+
"""Checks if a quantized tensor is scaled with a block size of 1x128
203+
Args:
204+
x: quantized tensor (should have `block_size` attribute)
205+
"""
206+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
207+
b = x.block_size
208+
return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128
209+
210+
211+
def _is_128_128_scaled(x: torch.Tensor) -> bool:
212+
"""Checks if a quantized tensor is scaled with a block size of 128x128
213+
Args:
214+
x: quantized tensor (should have `block_size` attribute)
215+
"""
216+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
217+
b = x.block_size
218+
return len(b) == 2 and b[0] == 128 and b[1] == 128
219+
220+
221+
def _granularity_is_a_1_128_w_128_128(
222+
g: Union[
223+
FP8Granularity,
224+
Tuple[FP8Granularity, FP8Granularity],
225+
list[FP8Granularity],
226+
],
227+
) -> bool:
228+
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))
229+
230+
199231
def _normalize_granularity(
200232
granularity: Optional[
201233
Union[
@@ -211,22 +243,23 @@ def _normalize_granularity(
211243
elif isinstance(granularity, (PerTensor, PerRow)):
212244
processed_granularity = (granularity, granularity)
213245
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
214-
if not (
215-
isinstance(granularity[0], (PerTensor, PerRow))
216-
and isinstance(granularity[1], (PerTensor, PerRow))
217-
):
218-
raise ValueError(
219-
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
220-
)
246+
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
247+
granularity[1], PerTensor
248+
)
249+
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
250+
granularity[1], PerRow
251+
)
252+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)
253+
254+
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
255+
raise ValueError(f"Unsupported granularity types: {granularity}.")
221256
if not isinstance(granularity[0], type(granularity[1])):
222257
raise ValueError(
223-
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
258+
f"Different granularities for activation and weight are not supported: {granularity}."
224259
)
225260
processed_granularity = tuple(granularity)
226261
else:
227-
raise ValueError(
228-
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
229-
)
262+
raise ValueError(f"Invalid granularity specification: {granularity}.")
230263
return processed_granularity
231264

232265

@@ -243,12 +276,22 @@ def _check_hardware_support(
243276
AssertionError: If hardware doesn't support the requested granularity
244277
ValueError: If invalid granularity type is provided
245278
"""
246-
for _granularity in granularities:
247-
if not isinstance(_granularity, (PerTensor, PerRow)):
248-
raise ValueError(
249-
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
250-
)
279+
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
280+
granularities[1], PerTensor
281+
)
282+
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
283+
granularities[1], PerRow
284+
)
285+
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)
251286

287+
if is_per_tensor or is_per_row:
252288
assert is_sm_at_least_89() or is_MI300(), (
253289
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
254290
)
291+
elif is_a_1_128_w_128_128:
292+
# TODO(future PR): look into AMD support
293+
assert is_sm_at_least_89(), (
294+
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
295+
)
296+
else:
297+
raise ValueError(f"Invalid granularities {granularities}.")

0 commit comments

Comments
 (0)