diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..5d521fbc 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -154,6 +154,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) + # reference: : https://github.com/vllm-project/vllm/pull/16362 def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b4ca3a82..01ac770c 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -111,11 +111,11 @@ def dequantize( elif scale.ndim == 2: if scale.shape[1] == 1: args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL) - else: + elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]): group_size = int(x_q.shape[1] / scale.shape[1]) - args = QuantizationArgs( - strategy=QuantizationStrategy.GROUP, group_size=group_size - ) + args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size) + else: + args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape) else: raise ValueError( f"Could not infer a quantization strategy from scale with {scale.ndim} " @@ -189,7 +189,63 @@ def _process_quantization( q_min, q_max = calculate_range(args, x.device) group_size = args.group_size - if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # blockwise FP8: quantize per 2D block, supports block_structure for static block quant + if args.strategy == QuantizationStrategy.BLOCK: + original_shape = x.shape + rows, cols = x.shape[-2], x.shape[-1] + block_height, block_width = args.block_structure + + # Ensure exact division (tensor dimensions must be divisible by block size) + if rows % block_height != 0: + raise ValueError( + f"Tensor height {rows} is not divisible by block_height {block_height}. " + f"Block quantization requires exact division." + ) + if cols % block_width != 0: + raise ValueError( + f"Tensor width {cols} is not divisible by block_width {block_width}. " + f"Block quantization requires exact division." + ) + + # reshape into blocks + num_rows_blocks = rows // block_height + num_cols_blocks = cols // block_width + x_blocks = x.reshape( + num_rows_blocks, + block_height, + num_cols_blocks, + block_width, + ).transpose(1, 2) + + # expand scale/zero_point for blocks + sb = scale.unsqueeze(-1).unsqueeze(-1) + zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None + if do_quantize: + # quantize blocks + x_blocks = _quantize( + x=x_blocks, + scale=sb, + zero_point=zb, + q_min=q_min, + q_max=q_max, + args=args, + dtype=dtype, + global_scale=global_scale, + ) + if do_dequantize: + # dequantize blocks + x_blocks = _dequantize( + x_q=x_blocks, + scale=sb, + zero_point=zb, + global_scale=global_scale, + ) + # restore original shape + output = x_blocks.transpose(1, 2).reshape(original_shape) + elif args.strategy in ( + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ): n_dims = x.shape if len(n_dims) > 2: x = x.squeeze(0) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index fdf34a28..53dbf88e 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -14,7 +14,7 @@ import warnings from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from compressed_tensors.utils import Aliasable @@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True): :param symmetric: whether or not quantization scale is symmetric about zero-point :param strategy: string id determining the scope of scale/zero-point to apply :param group_size: group length to use for the group strategy - :param block_structure: 2d block structure to use for the block strategy, must be - of the format "2x4", "8x16", etc. + :param block_structure: 2d block structure to use for the block strategy; must be + a list of two ints [rows, cols] like [128, 128]. :param dynamic: set True to perform dynamic quantization - values will not be calibrated during calibration phase, instead during inference new quantization ranges will be observed with every sample. Defaults to False for static @@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True): symmetric: bool = True group_size: Optional[int] = None strategy: Optional[QuantizationStrategy] = None - block_structure: Optional[str] = None + block_structure: Optional[List[int]] = None dynamic: Union[DynamicType, bool] = False actorder: Union[ActivationOrdering, bool, None] = None observer: Optional[str] = Field( @@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]: return value + @field_validator("block_structure", mode="before") + def validate_block_structure(cls, value) -> Optional[List[int]]: + if value is None: + return value + # For backward compatibility, allow string format "2x4", "8x16", etc. + if isinstance(value, str): + try: + return [int(x) for x in value.split("x")] + except Exception: + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + ) + if isinstance(value, (list, tuple)): + if len(value) != 2 or not all(isinstance(v, int) for v in value): + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + ) + return list(value) + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + ) + @field_validator("strategy", mode="before") def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: if isinstance(value, str): @@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": # infer observer w.r.t. dynamic if dynamic: - if strategy not in ( + supported_strategies = ( QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP, - ): + QuantizationStrategy.GROUP, + ) + if strategy not in supported_strategies: raise ValueError( - f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} " - "must be used for dynamic quantization", + f"One of {supported_strategies} must be used for dynamic quantization" ) if ( diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 753afc6c..1b3b2f87 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -243,6 +243,29 @@ def is_preset_scheme(name: str) -> bool: ), ) +# Block‐wise FP8 (deepseekv3-style quantization): +# static 128x128 per‐block weights and +# dynamic per‐token‐group activations +FP8_BLOCK = dict( + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.BLOCK, + symmetric=True, + dynamic=False, + block_structure=[128, 128], + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=True, + observer=None, + group_size=128, + ), +) + PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, @@ -257,6 +280,7 @@ def is_preset_scheme(name: str) -> bool: # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, + "FP8_BLOCK": FP8_BLOCK, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5d855f75..c3c2899d 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp( reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) elif args.strategy == QuantizationStrategy.TENSOR: reduce_dims = None - elif args.strategy == QuantizationStrategy.TENSOR_GROUP: + elif args.strategy in ( + QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.GROUP, + ): if len(value.shape) > 2: value = value.squeeze(0) @@ -187,9 +190,15 @@ def compute_dynamic_scales_and_zp( ), ) else: + supported_strategies = ( + QuantizationStrategy.TOKEN, + QuantizationStrategy.TENSOR, + QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.GROUP, + ) raise ValueError( "Dynamic quantization is only supported for ", - f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}", + f"{supported_strategies}", ) if not reduce_dims: diff --git a/tests/test_examples/test_bitmask_compression_ipynb.py b/tests/test_examples/test_bitmask_compression_ipynb.py index a3ee9d2a..12f196f6 100644 --- a/tests/test_examples/test_bitmask_compression_ipynb.py +++ b/tests/test_examples/test_bitmask_compression_ipynb.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import nbformat import pytest + + +nbformat = pytest.importorskip("nbformat") from nbconvert.preprocessors import ExecutePreprocessor diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index deb6cc30..a6b20f87 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -13,9 +13,12 @@ # limitations under the License. +import math + import pytest import torch from compressed_tensors.quantization.lifecycle.forward import ( + _process_quantization, dequantize, forward_quantize, quantize, @@ -29,6 +32,7 @@ QuantizationStrategy, ) from compressed_tensors.quantization.quant_config import QuantizationStatus +from compressed_tensors.quantization.utils.helpers import calculate_range from torch.nn import Linear @@ -203,3 +207,49 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i dtype=None, g_idx=g_idx, ) + + +def test_process_quantization_block_static(): + """ + Static block quantization (QuantizationStrategy.BLOCK) should split a 2D tensor + into blocks, quantize each block, and reassemble without changing shape. + """ + rows, cols = 8, 8 + bh, bw = 2, 4 + x = torch.randn(rows, cols) + args = QuantizationArgs( + num_bits=8, + type="float", + strategy=QuantizationStrategy.BLOCK, + symmetric=True, + dynamic=False, + block_structure=[bh, bw], + ) + num_rb = math.ceil(rows / bh) + num_cb = math.ceil(cols / bw) + scale = torch.rand(num_rb, num_cb) + 0.1 + zp = torch.zeros_like(scale) + q_min, q_max = calculate_range(args, x.device) + out = _process_quantization( + x=x, + scale=scale, + zero_point=zp, + args=args, + do_quantize=True, + do_dequantize=False, + dtype=None, + global_scale=None, + ) + assert out.shape == x.shape + # full fake-quantize roundtrip + out2 = _process_quantization( + x=x, + scale=scale, + zero_point=zp, + args=args, + do_quantize=True, + do_dequantize=True, + dtype=None, + global_scale=None, + ) + assert out2.shape == x.shape diff --git a/tests/test_quantization/test_quant_args.py b/tests/test_quantization/test_quant_args.py index 3a42a626..bec01f11 100644 --- a/tests/test_quantization/test_quant_args.py +++ b/tests/test_quantization/test_quant_args.py @@ -59,7 +59,7 @@ def test_block(): block = QuantizationArgs(**kwargs) assert block.strategy == QuantizationStrategy.BLOCK - assert block.block_structure == kwargs["block_structure"] + assert block.block_structure == [2, 4] def test_infer_strategy(): diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index 4cdc0c8a..2c6b1224 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -20,7 +20,11 @@ QuantizationArgs, QuantizationStrategy, ) -from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam +from compressed_tensors.quantization.utils import ( + calculate_qparams, + compute_dynamic_scales_and_zp, + generate_gparam, +) @pytest.mark.parametrize( @@ -73,3 +77,26 @@ def test_fused_global_scales(): assert max_tensor_value.item() == pytest.approx( FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001 ) + + +@pytest.mark.parametrize( + "shape,group_size,exp_shape", + [ + # Only batch size =1 is supported for dynamic GROUP quantization + ((1, 4, 8), 4, torch.Size([4, 2])), + ], +) +def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape): + """ + Dynamic group quantization should reduce activations in groups, producing + scales and zero points of shape [batch, num_groups]. + """ + value = torch.randn(*shape) + args = QuantizationArgs( + strategy=QuantizationStrategy.GROUP, + group_size=group_size, + dynamic=True, + ) + scale, zp = compute_dynamic_scales_and_zp(value, args, module=torch.nn.Module()) + assert scale.shape == exp_shape + assert zp.shape == exp_shape