From 451219a62f1b9c9903502efb405c36e25899fe98 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 30 Jun 2025 19:51:04 +0000 Subject: [PATCH 1/5] Support DeepSeekV3-style block FP8 quantization Signed-off-by: mgoin --- docs/schemes.md | 3 ++ .../fp8_block_example.py | 35 +++++++++++++++++++ .../modifiers/quantization/calibration.py | 21 +++++++++-- src/llmcompressor/observers/base.py | 30 ++++++++++++---- .../modifiers/quantization/test_base.py | 26 ++++++++++++++ 5 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 examples/quantization_w8a8_fp8/fp8_block_example.py diff --git a/docs/schemes.md b/docs/schemes.md index 19ff746e4..bbfd1f855 100644 --- a/docs/schemes.md +++ b/docs/schemes.md @@ -19,6 +19,9 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la - Useful for speed ups in high QPS regimes or offline serving on vLLM. - Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell). +### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py) +- Uses block-wise quantization to compress weights to FP8 in (commonly 128×128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM. + ## Sparsification Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include: diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py new file mode 100644 index 000000000..b5d6ca1f9 --- /dev/null +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -0,0 +1,35 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +MODEL_ID = "Qwen/Qwen3-0.6B" + +# Load model. +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +recipe = QuantizationModifier( + targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] +) + +# Apply quantization. +oneshot(model=model, recipe=recipe) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(tokenizer.decode(output[0])) +print("==========================================") + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-BLOCK" +model.save_pretrained(SAVE_DIR) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 63e1c2a24..c722ae5a7 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -109,8 +109,25 @@ def call_observer( updated_scale, updated_zero_point = observer( value, g_idx=g_idx, global_scale=global_scale ) - update_parameter_data(module, updated_scale, f"{base_name}_scale") - update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point") + # register or update scale & zero_point parameters (supports block shapes) + scale_name = f"{base_name}_scale" + zp_name = f"{base_name}_zero_point" + if not hasattr(module, scale_name) or getattr(module, scale_name).shape != updated_scale.shape: + if hasattr(module, scale_name): + delattr(module, scale_name) + module.register_parameter( + scale_name, torch.nn.Parameter(updated_scale.clone()) + ) + else: + update_parameter_data(module, updated_scale, scale_name) + if not hasattr(module, zp_name) or getattr(module, zp_name).shape != updated_zero_point.shape: + if hasattr(module, zp_name): + delattr(module, zp_name) + module.register_parameter( + zp_name, torch.nn.Parameter(updated_zero_point.clone()) + ) + else: + update_parameter_data(module, updated_zero_point, zp_name) def update_weight_global_scale(module: Module): diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index b7169d1d6..ce6840186 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -193,12 +193,30 @@ def get_qparams( ) elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # TODO (#1475) add support for block-wise quantization - raise NotImplementedError( - "Block-wise quantization is not yet supported, " - "consider group-wise quantization instead. More info at " - "https://github.com/vllm-project/llm-compressor/issues/1475" - ) + # Block-wise quantization: one scale/zero_point per block of shape [block_rows, block_cols] + rows, cols = observed.shape[:2] + bs = self.quantization_args.block_structure + if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(x, int) for x in bs)): + raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].") + block_rows, block_cols = bs + num_br = int(ceil(rows / block_rows)) + num_bc = int(ceil(cols / block_cols)) + # allocate per-block scale and zero_point + self._scale = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device) + self._zero_point = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device) + # compute qparams for each block + for i in range(num_br): + r0 = i * block_rows + r1 = min((i + 1) * block_rows, rows) + for j in range(num_bc): + c0 = j * block_cols + c1 = min((j + 1) * block_cols, cols) + # reduce across both dims to get one scale and zp per block + scale_bp, zp_bp = self.calculate_qparams( + observed[r0:r1, c0:c1], reduce_dims=(0, 1) + ) + self._scale[i, j] = scale_bp + self._zero_point[i, j] = zp_bp return self._scale, self._zero_point diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py index b95ee9c1c..8b8b29784 100644 --- a/tests/llmcompressor/modifiers/quantization/test_base.py +++ b/tests/llmcompressor/modifiers/quantization/test_base.py @@ -34,6 +34,32 @@ def q_config_kwargs(config_0, config_1): ) ) +@pytest.fixture +def block_q_config_kwargs(): + return dict( + config_groups=dict( + group_block=dict( + targets=["Linear"], + input_activations=dict( + num_bits=8, symmetric=True, strategy="group", group_size=128 + ), + weights=dict( + num_bits=8, + symmetric=True, + strategy="block", + block_structure=[128, 128], + ), + ), + ) + ) + +def test_block_strategy_parsing(block_q_config_kwargs): + modifier = GPTQModifier(**block_q_config_kwargs) + resolved = modifier.resolve_quantization_config() + w_scheme = resolved.config_groups["group_block"].weights + assert w_scheme.strategy == "block" + assert w_scheme.block_structure == [128, 128] + @pytest.mark.parametrize( "has_actorder,actorder,config_0,config_1,expected_0,expected_1", From a21ff60560bf045c8934f8ddc37608b8a24226e0 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 30 Jun 2025 14:12:38 -0600 Subject: [PATCH 2/5] Update docs/schemes.md Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/schemes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/schemes.md b/docs/schemes.md index bbfd1f855..b800ec8d9 100644 --- a/docs/schemes.md +++ b/docs/schemes.md @@ -20,7 +20,7 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la - Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell). ### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py) -- Uses block-wise quantization to compress weights to FP8 in (commonly 128×128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM. +- Uses block-wise quantization to compress weights to FP8 in blocks (commonly 128x128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM. ## Sparsification Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include: From 3bfbcd4181e731e49fe012005b3913ac24146a27 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 30 Jun 2025 14:12:50 -0600 Subject: [PATCH 3/5] Update examples/quantization_w8a8_fp8/fp8_block_example.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- examples/quantization_w8a8_fp8/fp8_block_example.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index b5d6ca1f9..d525a8d8c 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -13,8 +13,8 @@ # Configure the quantization algorithm and scheme. # In this case, we: -# * quantize the weights to fp8 with per channel via ptq -# * quantize the activations to fp8 with dynamic per token +# * quantize the weights to fp8 with block-wise quantization +# * quantize the activations to fp8 with dynamic per-token-group quantization recipe = QuantizationModifier( targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] ) From 02615d90046a771889824acde15a4c26aa1b7b29 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Jul 2025 00:44:17 +0000 Subject: [PATCH 4/5] Fix get_qparams by specifying tensor_id Signed-off-by: mgoin --- .../modifiers/quantization/calibration.py | 25 ++++++---------- src/llmcompressor/observers/base.py | 30 +++++++++++++++---- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 9b99ef2df..09131b671 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -127,22 +127,15 @@ def call_observer( # register or update scale & zero_point parameters (supports block shapes) scale_name = f"{base_name}_scale" zp_name = f"{base_name}_zero_point" - if not hasattr(module, scale_name) or getattr(module, scale_name).shape != updated_scale.shape: - if hasattr(module, scale_name): - delattr(module, scale_name) - module.register_parameter( - scale_name, torch.nn.Parameter(updated_scale.clone()) - ) - else: - update_parameter_data(module, updated_scale, scale_name) - if not hasattr(module, zp_name) or getattr(module, zp_name).shape != updated_zero_point.shape: - if hasattr(module, zp_name): - delattr(module, zp_name) - module.register_parameter( - zp_name, torch.nn.Parameter(updated_zero_point.clone()) - ) - else: - update_parameter_data(module, updated_zero_point, zp_name) + for name, value in [(scale_name, updated_scale), (zp_name, updated_zero_point)]: + if not hasattr(module, name) or getattr(module, name).shape != value.shape: + if hasattr(module, name): + delattr(module, name) + module.register_parameter( + name, torch.nn.Parameter(value.clone()) + ) + else: + update_parameter_data(module, value, name) def update_weight_global_scale(module: Module): diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index ce6840186..b08cc35c8 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -63,12 +63,17 @@ def calculate_qparams( self, observed: Tensor, reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + global_scale: Optional[torch.Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: """ :param observed: observed tensor to calculate quantization parameters for :param reduce_dims: optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions + :param tensor_id: Optional id if different ranges of observed tensors are + passed, useful for sharding tensors by group_size + :param global_scale: optional scale to further scale local quantization scales :return: tuple of scale and zero point derived from the observed tensor """ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") @@ -199,21 +204,36 @@ def get_qparams( if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(x, int) for x in bs)): raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].") block_rows, block_cols = bs - num_br = int(ceil(rows / block_rows)) - num_bc = int(ceil(cols / block_cols)) + + # Enforce exact division (tensor dimensions must be divisible by block size) + if rows % block_rows != 0: + raise ValueError( + f"Tensor height {rows} is not divisible by block_rows {block_rows}. " + f"Block quantization requires exact division." + ) + if cols % block_cols != 0: + raise ValueError( + f"Tensor width {cols} is not divisible by block_cols {block_cols}. " + f"Block quantization requires exact division." + ) + + num_br = rows // block_rows + num_bc = cols // block_cols # allocate per-block scale and zero_point self._scale = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device) self._zero_point = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device) # compute qparams for each block for i in range(num_br): r0 = i * block_rows - r1 = min((i + 1) * block_rows, rows) + r1 = (i + 1) * block_rows for j in range(num_bc): c0 = j * block_cols - c1 = min((j + 1) * block_cols, cols) + c1 = (j + 1) * block_cols # reduce across both dims to get one scale and zp per block scale_bp, zp_bp = self.calculate_qparams( - observed[r0:r1, c0:c1], reduce_dims=(0, 1) + observed[r0:r1, c0:c1], + reduce_dims=(0, 1), + tensor_id=i*num_bc+j, ) self._scale[i, j] = scale_bp self._zero_point[i, j] = zp_bp From 674b76e8286fd7d39b37f235c41867e7660d3873 Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 1 Jul 2025 01:04:58 +0000 Subject: [PATCH 5/5] Format Signed-off-by: mgoin --- .../fp8_block_example.py | 4 +- .../modifiers/quantization/calibration.py | 14 ++++--- src/llmcompressor/observers/base.py | 42 ++++++++++++------- .../modifiers/quantization/test_base.py | 2 + 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py index d525a8d8c..fd496fe15 100644 --- a/examples/quantization_w8a8_fp8/fp8_block_example.py +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -15,9 +15,7 @@ # In this case, we: # * quantize the weights to fp8 with block-wise quantization # * quantize the activations to fp8 with dynamic per-token-group quantization -recipe = QuantizationModifier( - targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"] -) +recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]) # Apply quantization. oneshot(model=model, recipe=recipe) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 09131b671..8399c67c0 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -127,13 +127,17 @@ def call_observer( # register or update scale & zero_point parameters (supports block shapes) scale_name = f"{base_name}_scale" zp_name = f"{base_name}_zero_point" - for name, value in [(scale_name, updated_scale), (zp_name, updated_zero_point)]: - if not hasattr(module, name) or getattr(module, name).shape != value.shape: + for name, value in [ + (scale_name, updated_scale), + (zp_name, updated_zero_point), + ]: + if ( + not hasattr(module, name) + or getattr(module, name).shape != value.shape + ): if hasattr(module, name): delattr(module, name) - module.register_parameter( - name, torch.nn.Parameter(value.clone()) - ) + module.register_parameter(name, torch.nn.Parameter(value.clone())) else: update_parameter_data(module, value, name) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index b08cc35c8..60bfc6d51 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -198,30 +198,42 @@ def get_qparams( ) elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # Block-wise quantization: one scale/zero_point per block of shape [block_rows, block_cols] + # Block-wise quantization: one scale/zero_point per block of shape + # [block_rows, block_cols] rows, cols = observed.shape[:2] bs = self.quantization_args.block_structure - if not (isinstance(bs, (list, tuple)) and len(bs) == 2 and all(isinstance(x, int) for x in bs)): - raise ValueError(f"Invalid block_structure '{bs}'. Must be a list of two ints [rows, cols].") + if not ( + isinstance(bs, (list, tuple)) + and len(bs) == 2 + and all(isinstance(x, int) for x in bs) + ): + raise ValueError( + f"Invalid block_structure '{bs}'. " + "Must be a list of two ints [rows, cols]." + ) block_rows, block_cols = bs - - # Enforce exact division (tensor dimensions must be divisible by block size) + + # Enforce exact division (dimensions must be divisible by block size) if rows % block_rows != 0: raise ValueError( - f"Tensor height {rows} is not divisible by block_rows {block_rows}. " - f"Block quantization requires exact division." + f"Tensor height {rows} is not divisible by block_rows " + f"{block_rows}. Block quantization requires exact division." ) if cols % block_cols != 0: raise ValueError( - f"Tensor width {cols} is not divisible by block_cols {block_cols}. " - f"Block quantization requires exact division." + f"Tensor width {cols} is not divisible by block_cols " + f"{block_cols}. Block quantization requires exact division." ) - + num_br = rows // block_rows num_bc = cols // block_cols # allocate per-block scale and zero_point - self._scale = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device) - self._zero_point = torch.empty((num_br, num_bc), dtype=observed.dtype, device=observed.device) + self._scale = torch.empty( + (num_br, num_bc), dtype=observed.dtype, device=observed.device + ) + self._zero_point = torch.empty( + (num_br, num_bc), dtype=observed.dtype, device=observed.device + ) # compute qparams for each block for i in range(num_br): r0 = i * block_rows @@ -231,9 +243,9 @@ def get_qparams( c1 = (j + 1) * block_cols # reduce across both dims to get one scale and zp per block scale_bp, zp_bp = self.calculate_qparams( - observed[r0:r1, c0:c1], - reduce_dims=(0, 1), - tensor_id=i*num_bc+j, + observed[r0:r1, c0:c1], + reduce_dims=(0, 1), + tensor_id=i * num_bc + j, ) self._scale[i, j] = scale_bp self._zero_point[i, j] = zp_bp diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py index 8b8b29784..4842586e3 100644 --- a/tests/llmcompressor/modifiers/quantization/test_base.py +++ b/tests/llmcompressor/modifiers/quantization/test_base.py @@ -34,6 +34,7 @@ def q_config_kwargs(config_0, config_1): ) ) + @pytest.fixture def block_q_config_kwargs(): return dict( @@ -53,6 +54,7 @@ def block_q_config_kwargs(): ) ) + def test_block_strategy_parsing(block_q_config_kwargs): modifier = GPTQModifier(**block_q_config_kwargs) resolved = modifier.resolve_quantization_config()