diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 56dd28449..f6b78874c 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -58,13 +58,21 @@ def calculate_qparams( self, observed: Tensor, reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: - """ - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor + """Calculate quantization parameters for the observed tensor. + + Args: + observed: Tensor to calculate quantization parameters for + reduce_dims: Optional tuple of dimensions to reduce along. + Returned scale and zero point will be shaped (1,) along the + reduced dimensions + tensor_id: Optional identifier for the tensor/block being quantized + global_scale: Optional scale to further scale local quantization scales + + Returns: + Tuple of scale and zero point derived from the observed tensor """ raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") @@ -79,16 +87,15 @@ def get_qparams( g_idx: Optional[Tensor] = None, global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: - """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point + """Get quantization parameters for the observed tensor. - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value + Args: + observed: Optional tensor to calculate quantization parameters from + g_idx: Optional mapping from column index to group index + global_scale: Optional scale to further scale local quantization scales + + Returns: + Tuple of scale and zero point based on last observed value """ if observed is not None: group_size = self.quantization_args.group_size @@ -165,11 +172,62 @@ def get_qparams( ) elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # TODO (#1475) add support for block-wise quantization - raise NotImplementedError( - "Block-wise quantization is not yet supported, " - "consider group-wise quantization instead. More info at " - "https://github.com/vllm-project/llm-compressor/issues/1475" + # Get block size from quantization arguments + block_size = self.quantization_args.block_size + if block_size is None: + raise ValueError( + "Block size must be specified for block-wise quantization" + ) + + # Get tensor dimensions + rows, columns = observed.shape[0], observed.shape[1] + + # Calculate number of blocks in each dimension + num_row_blocks = (rows + block_size - 1) // block_size + num_col_blocks = (columns + block_size - 1) // block_size + + # Create tensors to store scales and zero points for each block + self._scale = torch.empty( + (num_row_blocks, num_col_blocks), + dtype=observed.dtype, + device=observed.device, + ) + + if is_fp4(quantization_args=self.quantization_args): + zp_dtype = FP8_E4M3_DATA.dtype + else: + zp_dtype = self.quantization_args.pytorch_dtype() + + self._zero_point = torch.empty( + (num_row_blocks, num_col_blocks), + dtype=zp_dtype, + device=observed.device, + ) + + # Process each block + for i in range(num_row_blocks): + row_start = i * block_size + row_end = min(row_start + block_size, rows) + + for j in range(num_col_blocks): + col_start = j * block_size + col_end = min(col_start + block_size, columns) + + # Extract the block + block = observed[row_start:row_end, col_start:col_end] + + # Calculate scale and zero point for this block + scale, zero_point = self.calculate_qparams( + block, tensor_id=(i, j), global_scale=global_scale + ) + + # Store the results + self._scale[i, j] = scale.item() + self._zero_point[i, j] = zero_point.item() + else: + raise ValueError( + f"Unsupported quantization strategy: " + f"{self.quantization_args.strategy}" ) return self._scale, self._zero_point