Skip to content

Support DeepSeekV3-style block FP8 quantization #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

# reference: : https://github.com/vllm-project/vllm/pull/16362
def unpack_fp4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
Expand Down
64 changes: 62 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,15 @@ def dequantize(
elif scale.ndim == 2:
if scale.shape[1] == 1:
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
else:
elif scale.shape[0] == 1:
group_size = int(x_q.shape[1] / scale.shape[1])
args = QuantizationArgs(
strategy=QuantizationStrategy.GROUP, group_size=group_size
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a docstring explaining why this falls into the else condition?
Otherwise, I think this has grown complicated enough to easily fall prey to a bug

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka Do you have a sense of why this logic exists at all/ what cases its used? I'm not sure if inferring quant strat is really safe practice

else:
args = QuantizationArgs(
strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
)
else:
raise ValueError(
f"Could not infer a quantization strategy from scale with {scale.ndim} "
Expand Down Expand Up @@ -189,7 +193,63 @@ def _process_quantization(
q_min, q_max = calculate_range(args, x.device)
group_size = args.group_size

if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
if args.strategy == QuantizationStrategy.BLOCK:
original_shape = x.shape
rows, cols = x.shape[-2], x.shape[-1]
block_height, block_width = args.block_structure

# Ensure exact division (tensor dimensions must be divisible by block size)
if rows % block_height != 0:
raise ValueError(
f"Tensor height {rows} is not divisible by block_height {block_height}. "
f"Block quantization requires exact division."
)
if cols % block_width != 0:
raise ValueError(
f"Tensor width {cols} is not divisible by block_width {block_width}. "
f"Block quantization requires exact division."
)

# reshape into blocks
num_rows_blocks = rows // block_height
num_cols_blocks = cols // block_width
x_blocks = x.reshape(
num_rows_blocks,
block_height,
num_cols_blocks,
block_width,
).transpose(1, 2)

# expand scale/zero_point for blocks
sb = scale.unsqueeze(-1).unsqueeze(-1)
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
if do_quantize:
# quantize blocks
x_blocks = _quantize(
x=x_blocks,
scale=sb,
zero_point=zb,
q_min=q_min,
q_max=q_max,
args=args,
dtype=dtype,
global_scale=global_scale,
)
if do_dequantize:
# dequantize blocks
x_blocks = _dequantize(
x_q=x_blocks,
scale=sb,
zero_point=zb,
global_scale=global_scale,
)
# restore original shape
output = x_blocks.transpose(1, 2).reshape(original_shape)
elif args.strategy in (
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
):
n_dims = x.shape
if len(n_dims) > 2:
x = x.squeeze(0)
Expand Down
39 changes: 31 additions & 8 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from enum import Enum
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from compressed_tensors.utils import Aliasable
Expand Down Expand Up @@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
:param symmetric: whether or not quantization scale is symmetric about zero-point
:param strategy: string id determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block strategy, must be
of the format "2x4", "8x16", etc.
:param block_structure: 2d block structure to use for the block strategy; must be
a list of two ints [rows, cols] like [128, 128].
:param dynamic: set True to perform dynamic quantization - values will not be
calibrated during calibration phase, instead during inference new quantization
ranges will be observed with every sample. Defaults to False for static
Expand All @@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
block_structure: Optional[List[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
block_structure: Optional[List[int]] = None
block_structure: Optional[Tuple[int, int]] = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel towards keeping it a list since I don't think we can distinguish between list and tuple in the final json config

dynamic: Union[DynamicType, bool] = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: Optional[str] = Field(
Expand Down Expand Up @@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]:

return value

@field_validator("block_structure", mode="before")
def validate_block_structure(cls, value) -> Optional[List[int]]:
if value is None:
return value
# For backward compatibility, allow string format "2x4", "8x16", etc.
if isinstance(value, str):
try:
return [int(x) for x in value.split("x")]
except Exception:
raise ValueError(
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
)
if isinstance(value, (list, tuple)):
if len(value) != 2 or not all(isinstance(v, int) for v in value):
raise ValueError(
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
)
return list(value)
raise ValueError(
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
)

@field_validator("strategy", mode="before")
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
if isinstance(value, str):
Expand Down Expand Up @@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":

# infer observer w.r.t. dynamic
if dynamic:
if strategy not in (
supported_strategies = (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
):
QuantizationStrategy.GROUP,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly an aesthetic choice, but it might have aesthetic consequences if vllm wants to support fused input-weight quantization. Ex if input_quant_strategy == group and weight_quant_strategy == group

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add some validation on quant_scheme related to this as well

)
if strategy not in supported_strategies:
raise ValueError(
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
"must be used for dynamic quantization",
f"One of {supported_strategies} must be used for dynamic quantization"
)

if (
Expand Down
24 changes: 24 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,29 @@ def is_preset_scheme(name: str) -> bool:
),
)

# Block‐wise FP8 (deepseekv3-style quantization):
# static 128x128 per‐block weights and
# dynamic per‐token‐group activations
FP8_BLOCK = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.BLOCK,
symmetric=True,
dynamic=False,
block_structure=[128, 128],
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=True,
observer=None,
group_size=128,
),
)

PRESET_SCHEMES = {
# Unquantized (no-op)
"UNQUANTIZED": UNQUANTIZED,
Expand All @@ -257,6 +280,7 @@ def is_preset_scheme(name: str) -> bool:
# Float weight and activation schemes
"FP8": FP8,
"FP8_DYNAMIC": FP8_DYNAMIC,
"FP8_BLOCK": FP8_BLOCK,
"NVFP4A16": NVFP4A16,
"NVFP4": NVFP4,
}
13 changes: 11 additions & 2 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp(
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
elif args.strategy == QuantizationStrategy.TENSOR:
reduce_dims = None
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
elif args.strategy in (
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
):
if len(value.shape) > 2:
value = value.squeeze(0)

Expand All @@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp(
),
)
else:
supported_strategies = (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
)
raise ValueError(
"Dynamic quantization is only supported for ",
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
f"{supported_strategies}",
)

if not reduce_dims:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_examples/test_bitmask_compression_ipynb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import nbformat
import pytest


nbformat = pytest.importorskip("nbformat")
from nbconvert.preprocessors import ExecutePreprocessor


Expand Down
50 changes: 50 additions & 0 deletions tests/test_quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.


import math

import pytest
import torch
from compressed_tensors.quantization.lifecycle.forward import (
_process_quantization,
dequantize,
forward_quantize,
quantize,
Expand All @@ -29,6 +32,7 @@
QuantizationStrategy,
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from compressed_tensors.quantization.utils.helpers import calculate_range
from torch.nn import Linear


Expand Down Expand Up @@ -203,3 +207,49 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i
dtype=None,
g_idx=g_idx,
)


def test_process_quantization_block_static():
"""
Static block quantization (QuantizationStrategy.BLOCK) should split a 2D tensor
into blocks, quantize each block, and reassemble without changing shape.
"""
rows, cols = 8, 8
bh, bw = 2, 4
x = torch.randn(rows, cols)
args = QuantizationArgs(
num_bits=8,
type="float",
strategy=QuantizationStrategy.BLOCK,
symmetric=True,
dynamic=False,
block_structure=[bh, bw],
)
num_rb = math.ceil(rows / bh)
num_cb = math.ceil(cols / bw)
scale = torch.rand(num_rb, num_cb) + 0.1
zp = torch.zeros_like(scale)
q_min, q_max = calculate_range(args, x.device)
out = _process_quantization(
x=x,
scale=scale,
zero_point=zp,
args=args,
do_quantize=True,
do_dequantize=False,
dtype=None,
global_scale=None,
)
assert out.shape == x.shape
# full fake-quantize roundtrip
out2 = _process_quantization(
x=x,
scale=scale,
zero_point=zp,
args=args,
do_quantize=True,
do_dequantize=True,
dtype=None,
global_scale=None,
)
assert out2.shape == x.shape
2 changes: 1 addition & 1 deletion tests/test_quantization/test_quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_block():

block = QuantizationArgs(**kwargs)
assert block.strategy == QuantizationStrategy.BLOCK
assert block.block_structure == kwargs["block_structure"]
assert block.block_structure == [2, 4]


def test_infer_strategy():
Expand Down
29 changes: 28 additions & 1 deletion tests/test_quantization/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
from compressed_tensors.quantization.utils import (
calculate_qparams,
compute_dynamic_scales_and_zp,
generate_gparam,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -73,3 +77,26 @@ def test_fused_global_scales():
assert max_tensor_value.item() == pytest.approx(
FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001
)


@pytest.mark.parametrize(
"shape,group_size,exp_shape",
[
# Only batch size =1 is supported for dynamic GROUP quantization
((1, 4, 8), 4, torch.Size([4, 2])),
],
)
def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape):
"""
Dynamic group quantization should reduce activations in groups, producing
scales and zero points of shape [batch, num_groups].
"""
value = torch.randn(*shape)
args = QuantizationArgs(
strategy=QuantizationStrategy.GROUP,
group_size=group_size,
dynamic=True,
)
scale, zp = compute_dynamic_scales_and_zp(value, args, module=torch.nn.Module())
assert scale.shape == exp_shape
assert zp.shape == exp_shape
Loading