-
Notifications
You must be signed in to change notification settings - Fork 176
block wise quantization support #1497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,15 +58,159 @@ def calculate_qparams( | |
self, | ||
observed: Tensor, | ||
reduce_dims: Optional[Tuple[int]] = None, | ||
tensor_id: Optional[Any] = None, | ||
global_scale: Optional[Tensor] = None, | ||
) -> Tuple[FloatTensor, IntTensor]: | ||
"""Calculate quantization parameters for the observed tensor. | ||
|
||
Args: | ||
observed: Tensor to calculate quantization parameters for | ||
reduce_dims: Optional tuple of dimensions to reduce along. | ||
Returned scale and zero point will be shaped (1,) along the | ||
reduced dimensions | ||
tensor_id: Optional identifier for the tensor or tensor part being | ||
quantized | ||
global_scale: Optional scale to further scale local quantization scales | ||
|
||
Returns: | ||
Tuple of scale and zero point derived from the observed tensor | ||
""" | ||
:param observed: observed tensor to calculate quantization parameters for | ||
:param reduce_dims: optional tuple of dimensions to reduce along, | ||
returned scale and zero point will be shaped (1,) along the | ||
reduced dimensions | ||
:return: tuple of scale and zero point derived from the observed tensor | ||
""" | ||
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams") | ||
if self.quantization_args.strategy == QuantizationStrategy.BLOCK: | ||
# Parse block structure - but ONLY if this is a top-level call | ||
if tensor_id is None: | ||
# This is the top-level call - handle the block structure | ||
block_structure = self.quantization_args.block_structure | ||
if block_structure is None: | ||
raise ValueError( | ||
"block_structure must be specified for block-wise quantization" | ||
) | ||
|
||
try: | ||
block_rows, block_cols = map(int, block_structure.split("x")) | ||
if block_rows <= 0 or block_cols <= 0: | ||
raise ValueError( | ||
f"Block dimensions must be positive integers, " | ||
f"got {block_rows}x{block_cols}" | ||
) | ||
except (ValueError, AttributeError): | ||
raise ValueError( | ||
f"Invalid block_structure: {block_structure}, " | ||
"expected format 'AxB' (e.g. '128x128')" | ||
) | ||
|
||
rows, columns = observed.shape | ||
if observed.ndim != 2: | ||
raise ValueError( | ||
f"Block-wise quantization expects 2D tensors, " | ||
f"got tensor with {observed.ndim} dimensions" | ||
) | ||
|
||
num_row_blocks = ceil(rows / block_rows) | ||
num_col_blocks = ceil(columns / block_cols) | ||
|
||
# Check if dimensions are multiples of block size | ||
if ( | ||
num_row_blocks * block_rows != rows | ||
or num_col_blocks * block_cols != columns | ||
): | ||
logger.bind(log_once=True).warning( | ||
f"Tensor dimensions ({rows}x{columns}) are not divisible by " | ||
f"block_structure ({block_structure}). Padding will be applied." | ||
) | ||
|
||
# Create tensors to hold scales and zero points | ||
scale_tensor = torch.zeros_like(observed) | ||
zero_point_tensor = torch.zeros_like(observed, dtype=torch.int32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these should have shape |
||
|
||
# Process each block | ||
for row_block in range(num_row_blocks): | ||
row_start = row_block * block_rows | ||
row_end = min(row_start + block_rows, rows) | ||
|
||
for col_block in range(num_col_blocks): | ||
col_start = col_block * block_cols | ||
col_end = min(col_start + block_cols, columns) | ||
|
||
# Get block data | ||
block_data = observed[row_start:row_end, col_start:col_end] | ||
|
||
# Calculate min/max for this block | ||
block_min = block_data.min() | ||
block_max = block_data.max() | ||
|
||
# Calculate scale and zero point for this block | ||
if block_max == block_min: | ||
block_scale = torch.tensor(1.0, device=observed.device) | ||
block_zero_point = torch.tensor( | ||
0, dtype=torch.int32, device=observed.device | ||
) | ||
else: | ||
# For int8, qmin=-128, qmax=127 | ||
qmin, qmax = -128, 127 # Default for INT8 | ||
block_scale = (block_max - block_min) / (qmax - qmin) | ||
block_zero_point = torch.round( | ||
qmin - block_min / block_scale | ||
).to(torch.int32) | ||
|
||
# Extract scalar values if needed | ||
if hasattr(block_scale, "item"): | ||
block_scale = block_scale.item() | ||
if hasattr(block_zero_point, "item"): | ||
block_zero_point = block_zero_point.item() | ||
|
||
# Fill the corresponding region in the tensors | ||
scale_tensor[row_start:row_end, col_start:col_end].fill_( | ||
block_scale | ||
) | ||
zero_point_tensor[row_start:row_end, col_start:col_end].fill_( | ||
block_zero_point | ||
) | ||
|
||
# Store the full tensors | ||
self._scale = scale_tensor | ||
self._zero_point = zero_point_tensor.to( | ||
dtype=( | ||
FP8_E4M3_DATA.dtype | ||
if is_fp4(quantization_args=self.quantization_args) | ||
else self.quantization_args.pytorch_dtype() | ||
) | ||
) | ||
|
||
return self._scale, self._zero_point | ||
else: | ||
# This is a recursive call for a specific block | ||
min_val = observed.min() | ||
max_val = observed.max() | ||
|
||
# For int8, qmin=-128, qmax=127 | ||
qmin, qmax = -128, 127 # Default for INT8 | ||
|
||
if max_val == min_val: | ||
scale = torch.tensor(1.0, device=observed.device) | ||
zero_point = torch.tensor( | ||
0, dtype=torch.int32, device=observed.device | ||
) | ||
else: | ||
scale = (max_val - min_val) / (qmax - qmin) | ||
zero_point = torch.round(qmin - min_val / scale).to(torch.int32) | ||
|
||
return scale, zero_point | ||
else: | ||
# For non-block quantization, use global min/max | ||
min_val = observed.min() | ||
max_val = observed.max() | ||
|
||
# For int8, qmin=-128, qmax=127 | ||
qmin, qmax = -128, 127 # Default for INT8 | ||
|
||
if max_val == min_val: | ||
scale = torch.tensor(1.0, device=min_val.device) | ||
zero_point = torch.tensor(0, dtype=torch.int32, device=min_val.device) | ||
else: | ||
scale = (max_val - min_val) / (qmax - qmin) | ||
zero_point = torch.round(qmin - min_val / scale).to(torch.int32) | ||
|
||
return scale, zero_point | ||
|
||
def post_calculate_qparams(self) -> None: | ||
""" | ||
|
@@ -79,16 +223,15 @@ def get_qparams( | |
g_idx: Optional[Tensor] = None, | ||
global_scale: Optional[Tensor] = None, | ||
) -> Tuple[FloatTensor, IntTensor]: | ||
""" | ||
Convenience function to wrap overwritten calculate_qparams | ||
adds support to make observed tensor optional and support for tracking latest | ||
calculated scale and zero point | ||
"""Get quantization parameters for the observed tensor. | ||
|
||
:param observed: optional observed tensor to calculate quantization parameters | ||
from | ||
:param g_idx: optional mapping from column index to group index | ||
:param global_scale: optional scale to further scale local quantization scales | ||
:return: tuple of scale and zero point based on last observed value | ||
Args: | ||
observed: Optional tensor to calculate quantization parameters from | ||
g_idx: Optional mapping from column index to group index | ||
global_scale: Optional scale to further scale local quantization scales | ||
|
||
Returns: | ||
Tuple of scale and zero point based on last observed value | ||
""" | ||
if observed is not None: | ||
group_size = self.quantization_args.group_size | ||
|
@@ -165,11 +308,13 @@ def get_qparams( | |
) | ||
|
||
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK: | ||
# TODO (#1475) add support for block-wise quantization | ||
raise NotImplementedError( | ||
"Block-wise quantization is not yet supported, " | ||
"consider group-wise quantization instead. More info at " | ||
"https://github.com/vllm-project/llm-compressor/issues/1475" | ||
self._scale, self._zero_point = self.calculate_qparams( | ||
observed, tensor_id=None, global_scale=global_scale | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the majority of your logic you'll want to have in here, or in a helper method that this calls to help with readability. |
||
else: | ||
raise ValueError( | ||
f"Unsupported quantization strategy: " | ||
f"{self.quantization_args.strategy}" | ||
) | ||
|
||
return self._scale, self._zero_point | ||
|
Uh oh!
There was an error while loading. Please reload this page.