From 07e46b4d1cc9135c82165cca91ef90eaf88ca04a Mon Sep 17 00:00:00 2001 From: V-E-D Date: Sun, 1 Jun 2025 17:53:11 +0530 Subject: [PATCH 1/2] block wise quantization support --- src/llmcompressor/observers/base.py | 187 ++++++++++++++++++++++++---- 1 file changed, 166 insertions(+), 21 deletions(-) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 56dd28449..ba84e26aa 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -58,15 +58,159 @@ def calculate_qparams( self, observed: Tensor, reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: + """Calculate quantization parameters for the observed tensor. + + Args: + observed: Tensor to calculate quantization parameters for + reduce_dims: Optional tuple of dimensions to reduce along. + Returned scale and zero point will be shaped (1,) along the + reduced dimensions + tensor_id: Optional identifier for the tensor or tensor part being + quantized + global_scale: Optional scale to further scale local quantization scales + + Returns: + Tuple of scale and zero point derived from the observed tensor """ - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :return: tuple of scale and zero point derived from the observed tensor - """ - raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") + if self.quantization_args.strategy == QuantizationStrategy.BLOCK: + # Parse block structure - but ONLY if this is a top-level call + if tensor_id is None: + # This is the top-level call - handle the block structure + block_structure = self.quantization_args.block_structure + if block_structure is None: + raise ValueError( + "block_structure must be specified for block-wise quantization" + ) + + try: + block_rows, block_cols = map(int, block_structure.split("x")) + if block_rows <= 0 or block_cols <= 0: + raise ValueError( + f"Block dimensions must be positive integers, " + f"got {block_rows}x{block_cols}" + ) + except (ValueError, AttributeError): + raise ValueError( + f"Invalid block_structure: {block_structure}, " + "expected format 'AxB' (e.g. '128x128')" + ) + + rows, columns = observed.shape + if observed.ndim != 2: + raise ValueError( + f"Block-wise quantization expects 2D tensors, " + f"got tensor with {observed.ndim} dimensions" + ) + + num_row_blocks = ceil(rows / block_rows) + num_col_blocks = ceil(columns / block_cols) + + # Check if dimensions are multiples of block size + if ( + num_row_blocks * block_rows != rows + or num_col_blocks * block_cols != columns + ): + logger.bind(log_once=True).warning( + f"Tensor dimensions ({rows}x{columns}) are not divisible by " + f"block_structure ({block_structure}). Padding will be applied." + ) + + # Create tensors to hold scales and zero points + scale_tensor = torch.zeros_like(observed) + zero_point_tensor = torch.zeros_like(observed, dtype=torch.int32) + + # Process each block + for row_block in range(num_row_blocks): + row_start = row_block * block_rows + row_end = min(row_start + block_rows, rows) + + for col_block in range(num_col_blocks): + col_start = col_block * block_cols + col_end = min(col_start + block_cols, columns) + + # Get block data + block_data = observed[row_start:row_end, col_start:col_end] + + # Calculate min/max for this block + block_min = block_data.min() + block_max = block_data.max() + + # Calculate scale and zero point for this block + if block_max == block_min: + block_scale = torch.tensor(1.0, device=observed.device) + block_zero_point = torch.tensor( + 0, dtype=torch.int32, device=observed.device + ) + else: + # For int8, qmin=-128, qmax=127 + qmin, qmax = -128, 127 # Default for INT8 + block_scale = (block_max - block_min) / (qmax - qmin) + block_zero_point = torch.round( + qmin - block_min / block_scale + ).to(torch.int32) + + # Extract scalar values if needed + if hasattr(block_scale, "item"): + block_scale = block_scale.item() + if hasattr(block_zero_point, "item"): + block_zero_point = block_zero_point.item() + + # Fill the corresponding region in the tensors + scale_tensor[row_start:row_end, col_start:col_end].fill_( + block_scale + ) + zero_point_tensor[row_start:row_end, col_start:col_end].fill_( + block_zero_point + ) + + # Store the full tensors + self._scale = scale_tensor + self._zero_point = zero_point_tensor.to( + dtype=( + FP8_E4M3_DATA.dtype + if is_fp4(quantization_args=self.quantization_args) + else self.quantization_args.pytorch_dtype() + ) + ) + + return self._scale, self._zero_point + else: + # This is a recursive call for a specific block + min_val = observed.min() + max_val = observed.max() + + # For int8, qmin=-128, qmax=127 + qmin, qmax = -128, 127 # Default for INT8 + + if max_val == min_val: + scale = torch.tensor(1.0, device=observed.device) + zero_point = torch.tensor( + 0, dtype=torch.int32, device=observed.device + ) + else: + scale = (max_val - min_val) / (qmax - qmin) + zero_point = torch.round(qmin - min_val / scale).to(torch.int32) + + return scale, zero_point + else: + # For non-block quantization, use global min/max + min_val = observed.min() + max_val = observed.max() + + # For int8, qmin=-128, qmax=127 + qmin, qmax = -128, 127 # Default for INT8 + + if max_val == min_val: + scale = torch.tensor(1.0, device=min_val.device) + zero_point = torch.tensor(0, dtype=torch.int32, device=min_val.device) + else: + scale = (max_val - min_val) / (qmax - qmin) + zero_point = torch.round(qmin - min_val / scale).to(torch.int32) + + return scale, zero_point def post_calculate_qparams(self) -> None: """ @@ -79,16 +223,15 @@ def get_qparams( g_idx: Optional[Tensor] = None, global_scale: Optional[Tensor] = None, ) -> Tuple[FloatTensor, IntTensor]: - """ - Convenience function to wrap overwritten calculate_qparams - adds support to make observed tensor optional and support for tracking latest - calculated scale and zero point + """Get quantization parameters for the observed tensor. - :param observed: optional observed tensor to calculate quantization parameters - from - :param g_idx: optional mapping from column index to group index - :param global_scale: optional scale to further scale local quantization scales - :return: tuple of scale and zero point based on last observed value + Args: + observed: Optional tensor to calculate quantization parameters from + g_idx: Optional mapping from column index to group index + global_scale: Optional scale to further scale local quantization scales + + Returns: + Tuple of scale and zero point based on last observed value """ if observed is not None: group_size = self.quantization_args.group_size @@ -165,11 +308,13 @@ def get_qparams( ) elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # TODO (#1475) add support for block-wise quantization - raise NotImplementedError( - "Block-wise quantization is not yet supported, " - "consider group-wise quantization instead. More info at " - "https://github.com/vllm-project/llm-compressor/issues/1475" + self._scale, self._zero_point = self.calculate_qparams( + observed, tensor_id=None, global_scale=global_scale + ) + else: + raise ValueError( + f"Unsupported quantization strategy: " + f"{self.quantization_args.strategy}" ) return self._scale, self._zero_point From 523de8d0b0e3049a630e3280cb39f328daeb89cd Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 11 Jun 2025 10:50:03 +0530 Subject: [PATCH 2/2] moved block qunatization logic to get_params --- src/llmcompressor/observers/base.py | 193 ++++++++-------------------- 1 file changed, 53 insertions(+), 140 deletions(-) diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index ba84e26aa..f6b78874c 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -68,149 +68,13 @@ def calculate_qparams( reduce_dims: Optional tuple of dimensions to reduce along. Returned scale and zero point will be shaped (1,) along the reduced dimensions - tensor_id: Optional identifier for the tensor or tensor part being - quantized + tensor_id: Optional identifier for the tensor/block being quantized global_scale: Optional scale to further scale local quantization scales Returns: Tuple of scale and zero point derived from the observed tensor """ - if self.quantization_args.strategy == QuantizationStrategy.BLOCK: - # Parse block structure - but ONLY if this is a top-level call - if tensor_id is None: - # This is the top-level call - handle the block structure - block_structure = self.quantization_args.block_structure - if block_structure is None: - raise ValueError( - "block_structure must be specified for block-wise quantization" - ) - - try: - block_rows, block_cols = map(int, block_structure.split("x")) - if block_rows <= 0 or block_cols <= 0: - raise ValueError( - f"Block dimensions must be positive integers, " - f"got {block_rows}x{block_cols}" - ) - except (ValueError, AttributeError): - raise ValueError( - f"Invalid block_structure: {block_structure}, " - "expected format 'AxB' (e.g. '128x128')" - ) - - rows, columns = observed.shape - if observed.ndim != 2: - raise ValueError( - f"Block-wise quantization expects 2D tensors, " - f"got tensor with {observed.ndim} dimensions" - ) - - num_row_blocks = ceil(rows / block_rows) - num_col_blocks = ceil(columns / block_cols) - - # Check if dimensions are multiples of block size - if ( - num_row_blocks * block_rows != rows - or num_col_blocks * block_cols != columns - ): - logger.bind(log_once=True).warning( - f"Tensor dimensions ({rows}x{columns}) are not divisible by " - f"block_structure ({block_structure}). Padding will be applied." - ) - - # Create tensors to hold scales and zero points - scale_tensor = torch.zeros_like(observed) - zero_point_tensor = torch.zeros_like(observed, dtype=torch.int32) - - # Process each block - for row_block in range(num_row_blocks): - row_start = row_block * block_rows - row_end = min(row_start + block_rows, rows) - - for col_block in range(num_col_blocks): - col_start = col_block * block_cols - col_end = min(col_start + block_cols, columns) - - # Get block data - block_data = observed[row_start:row_end, col_start:col_end] - - # Calculate min/max for this block - block_min = block_data.min() - block_max = block_data.max() - - # Calculate scale and zero point for this block - if block_max == block_min: - block_scale = torch.tensor(1.0, device=observed.device) - block_zero_point = torch.tensor( - 0, dtype=torch.int32, device=observed.device - ) - else: - # For int8, qmin=-128, qmax=127 - qmin, qmax = -128, 127 # Default for INT8 - block_scale = (block_max - block_min) / (qmax - qmin) - block_zero_point = torch.round( - qmin - block_min / block_scale - ).to(torch.int32) - - # Extract scalar values if needed - if hasattr(block_scale, "item"): - block_scale = block_scale.item() - if hasattr(block_zero_point, "item"): - block_zero_point = block_zero_point.item() - - # Fill the corresponding region in the tensors - scale_tensor[row_start:row_end, col_start:col_end].fill_( - block_scale - ) - zero_point_tensor[row_start:row_end, col_start:col_end].fill_( - block_zero_point - ) - - # Store the full tensors - self._scale = scale_tensor - self._zero_point = zero_point_tensor.to( - dtype=( - FP8_E4M3_DATA.dtype - if is_fp4(quantization_args=self.quantization_args) - else self.quantization_args.pytorch_dtype() - ) - ) - - return self._scale, self._zero_point - else: - # This is a recursive call for a specific block - min_val = observed.min() - max_val = observed.max() - - # For int8, qmin=-128, qmax=127 - qmin, qmax = -128, 127 # Default for INT8 - - if max_val == min_val: - scale = torch.tensor(1.0, device=observed.device) - zero_point = torch.tensor( - 0, dtype=torch.int32, device=observed.device - ) - else: - scale = (max_val - min_val) / (qmax - qmin) - zero_point = torch.round(qmin - min_val / scale).to(torch.int32) - - return scale, zero_point - else: - # For non-block quantization, use global min/max - min_val = observed.min() - max_val = observed.max() - - # For int8, qmin=-128, qmax=127 - qmin, qmax = -128, 127 # Default for INT8 - - if max_val == min_val: - scale = torch.tensor(1.0, device=min_val.device) - zero_point = torch.tensor(0, dtype=torch.int32, device=min_val.device) - else: - scale = (max_val - min_val) / (qmax - qmin) - zero_point = torch.round(qmin - min_val / scale).to(torch.int32) - - return scale, zero_point + raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") def post_calculate_qparams(self) -> None: """ @@ -308,9 +172,58 @@ def get_qparams( ) elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: - self._scale, self._zero_point = self.calculate_qparams( - observed, tensor_id=None, global_scale=global_scale + # Get block size from quantization arguments + block_size = self.quantization_args.block_size + if block_size is None: + raise ValueError( + "Block size must be specified for block-wise quantization" + ) + + # Get tensor dimensions + rows, columns = observed.shape[0], observed.shape[1] + + # Calculate number of blocks in each dimension + num_row_blocks = (rows + block_size - 1) // block_size + num_col_blocks = (columns + block_size - 1) // block_size + + # Create tensors to store scales and zero points for each block + self._scale = torch.empty( + (num_row_blocks, num_col_blocks), + dtype=observed.dtype, + device=observed.device, ) + + if is_fp4(quantization_args=self.quantization_args): + zp_dtype = FP8_E4M3_DATA.dtype + else: + zp_dtype = self.quantization_args.pytorch_dtype() + + self._zero_point = torch.empty( + (num_row_blocks, num_col_blocks), + dtype=zp_dtype, + device=observed.device, + ) + + # Process each block + for i in range(num_row_blocks): + row_start = i * block_size + row_end = min(row_start + block_size, rows) + + for j in range(num_col_blocks): + col_start = j * block_size + col_end = min(col_start + block_size, columns) + + # Extract the block + block = observed[row_start:row_end, col_start:col_end] + + # Calculate scale and zero point for this block + scale, zero_point = self.calculate_qparams( + block, tensor_id=(i, j), global_scale=global_scale + ) + + # Store the results + self._scale[i, j] = scale.item() + self._zero_point[i, j] = zero_point.item() else: raise ValueError( f"Unsupported quantization strategy: "