diff --git a/docs/schemes.md b/docs/schemes.md index 19ff746e4..b800ec8d9 100644 --- a/docs/schemes.md +++ b/docs/schemes.md @@ -19,6 +19,9 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la - Useful for speed ups in high QPS regimes or offline serving on vLLM. - Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell). +### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py) +- Uses block-wise quantization to compress weights to FP8 in blocks (commonly 128x128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM. + ## Sparsification Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include: diff --git a/examples/quantization_w8a8_fp8/fp8_block_example.py b/examples/quantization_w8a8_fp8/fp8_block_example.py new file mode 100644 index 000000000..fd496fe15 --- /dev/null +++ b/examples/quantization_w8a8_fp8/fp8_block_example.py @@ -0,0 +1,33 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +MODEL_ID = "Qwen/Qwen3-0.6B" + +# Load model. +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with block-wise quantization +# * quantize the activations to fp8 with dynamic per-token-group quantization +recipe = QuantizationModifier(targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]) + +# Apply quantization. +oneshot(model=model, recipe=recipe) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(tokenizer.decode(output[0])) +print("==========================================") + +# Save to disk in compressed-tensors format. +SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-BLOCK" +model.save_pretrained(SAVE_DIR) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index b10a4cb31..8399c67c0 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -124,8 +124,22 @@ def call_observer( updated_scale, updated_zero_point = observer( value, g_idx=g_idx, global_scale=global_scale ) - update_parameter_data(module, updated_scale, f"{base_name}_scale") - update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point") + # register or update scale & zero_point parameters (supports block shapes) + scale_name = f"{base_name}_scale" + zp_name = f"{base_name}_zero_point" + for name, value in [ + (scale_name, updated_scale), + (zp_name, updated_zero_point), + ]: + if ( + not hasattr(module, name) + or getattr(module, name).shape != value.shape + ): + if hasattr(module, name): + delattr(module, name) + module.register_parameter(name, torch.nn.Parameter(value.clone())) + else: + update_parameter_data(module, value, name) def update_weight_global_scale(module: Module): diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index b7169d1d6..60bfc6d51 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -63,12 +63,17 @@ def calculate_qparams( self, observed: Tensor, reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + global_scale: Optional[torch.Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: """ :param observed: observed tensor to calculate quantization parameters for :param reduce_dims: optional tuple of dimensions to reduce along, returned scale and zero point will be shaped (1,) along the reduced dimensions + :param tensor_id: Optional id if different ranges of observed tensors are + passed, useful for sharding tensors by group_size + :param global_scale: optional scale to further scale local quantization scales :return: tuple of scale and zero point derived from the observed tensor """ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") @@ -193,12 +198,57 @@ def get_qparams( ) elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # TODO (#1475) add support for block-wise quantization - raise NotImplementedError( - "Block-wise quantization is not yet supported, " - "consider group-wise quantization instead. More info at " - "https://github.com/vllm-project/llm-compressor/issues/1475" + # Block-wise quantization: one scale/zero_point per block of shape + # [block_rows, block_cols] + rows, cols = observed.shape[:2] + bs = self.quantization_args.block_structure + if not ( + isinstance(bs, (list, tuple)) + and len(bs) == 2 + and all(isinstance(x, int) for x in bs) + ): + raise ValueError( + f"Invalid block_structure '{bs}'. " + "Must be a list of two ints [rows, cols]." + ) + block_rows, block_cols = bs + + # Enforce exact division (dimensions must be divisible by block size) + if rows % block_rows != 0: + raise ValueError( + f"Tensor height {rows} is not divisible by block_rows " + f"{block_rows}. Block quantization requires exact division." + ) + if cols % block_cols != 0: + raise ValueError( + f"Tensor width {cols} is not divisible by block_cols " + f"{block_cols}. Block quantization requires exact division." + ) + + num_br = rows // block_rows + num_bc = cols // block_cols + # allocate per-block scale and zero_point + self._scale = torch.empty( + (num_br, num_bc), dtype=observed.dtype, device=observed.device + ) + self._zero_point = torch.empty( + (num_br, num_bc), dtype=observed.dtype, device=observed.device ) + # compute qparams for each block + for i in range(num_br): + r0 = i * block_rows + r1 = (i + 1) * block_rows + for j in range(num_bc): + c0 = j * block_cols + c1 = (j + 1) * block_cols + # reduce across both dims to get one scale and zp per block + scale_bp, zp_bp = self.calculate_qparams( + observed[r0:r1, c0:c1], + reduce_dims=(0, 1), + tensor_id=i * num_bc + j, + ) + self._scale[i, j] = scale_bp + self._zero_point[i, j] = zp_bp return self._scale, self._zero_point diff --git a/tests/llmcompressor/modifiers/quantization/test_base.py b/tests/llmcompressor/modifiers/quantization/test_base.py index b95ee9c1c..4842586e3 100644 --- a/tests/llmcompressor/modifiers/quantization/test_base.py +++ b/tests/llmcompressor/modifiers/quantization/test_base.py @@ -35,6 +35,34 @@ def q_config_kwargs(config_0, config_1): ) +@pytest.fixture +def block_q_config_kwargs(): + return dict( + config_groups=dict( + group_block=dict( + targets=["Linear"], + input_activations=dict( + num_bits=8, symmetric=True, strategy="group", group_size=128 + ), + weights=dict( + num_bits=8, + symmetric=True, + strategy="block", + block_structure=[128, 128], + ), + ), + ) + ) + + +def test_block_strategy_parsing(block_q_config_kwargs): + modifier = GPTQModifier(**block_q_config_kwargs) + resolved = modifier.resolve_quantization_config() + w_scheme = resolved.config_groups["group_block"].weights + assert w_scheme.strategy == "block" + assert w_scheme.block_structure == [128, 128] + + @pytest.mark.parametrize( "has_actorder,actorder,config_0,config_1,expected_0,expected_1", [