From 558870ccee8080f67fbdbacd290e223d0571a36a Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 30 Jun 2025 19:26:37 +0000 Subject: [PATCH 1/7] Support DeepSeekV3-style block FP8 quantization Signed-off-by: mgoin --- .../quantization/lifecycle/forward.py | 46 ++++++++++++++++- .../quantization/quant_args.py | 39 ++++++++++++--- .../quantization/quant_scheme.py | 24 +++++++++ .../quantization/utils/helpers.py | 9 +++- .../test_bitmask_compression_ipynb.py | 2 +- .../lifecycle/test_forward.py | 50 +++++++++++++++++++ tests/test_quantization/test_quant_args.py | 2 +- .../test_utils/test_helpers.py | 24 ++++++++- 8 files changed, 182 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b4ca3a82..05a890a9 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import wraps +import math from math import ceil from typing import Optional @@ -189,7 +190,50 @@ def _process_quantization( q_min, q_max = calculate_range(args, x.device) group_size = args.group_size - if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + # blockwise FP8: quantize per 2D block, supports block_structure for static block quant + if args.strategy == QuantizationStrategy.BLOCK: + # x: [*, H, W] or [H, W] + bs = args.block_structure + if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(dim, int) for dim in bs)): + raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].") + b_h, b_w = bs + rows, cols = x.shape[-2], x.shape[-1] + # ensure exact division (assumes shapes match spec) + n_rb = math.ceil(rows / b_h) + n_cb = math.ceil(cols / b_w) + # reshape into blocks + x_blocks = x.reshape( + n_rb, + b_h, + n_cb, + b_w, + ).transpose(1, 2) + # expand scale/zero_point for blocks + sb = scale.unsqueeze(-1).unsqueeze(-1) + zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None + # quantize blocks + qb = _quantize( + x=x_blocks, + scale=sb, + zero_point=zb, + q_min=q_min, + q_max=q_max, + args=args, + dtype=dtype, + global_scale=global_scale, + ) if do_quantize else x_blocks + # dequantize if needed + if do_dequantize: + qb = _dequantize( + x_q=qb, + scale=sb, + zero_point=zb, + global_scale=global_scale, + ) + # restore original shape + out = qb.transpose(1, 2).reshape(rows, cols) + return out + elif args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): n_dims = x.shape if len(n_dims) > 2: x = x.squeeze(0) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index fdf34a28..53dbf88e 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -14,7 +14,7 @@ import warnings from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from compressed_tensors.utils import Aliasable @@ -153,8 +153,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True): :param symmetric: whether or not quantization scale is symmetric about zero-point :param strategy: string id determining the scope of scale/zero-point to apply :param group_size: group length to use for the group strategy - :param block_structure: 2d block structure to use for the block strategy, must be - of the format "2x4", "8x16", etc. + :param block_structure: 2d block structure to use for the block strategy; must be + a list of two ints [rows, cols] like [128, 128]. :param dynamic: set True to perform dynamic quantization - values will not be calibrated during calibration phase, instead during inference new quantization ranges will be observed with every sample. Defaults to False for static @@ -169,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True): symmetric: bool = True group_size: Optional[int] = None strategy: Optional[QuantizationStrategy] = None - block_structure: Optional[str] = None + block_structure: Optional[List[int]] = None dynamic: Union[DynamicType, bool] = False actorder: Union[ActivationOrdering, bool, None] = None observer: Optional[str] = Field( @@ -207,6 +207,28 @@ def validate_group(cls, value) -> Union[int, None]: return value + @field_validator("block_structure", mode="before") + def validate_block_structure(cls, value) -> Optional[List[int]]: + if value is None: + return value + # For backward compatibility, allow string format "2x4", "8x16", etc. + if isinstance(value, str): + try: + return [int(x) for x in value.split("x")] + except Exception: + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + ) + if isinstance(value, (list, tuple)): + if len(value) != 2 or not all(isinstance(v, int) for v in value): + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + ) + return list(value) + raise ValueError( + f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]." + ) + @field_validator("strategy", mode="before") def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]: if isinstance(value, str): @@ -277,14 +299,15 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": # infer observer w.r.t. dynamic if dynamic: - if strategy not in ( + supported_strategies = ( QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP, - ): + QuantizationStrategy.GROUP, + ) + if strategy not in supported_strategies: raise ValueError( - f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} " - "must be used for dynamic quantization", + f"One of {supported_strategies} must be used for dynamic quantization" ) if ( diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 753afc6c..a1bf7883 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -243,6 +243,29 @@ def is_preset_scheme(name: str) -> bool: ), ) +# Block‐wise FP8 (deepseekv3-style quantization): +# static 128x128 per‐block weights and +# dynamic per‐token‐group activations +FP8_BLOCK = dict( + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.BLOCK, + symmetric=True, + dynamic=False, + block_structure=[128, 128], + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=True, + observer=None, + group_size=128, + ), +) + PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, @@ -257,6 +280,7 @@ def is_preset_scheme(name: str) -> bool: # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, + "FP8_BLOCK": FP8_BLOCK, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5d855f75..5f6aff7d 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -171,7 +171,7 @@ def compute_dynamic_scales_and_zp( reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) elif args.strategy == QuantizationStrategy.TENSOR: reduce_dims = None - elif args.strategy == QuantizationStrategy.TENSOR_GROUP: + elif args.strategy in (QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP): if len(value.shape) > 2: value = value.squeeze(0) @@ -189,7 +189,12 @@ def compute_dynamic_scales_and_zp( else: raise ValueError( "Dynamic quantization is only supported for ", - f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}", + f"{( + QuantizationStrategy.TOKEN, + QuantizationStrategy.TENSOR, + QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.GROUP, + )}", ) if not reduce_dims: diff --git a/tests/test_examples/test_bitmask_compression_ipynb.py b/tests/test_examples/test_bitmask_compression_ipynb.py index a3ee9d2a..fadd9e44 100644 --- a/tests/test_examples/test_bitmask_compression_ipynb.py +++ b/tests/test_examples/test_bitmask_compression_ipynb.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import nbformat import pytest +nbformat = pytest.importorskip("nbformat") from nbconvert.preprocessors import ExecutePreprocessor diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index deb6cc30..ef10b087 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -30,6 +30,10 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from torch.nn import Linear +import math + +from compressed_tensors.quantization.lifecycle.forward import _process_quantization +from compressed_tensors.quantization.utils.helpers import calculate_range def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: @@ -203,3 +207,49 @@ def test_dequantize(num_bits, type, strategy, group_size, scale, zero_point, g_i dtype=None, g_idx=g_idx, ) + + +def test_process_quantization_block_static(): + """ + Static block quantization (QuantizationStrategy.BLOCK) should split a 2D tensor + into blocks, quantize each block, and reassemble without changing shape. + """ + rows, cols = 8, 8 + bh, bw = 2, 4 + x = torch.randn(rows, cols) + args = QuantizationArgs( + num_bits=8, + type="float", + strategy=QuantizationStrategy.BLOCK, + symmetric=True, + dynamic=False, + block_structure=[bh, bw], + ) + num_rb = math.ceil(rows / bh) + num_cb = math.ceil(cols / bw) + scale = torch.rand(num_rb, num_cb) + 0.1 + zp = torch.zeros_like(scale) + q_min, q_max = calculate_range(args, x.device) + out = _process_quantization( + x=x, + scale=scale, + zero_point=zp, + args=args, + do_quantize=True, + do_dequantize=False, + dtype=None, + global_scale=None, + ) + assert out.shape == x.shape + # full fake-quantize roundtrip + out2 = _process_quantization( + x=x, + scale=scale, + zero_point=zp, + args=args, + do_quantize=True, + do_dequantize=True, + dtype=None, + global_scale=None, + ) + assert out2.shape == x.shape diff --git a/tests/test_quantization/test_quant_args.py b/tests/test_quantization/test_quant_args.py index 3a42a626..bec01f11 100644 --- a/tests/test_quantization/test_quant_args.py +++ b/tests/test_quantization/test_quant_args.py @@ -59,7 +59,7 @@ def test_block(): block = QuantizationArgs(**kwargs) assert block.strategy == QuantizationStrategy.BLOCK - assert block.block_structure == kwargs["block_structure"] + assert block.block_structure == [2, 4] def test_infer_strategy(): diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index 4cdc0c8a..8c551f61 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -20,7 +20,7 @@ QuantizationArgs, QuantizationStrategy, ) -from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam +from compressed_tensors.quantization.utils import calculate_qparams, compute_dynamic_scales_and_zp, generate_gparam @pytest.mark.parametrize( @@ -73,3 +73,25 @@ def test_fused_global_scales(): assert max_tensor_value.item() == pytest.approx( FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001 ) + +@pytest.mark.parametrize( + "shape,group_size,exp_shape", + [ + # Only batch size =1 is supported for dynamic GROUP quantization + ((1, 4, 8), 4, torch.Size([4, 2])), + ], +) +def test_compute_dynamic_scales_and_zp_group(shape, group_size, exp_shape): + """ + Dynamic group quantization should reduce activations in groups, producing + scales and zero points of shape [batch, num_groups]. + """ + value = torch.randn(*shape) + args = QuantizationArgs( + strategy=QuantizationStrategy.GROUP, + group_size=group_size, + dynamic=True, + ) + scale, zp = compute_dynamic_scales_and_zp(value, args, module=torch.nn.Module()) + assert scale.shape == exp_shape + assert zp.shape == exp_shape From 759be7afac68b71455f92bdf7ecfe15a8e873009 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 30 Jun 2025 19:39:22 +0000 Subject: [PATCH 2/7] Remove math Signed-off-by: mgoin --- src/compressed_tensors/quantization/lifecycle/forward.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 05a890a9..4e550c6e 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -13,7 +13,6 @@ # limitations under the License. from functools import wraps -import math from math import ceil from typing import Optional @@ -199,8 +198,8 @@ def _process_quantization( b_h, b_w = bs rows, cols = x.shape[-2], x.shape[-1] # ensure exact division (assumes shapes match spec) - n_rb = math.ceil(rows / b_h) - n_cb = math.ceil(cols / b_w) + n_rb = ceil(rows / b_h) + n_cb = ceil(cols / b_w) # reshape into blocks x_blocks = x.reshape( n_rb, From 010e903c3522e7cf3fd79222afaf595d331ea278 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Jul 2025 00:44:56 +0000 Subject: [PATCH 3/7] Enforce divisible shapes Signed-off-by: mgoin --- .../quantization/lifecycle/forward.py | 68 ++++++++++++------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 4e550c6e..c84f0563 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -111,11 +111,15 @@ def dequantize( elif scale.ndim == 2: if scale.shape[1] == 1: args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL) - else: + elif scale.shape[0] == 1: group_size = int(x_q.shape[1] / scale.shape[1]) args = QuantizationArgs( strategy=QuantizationStrategy.GROUP, group_size=group_size ) + else: + args = QuantizationArgs( + strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape + ) else: raise ValueError( f"Could not infer a quantization strategy from scale with {scale.ndim} " @@ -195,43 +199,57 @@ def _process_quantization( bs = args.block_structure if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(dim, int) for dim in bs)): raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].") - b_h, b_w = bs + original_shape = x.shape rows, cols = x.shape[-2], x.shape[-1] - # ensure exact division (assumes shapes match spec) - n_rb = ceil(rows / b_h) - n_cb = ceil(cols / b_w) + block_height, block_width = bs + + # Ensure exact division (tensor dimensions must be divisible by block size) + if rows % block_height != 0: + raise ValueError( + f"Tensor height {rows} is not divisible by block_height {block_height}. " + f"Block quantization requires exact division." + ) + if cols % block_width != 0: + raise ValueError( + f"Tensor width {cols} is not divisible by block_width {block_width}. " + f"Block quantization requires exact division." + ) + # reshape into blocks + num_rows_blocks = rows // block_height + num_cols_blocks = cols // block_width x_blocks = x.reshape( - n_rb, - b_h, - n_cb, - b_w, + num_rows_blocks, + block_height, + num_cols_blocks, + block_width, ).transpose(1, 2) + # expand scale/zero_point for blocks sb = scale.unsqueeze(-1).unsqueeze(-1) zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None - # quantize blocks - qb = _quantize( - x=x_blocks, - scale=sb, - zero_point=zb, - q_min=q_min, - q_max=q_max, - args=args, - dtype=dtype, - global_scale=global_scale, - ) if do_quantize else x_blocks - # dequantize if needed + if do_quantize: + # quantize blocks + x_blocks = _quantize( + x=x_blocks, + scale=sb, + zero_point=zb, + q_min=q_min, + q_max=q_max, + args=args, + dtype=dtype, + global_scale=global_scale, + ) if do_dequantize: - qb = _dequantize( - x_q=qb, + # dequantize blocks + x_blocks = _dequantize( + x_q=x_blocks, scale=sb, zero_point=zb, global_scale=global_scale, ) # restore original shape - out = qb.transpose(1, 2).reshape(rows, cols) - return out + output = x_blocks.transpose(1, 2).reshape(original_shape) elif args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): n_dims = x.shape if len(n_dims) > 2: From 7e642f774e879cd4b0835b4742f0c8f682c71550 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Jul 2025 01:07:01 +0000 Subject: [PATCH 4/7] Format Signed-off-by: mgoin --- .../quantized_compressors/nvfp4_quantized.py | 1 + .../quantization/lifecycle/forward.py | 21 +++++++++++++------ .../quantization/quant_scheme.py | 2 +- .../quantization/utils/helpers.py | 5 ++++- .../test_bitmask_compression_ipynb.py | 2 ++ .../lifecycle/test_forward.py | 8 +++---- .../test_utils/test_helpers.py | 7 ++++++- 7 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py index 5f348e91..5d521fbc 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py @@ -154,6 +154,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor: [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 ) + # reference: : https://github.com/vllm-project/vllm/pull/16362 def unpack_fp4_from_uint8( a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16 diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index c84f0563..f73d5596 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -197,12 +197,18 @@ def _process_quantization( if args.strategy == QuantizationStrategy.BLOCK: # x: [*, H, W] or [H, W] bs = args.block_structure - if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(dim, int) for dim in bs)): - raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].") + if not ( + isinstance(bs, (list, tuple)) + and len(bs) == 2 + and all(isinstance(dim, int) for dim in bs) + ): + raise ValueError( + f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols]." + ) original_shape = x.shape rows, cols = x.shape[-2], x.shape[-1] block_height, block_width = bs - + # Ensure exact division (tensor dimensions must be divisible by block size) if rows % block_height != 0: raise ValueError( @@ -214,7 +220,7 @@ def _process_quantization( f"Tensor width {cols} is not divisible by block_width {block_width}. " f"Block quantization requires exact division." ) - + # reshape into blocks num_rows_blocks = rows // block_height num_cols_blocks = cols // block_width @@ -224,7 +230,7 @@ def _process_quantization( num_cols_blocks, block_width, ).transpose(1, 2) - + # expand scale/zero_point for blocks sb = scale.unsqueeze(-1).unsqueeze(-1) zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None @@ -250,7 +256,10 @@ def _process_quantization( ) # restore original shape output = x_blocks.transpose(1, 2).reshape(original_shape) - elif args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): + elif args.strategy in ( + QuantizationStrategy.GROUP, + QuantizationStrategy.TENSOR_GROUP, + ): n_dims = x.shape if len(n_dims) > 2: x = x.squeeze(0) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a1bf7883..1b3b2f87 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -244,7 +244,7 @@ def is_preset_scheme(name: str) -> bool: ) # Block‐wise FP8 (deepseekv3-style quantization): -# static 128x128 per‐block weights and +# static 128x128 per‐block weights and # dynamic per‐token‐group activations FP8_BLOCK = dict( weights=QuantizationArgs( diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5f6aff7d..9868e35e 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -171,7 +171,10 @@ def compute_dynamic_scales_and_zp( reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim) elif args.strategy == QuantizationStrategy.TENSOR: reduce_dims = None - elif args.strategy in (QuantizationStrategy.TENSOR_GROUP, QuantizationStrategy.GROUP): + elif args.strategy in ( + QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.GROUP, + ): if len(value.shape) > 2: value = value.squeeze(0) diff --git a/tests/test_examples/test_bitmask_compression_ipynb.py b/tests/test_examples/test_bitmask_compression_ipynb.py index fadd9e44..12f196f6 100644 --- a/tests/test_examples/test_bitmask_compression_ipynb.py +++ b/tests/test_examples/test_bitmask_compression_ipynb.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest + + nbformat = pytest.importorskip("nbformat") from nbconvert.preprocessors import ExecutePreprocessor diff --git a/tests/test_quantization/lifecycle/test_forward.py b/tests/test_quantization/lifecycle/test_forward.py index ef10b087..a6b20f87 100644 --- a/tests/test_quantization/lifecycle/test_forward.py +++ b/tests/test_quantization/lifecycle/test_forward.py @@ -13,9 +13,12 @@ # limitations under the License. +import math + import pytest import torch from compressed_tensors.quantization.lifecycle.forward import ( + _process_quantization, dequantize, forward_quantize, quantize, @@ -29,11 +32,8 @@ QuantizationStrategy, ) from compressed_tensors.quantization.quant_config import QuantizationStatus -from torch.nn import Linear -import math - -from compressed_tensors.quantization.lifecycle.forward import _process_quantization from compressed_tensors.quantization.utils.helpers import calculate_range +from torch.nn import Linear def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor: diff --git a/tests/test_quantization/test_utils/test_helpers.py b/tests/test_quantization/test_utils/test_helpers.py index 8c551f61..2c6b1224 100644 --- a/tests/test_quantization/test_utils/test_helpers.py +++ b/tests/test_quantization/test_utils/test_helpers.py @@ -20,7 +20,11 @@ QuantizationArgs, QuantizationStrategy, ) -from compressed_tensors.quantization.utils import calculate_qparams, compute_dynamic_scales_and_zp, generate_gparam +from compressed_tensors.quantization.utils import ( + calculate_qparams, + compute_dynamic_scales_and_zp, + generate_gparam, +) @pytest.mark.parametrize( @@ -74,6 +78,7 @@ def test_fused_global_scales(): FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001 ) + @pytest.mark.parametrize( "shape,group_size,exp_shape", [ From bc5bea30ff8b5a001ae35404a469019c2c7a7c5f Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Jul 2025 01:08:26 +0000 Subject: [PATCH 5/7] Remove validation Signed-off-by: mgoin --- .../quantization/lifecycle/forward.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index f73d5596..dc775c28 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -195,19 +195,9 @@ def _process_quantization( # blockwise FP8: quantize per 2D block, supports block_structure for static block quant if args.strategy == QuantizationStrategy.BLOCK: - # x: [*, H, W] or [H, W] - bs = args.block_structure - if not ( - isinstance(bs, (list, tuple)) - and len(bs) == 2 - and all(isinstance(dim, int) for dim in bs) - ): - raise ValueError( - f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols]." - ) original_shape = x.shape rows, cols = x.shape[-2], x.shape[-1] - block_height, block_width = bs + block_height, block_width = args.block_structure # Ensure exact division (tensor dimensions must be divisible by block size) if rows % block_height != 0: From cda798e88b02d28d5876036161a3be6f11a1da5a Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Jul 2025 01:37:51 +0000 Subject: [PATCH 6/7] Fix string Signed-off-by: mgoin --- .../quantization/utils/helpers.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 9868e35e..c3c2899d 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -190,14 +190,15 @@ def compute_dynamic_scales_and_zp( ), ) else: + supported_strategies = ( + QuantizationStrategy.TOKEN, + QuantizationStrategy.TENSOR, + QuantizationStrategy.TENSOR_GROUP, + QuantizationStrategy.GROUP, + ) raise ValueError( "Dynamic quantization is only supported for ", - f"{( - QuantizationStrategy.TOKEN, - QuantizationStrategy.TENSOR, - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - )}", + f"{supported_strategies}", ) if not reduce_dims: From da55b66663c752cd5249a621956a4879aeb397ca Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Thu, 10 Jul 2025 21:14:35 -0400 Subject: [PATCH 7/7] fix failed tests Signed-off-by: shanjiaz --- .../quantization/lifecycle/forward.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index dc775c28..01ac770c 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -111,15 +111,11 @@ def dequantize( elif scale.ndim == 2: if scale.shape[1] == 1: args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL) - elif scale.shape[0] == 1: + elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]): group_size = int(x_q.shape[1] / scale.shape[1]) - args = QuantizationArgs( - strategy=QuantizationStrategy.GROUP, group_size=group_size - ) + args = QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=group_size) else: - args = QuantizationArgs( - strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape - ) + args = QuantizationArgs(strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape) else: raise ValueError( f"Could not infer a quantization strategy from scale with {scale.ndim} "