Skip to content

Support DeepSeekV3-style block FP8 quantization #372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,50 @@ def _process_quantization(
q_min, q_max = calculate_range(args, x.device)
group_size = args.group_size

if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
if args.strategy == QuantizationStrategy.BLOCK:
# x: [*, H, W] or [H, W]
bs = args.block_structure
if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(dim, int) for dim in bs)):
raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].")
b_h, b_w = bs
rows, cols = x.shape[-2], x.shape[-1]
# ensure exact division (assumes shapes match spec)
n_rb = ceil(rows / b_h)
n_cb = ceil(cols / b_w)
# reshape into blocks
x_blocks = x.reshape(
n_rb,
b_h,
n_cb,
b_w,
).transpose(1, 2)
# expand scale/zero_point for blocks
sb = scale.unsqueeze(-1).unsqueeze(-1)
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
# quantize blocks
qb = _quantize(
x=x_blocks,
scale=sb,
zero_point=zb,
q_min=q_min,
q_max=q_max,
args=args,
dtype=dtype,
global_scale=global_scale,
) if do_quantize else x_blocks
# dequantize if needed
if do_dequantize:
qb = _dequantize(
x_q=qb,
scale=sb,
zero_point=zb,
global_scale=global_scale,
)
# restore original shape
out = qb.transpose(1, 2).reshape(rows, cols)
return out
elif args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
n_dims = x.shape
if len(n_dims) > 2:
x = x.squeeze(0)
Expand Down
39 changes: 31 additions & 8 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from enum import Enum
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from compressed_tensors.utils import Aliasable
Expand Down Expand Up @@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
:param symmetric: whether or not quantization scale is symmetric about zero-point
:param strategy: string id determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block strategy, must be
of the format "2x4", "8x16", etc.
:param block_structure: 2d block structure to use for the block strategy; must be
a list of two ints [rows, cols] like [128, 128].
:param dynamic: set True to perform dynamic quantization - values will not be
calibrated during calibration phase, instead during inference new quantization
ranges will be observed with every sample. Defaults to False for static
Expand All @@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
block_structure: Optional[List[int]] = None
dynamic: Union[DynamicType, bool] = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: Optional[str] = Field(
Expand Down Expand Up @@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]:

return value

@field_validator("block_structure", mode="before")
def validate_block_structure(cls, value) -> Optional[List[int]]:
if value is None:
return value
# For backward compatibility, allow string format "2x4", "8x16", etc.
if isinstance(value, str):
try:
return [int(x) for x in value.split("x")]
except Exception:
raise ValueError(
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
)
if isinstance(value, (list, tuple)):
if len(value) != 2 or not all(isinstance(v, int) for v in value):
raise ValueError(
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
)
return list(value)
raise ValueError(
f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
)

@field_validator("strategy", mode="before")
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
if isinstance(value, str):
Expand Down Expand Up @@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":

# infer observer w.r.t. dynamic
if dynamic:
if strategy not in (
supported_strategies = (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
):
QuantizationStrategy.GROUP,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly an aesthetic choice, but it might have aesthetic consequences if vllm wants to support fused input-weight quantization. Ex if input_quant_strategy == group and weight_quant_strategy == group

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add some validation on quant_scheme related to this as well

)
if strategy not in supported_strategies:
raise ValueError(
f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
"must be used for dynamic quantization",
f"One of {supported_strategies} must be used for dynamic quantization"
)

if (
Expand Down
24 changes: 24 additions & 0 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,29 @@ def is_preset_scheme(name: str) -> bool:
),
)

# Block‐wise FP8 (deepseekv3-style quantization):
# static 128x128 per‐block weights and
# dynamic per‐token‐group activations
FP8_BLOCK = dict(
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.BLOCK,
symmetric=True,
dynamic=False,
block_structure=[128, 128],
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=True,
observer=None,
group_size=128,
),
)

PRESET_SCHEMES = {
# Unquantized (no-op)
"UNQUANTIZED": UNQUANTIZED,
Expand All @@ -257,6 +280,7 @@ def is_preset_scheme(name: str) -> bool:
# Float weight and activation schemes
"FP8": FP8,
"FP8_DYNAMIC": FP8_DYNAMIC,
"FP8_BLOCK": FP8_BLOCK,
"NVFP4A16": NVFP4A16,
"NVFP4": NVFP4,
}
9 changes: 7 additions & 2 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def compute_dynamic_scales_and_zp(
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
elif args.strategy == QuantizationStrategy.TENSOR:
reduce_dims = None
elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
elif args.strategy in (QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP):
if len(value.shape) > 2:
value = value.squeeze(0)

Expand All @@ -189,7 +189,12 @@ def compute_dynamic_scales_and_zp(
else:
raise ValueError(
"Dynamic quantization is only supported for ",
f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
f"{(
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
)}",
)

if not reduce_dims:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_examples/test_bitmask_compression_ipynb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import nbformat
import pytest
nbformat = pytest.importorskip("nbformat")
from nbconvert.preprocessors import ExecutePreprocessor


Expand Down
50 changes: 50 additions & 0 deletions tests/test_quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
)
from compressed_tensors.quantization.quant_config import QuantizationStatus
from torch.nn import Linear
import math

from compressed_tensors.quantization.lifecycle.forward import _process_quantization
from compressed_tensors.quantization.utils.helpers import calculate_range


def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
Expand Down Expand Up @@ -203,3 +207,49 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i
dtype=None,
g_idx=g_idx,
)


def test_process_quantization_block_static():
"""
Static block quantization (QuantizationStrategy.BLOCK) should split a 2D tensor
into blocks, quantize each block, and reassemble without changing shape.
"""
rows, cols = 8, 8
bh, bw = 2, 4
x = torch.randn(rows, cols)
args = QuantizationArgs(
num_bits=8,
type="float",
strategy=QuantizationStrategy.BLOCK,
symmetric=True,
dynamic=False,
block_structure=[bh, bw],
)
num_rb = math.ceil(rows / bh)
num_cb = math.ceil(cols / bw)
scale = torch.rand(num_rb, num_cb) + 0.1
zp = torch.zeros_like(scale)
q_min, q_max = calculate_range(args, x.device)
out = _process_quantization(
x=x,
scale=scale,
zero_point=zp,
args=args,
do_quantize=True,
do_dequantize=False,
dtype=None,
global_scale=None,
)
assert out.shape == x.shape
# full fake-quantize roundtrip
out2 = _process_quantization(
x=x,
scale=scale,
zero_point=zp,
args=args,
do_quantize=True,
do_dequantize=True,
dtype=None,
global_scale=None,
)
assert out2.shape == x.shape
2 changes: 1 addition & 1 deletion tests/test_quantization/test_quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_block():

block = QuantizationArgs(**kwargs)
assert block.strategy == QuantizationStrategy.BLOCK
assert block.block_structure == kwargs["block_structure"]
assert block.block_structure == [2, 4]


def test_infer_strategy():
Expand Down
24 changes: 23 additions & 1 deletion tests/test_quantization/test_utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
QuantizationArgs,
QuantizationStrategy,
)
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
from compressed_tensors.quantization.utils import calculate_qparams, compute_dynamic_scales_and_zp, generate_gparam


@pytest.mark.parametrize(
Expand Down Expand Up @@ -73,3 +73,25 @@ def test_fused_global_scales():
assert max_tensor_value.item() == pytest.approx(
FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001
)

@pytest.mark.parametrize(
"shape,group_size,exp_shape",
[
# Only batch size =1 is supported for dynamic GROUP quantization
((1, 4, 8), 4, torch.Size([4, 2])),
],
)
def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape):
"""
Dynamic group quantization should reduce activations in groups, producing
scales and zero points of shape [batch, num_groups].
"""
value = torch.randn(*shape)
args = QuantizationArgs(
strategy=QuantizationStrategy.GROUP,
group_size=group_size,
dynamic=True,
)
scale, zp = compute_dynamic_scales_and_zp(value, args, module=torch.nn.Module())
assert scale.shape == exp_shape
assert zp.shape == exp_shape
Loading