Skip to content

Commit 070068f

Browse files
committed
Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference
Summary: Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in torchao inference. In detail: 1. bring the 128x128 gemm triton kernel we have out of prototype and wrap it with a custom op for `torch.compile` compatibility 2. enable the new granularity in various utility functions 3. wire the new granularity through the float8 inference configs 4. add a test which tests for e2e numerical correctness via SQNR comparison vs high precision baseline For now I added a fallback which only requires triton and is numerically correct but may not reach optimal performance. Performance optimization is left for future PRs: 1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+ 2. we should enable an fbgemm_gpu_genai path, if available in user env 3. we should map to a triton kernel for quantizing the weights, as `torch.compile` is currently known slow for 128x128 block quantization Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: db464e1 ghstack-comment-id: 3460951962 Pull-Request: #3257
1 parent 841d104 commit 070068f

File tree

3 files changed

+110
-31
lines changed

3 files changed

+110
-31
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
PerBlock,
2122
PerRow,
2223
PerTensor,
2324
quantize_,
@@ -61,20 +62,37 @@ def setUp(self):
6162
@unittest.skipIf(
6263
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
6364
)
64-
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
65-
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
66-
@common_utils.parametrize("compile", [True, False])
67-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
65+
# @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
66+
@common_utils.parametrize(
67+
"dtype",
68+
[
69+
torch.bfloat16,
70+
],
71+
)
72+
# @common_utils.parametrize("mode", ["dynamic", "weight-only"])
73+
@common_utils.parametrize(
74+
"mode",
75+
[
76+
"dynamic",
77+
],
78+
)
79+
# @common_utils.parametrize("compile", [True, False])
80+
@common_utils.parametrize("compile", [False])
81+
# @common_utils.parametrize("granularity", [PerTensor(), PerRow()])
82+
@common_utils.parametrize(
83+
"granularity", [(PerBlock((1, 128)), PerBlock((128, 128)))]
84+
)
6885
@common_utils.parametrize(
6986
"kernel_preference",
70-
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
87+
# [KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
88+
[KernelPreference.TORCH],
7189
)
7290
# Inputs are (M,..), K, N
7391
@common_utils.parametrize(
7492
"sizes",
7593
[
7694
((128,), 256, 128),
77-
((32, 128), 64, 256),
95+
# ((32, 128), 64, 256),
7896
],
7997
)
8098
def test_fp8_linear_variants(

torchao/float8/inference.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
1515
from torchao.float8.types import FP8Granularity
1616
from torchao.quantization.granularity import (
17+
PerBlock,
1718
PerRow,
1819
PerTensor,
1920
)
@@ -196,6 +197,26 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
196197
)
197198

198199

200+
def _is_1_128_scaled(x: torch.Tensor) -> bool:
201+
"""Checks if a quantized tensor is scaled with a block size of 1x128
202+
Args:
203+
x: quantized tensor (should have `block_size` attribute)
204+
"""
205+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
206+
b = x.block_size
207+
return len(b) == 2 and b[0] == 1 and b[1] == 128
208+
209+
210+
def _is_128_128_scaled(x: torch.Tensor) -> bool:
211+
"""Checks if a quantized tensor is scaled with a block size of 128x128
212+
Args:
213+
x: quantized tensor (should have `block_size` attribute)
214+
"""
215+
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
216+
b = x.block_size
217+
return len(b) == 2 and b[0] == 128 and b[1] == 128
218+
219+
199220
def _normalize_granularity(
200221
granularity: Optional[
201222
Union[
@@ -211,22 +232,25 @@ def _normalize_granularity(
211232
elif isinstance(granularity, (PerTensor, PerRow)):
212233
processed_granularity = (granularity, granularity)
213234
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
214-
if not (
215-
isinstance(granularity[0], (PerTensor, PerRow))
216-
and isinstance(granularity[1], (PerTensor, PerRow))
217-
):
218-
raise ValueError(
219-
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
220-
)
235+
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
236+
granularity[1], PerTensor
237+
)
238+
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
239+
granularity[1], PerRow
240+
)
241+
is_a_1_128_w_128_128 = granularity[0] == PerBlock((1, 128)) and granularity[
242+
1
243+
] == PerBlock((128, 128))
244+
245+
if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
246+
raise ValueError(f"Unsupported granularity types: {granularity}.")
221247
if not isinstance(granularity[0], type(granularity[1])):
222248
raise ValueError(
223-
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
249+
f"Different granularities for activation and weight are not supported: {granularity}."
224250
)
225251
processed_granularity = tuple(granularity)
226252
else:
227-
raise ValueError(
228-
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
229-
)
253+
raise ValueError(f"Invalid granularity specification: {granularity}.")
230254
return processed_granularity
231255

232256

@@ -243,12 +267,24 @@ def _check_hardware_support(
243267
AssertionError: If hardware doesn't support the requested granularity
244268
ValueError: If invalid granularity type is provided
245269
"""
246-
for _granularity in granularities:
247-
if not isinstance(_granularity, (PerTensor, PerRow)):
248-
raise ValueError(
249-
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
250-
)
270+
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
271+
granularities[1], PerTensor
272+
)
273+
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
274+
granularities[1], PerRow
275+
)
276+
is_a_1_128_w_128_128 = granularities[0] == PerBlock((1, 128)) and granularities[
277+
1
278+
] == PerBlock((128, 128))
251279

280+
if is_per_tensor or is_per_row:
252281
assert is_sm_at_least_89() or is_MI300(), (
253282
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
254283
)
284+
elif is_a_1_128_w_128_128:
285+
# TODO(future PR): look into AMD support
286+
assert is_sm_at_least_89(), (
287+
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
288+
)
289+
else:
290+
raise ValueError(f"Invalid granularities {granularities}.")

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
from torchao.float8.inference import (
1616
Float8MMConfig,
1717
FP8Granularity,
18+
_is_1_128_scaled,
19+
_is_128_128_scaled,
1820
_is_rowwise_scaled,
1921
_is_tensorwise_scaled,
2022
_slice_scale_for_dimension,
2123
addmm_float8_unwrapped_inference,
2224
preprocess_data,
2325
preprocess_scale,
2426
)
27+
from torchao.kernel.blockwise_quantization import (
28+
blockwise_fp8_gemm,
29+
)
2530
from torchao.quantization.granularity import PerRow, PerTensor
2631
from torchao.quantization.quant_primitives import (
2732
_choose_scale_float8,
@@ -337,19 +342,39 @@ def _(func, types, args, kwargs):
337342
"Input tensor must be rowwise block size"
338343
)
339344
w_scale = w_scale.transpose(-1, -2)
345+
elif _is_128_128_scaled(weight_tensor):
346+
assert _is_1_128_scaled(input_tensor), (
347+
"input_tensor must be 1x128 scaled"
348+
)
349+
w_scale = w_scale.transpose(-1, -2)
340350

341351
input_scale = preprocess_scale(input_scale, input_tensor.shape)
342352
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
343353

344-
return addmm_float8_unwrapped_inference(
345-
inpt_data,
346-
input_scale,
347-
w_data,
348-
w_scale,
349-
output_dtype=input_tensor.dtype,
350-
bias=bias,
351-
use_fast_accum=scaled_mm_config.use_fast_accum,
352-
).reshape(out_shape)
354+
if _is_128_128_scaled(weight_tensor):
355+
# TODO(before land): ensure fast_accum is False for blockwise
356+
# TODO(future PR): add testing for torch._scaled_mm with
357+
# blockwise scaling on CUDA 12.9
358+
# TODO(future PR): add fbgemm_gpu_genai path if available
359+
assert _is_1_128_scaled(input_tensor), "unsupported"
360+
res = blockwise_fp8_gemm(
361+
inpt_data,
362+
input_scale,
363+
w_data.t(),
364+
w_scale,
365+
block_size=128,
366+
)
367+
else:
368+
res = addmm_float8_unwrapped_inference(
369+
inpt_data,
370+
input_scale,
371+
w_data,
372+
w_scale,
373+
output_dtype=input_tensor.dtype,
374+
bias=bias,
375+
use_fast_accum=scaled_mm_config.use_fast_accum,
376+
)
377+
return res.reshape(out_shape)
353378
else:
354379
assert not isinstance(input_tensor, TorchAOBaseTensor), (
355380
"Expecting input_tensor to be unquantized"

0 commit comments

Comments
 (0)