Skip to content

block wise quantization support #1497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 78 additions & 20 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,21 @@ def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""Calculate quantization parameters for the observed tensor.

Args:
observed: Tensor to calculate quantization parameters for
reduce_dims: Optional tuple of dimensions to reduce along.
Returned scale and zero point will be shaped (1,) along the
reduced dimensions
tensor_id: Optional identifier for the tensor/block being quantized
global_scale: Optional scale to further scale local quantization scales

Returns:
Tuple of scale and zero point derived from the observed tensor
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")

Expand All @@ -79,16 +87,15 @@ def get_qparams(
g_idx: Optional[Tensor] = None,
global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Convenience function to wrap overwritten calculate_qparams
adds support to make observed tensor optional and support for tracking latest
calculated scale and zero point
"""Get quantization parameters for the observed tensor.

:param observed: optional observed tensor to calculate quantization parameters
from
:param g_idx: optional mapping from column index to group index
:param global_scale: optional scale to further scale local quantization scales
:return: tuple of scale and zero point based on last observed value
Args:
observed: Optional tensor to calculate quantization parameters from
g_idx: Optional mapping from column index to group index
global_scale: Optional scale to further scale local quantization scales

Returns:
Tuple of scale and zero point based on last observed value
"""
if observed is not None:
group_size = self.quantization_args.group_size
Expand Down Expand Up @@ -165,11 +172,62 @@ def get_qparams(
)

elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
# TODO (#1475) add support for block-wise quantization
raise NotImplementedError(
"Block-wise quantization is not yet supported, "
"consider group-wise quantization instead. More info at "
"https://github.com/vllm-project/llm-compressor/issues/1475"
# Get block size from quantization arguments
block_size = self.quantization_args.block_size
if block_size is None:
raise ValueError(
"Block size must be specified for block-wise quantization"
)

# Get tensor dimensions
rows, columns = observed.shape[0], observed.shape[1]

# Calculate number of blocks in each dimension
num_row_blocks = (rows + block_size - 1) // block_size
num_col_blocks = (columns + block_size - 1) // block_size

# Create tensors to store scales and zero points for each block
self._scale = torch.empty(
(num_row_blocks, num_col_blocks),
dtype=observed.dtype,
device=observed.device,
)

if is_fp4(quantization_args=self.quantization_args):
zp_dtype = FP8_E4M3_DATA.dtype
else:
zp_dtype = self.quantization_args.pytorch_dtype()

self._zero_point = torch.empty(
(num_row_blocks, num_col_blocks),
dtype=zp_dtype,
device=observed.device,
)

# Process each block
for i in range(num_row_blocks):
row_start = i * block_size
row_end = min(row_start + block_size, rows)

for j in range(num_col_blocks):
col_start = j * block_size
col_end = min(col_start + block_size, columns)

# Extract the block
block = observed[row_start:row_end, col_start:col_end]

# Calculate scale and zero point for this block
scale, zero_point = self.calculate_qparams(
block, tensor_id=(i, j), global_scale=global_scale
)

# Store the results
self._scale[i, j] = scale.item()
self._zero_point[i, j] = zero_point.item()
else:
raise ValueError(
f"Unsupported quantization strategy: "
f"{self.quantization_args.strategy}"
)

return self._scale, self._zero_point
Expand Down