Skip to content

block wise quantization support #1497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 166 additions & 21 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,159 @@ def calculate_qparams(
self,
observed: Tensor,
reduce_dims: Optional[Tuple[int]] = None,
tensor_id: Optional[Any] = None,
global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""Calculate quantization parameters for the observed tensor.

Args:
observed: Tensor to calculate quantization parameters for
reduce_dims: Optional tuple of dimensions to reduce along.
Returned scale and zero point will be shaped (1,) along the
reduced dimensions
tensor_id: Optional identifier for the tensor or tensor part being
quantized
global_scale: Optional scale to further scale local quantization scales

Returns:
Tuple of scale and zero point derived from the observed tensor
"""
:param observed: observed tensor to calculate quantization parameters for
:param reduce_dims: optional tuple of dimensions to reduce along,
returned scale and zero point will be shaped (1,) along the
reduced dimensions
:return: tuple of scale and zero point derived from the observed tensor
"""
raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
if self.quantization_args.strategy == QuantizationStrategy.BLOCK:
# Parse block structure - but ONLY if this is a top-level call
if tensor_id is None:
# This is the top-level call - handle the block structure
block_structure = self.quantization_args.block_structure
if block_structure is None:
raise ValueError(
"block_structure must be specified for block-wise quantization"
)

try:
block_rows, block_cols = map(int, block_structure.split("x"))
if block_rows <= 0 or block_cols <= 0:
raise ValueError(
f"Block dimensions must be positive integers, "
f"got {block_rows}x{block_cols}"
)
except (ValueError, AttributeError):
raise ValueError(
f"Invalid block_structure: {block_structure}, "
"expected format 'AxB' (e.g. '128x128')"
)

rows, columns = observed.shape
if observed.ndim != 2:
raise ValueError(
f"Block-wise quantization expects 2D tensors, "
f"got tensor with {observed.ndim} dimensions"
)

num_row_blocks = ceil(rows / block_rows)
num_col_blocks = ceil(columns / block_cols)

# Check if dimensions are multiples of block size
if (
num_row_blocks * block_rows != rows
or num_col_blocks * block_cols != columns
):
logger.bind(log_once=True).warning(
f"Tensor dimensions ({rows}x{columns}) are not divisible by "
f"block_structure ({block_structure}). Padding will be applied."
)

# Create tensors to hold scales and zero points
scale_tensor = torch.zeros_like(observed)
zero_point_tensor = torch.zeros_like(observed, dtype=torch.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should have shape (rows, num_blocks), similar to how group-wise is set up here


# Process each block
for row_block in range(num_row_blocks):
row_start = row_block * block_rows
row_end = min(row_start + block_rows, rows)

for col_block in range(num_col_blocks):
col_start = col_block * block_cols
col_end = min(col_start + block_cols, columns)

# Get block data
block_data = observed[row_start:row_end, col_start:col_end]

# Calculate min/max for this block
block_min = block_data.min()
block_max = block_data.max()

# Calculate scale and zero point for this block
if block_max == block_min:
block_scale = torch.tensor(1.0, device=observed.device)
block_zero_point = torch.tensor(
0, dtype=torch.int32, device=observed.device
)
else:
# For int8, qmin=-128, qmax=127
qmin, qmax = -128, 127 # Default for INT8
block_scale = (block_max - block_min) / (qmax - qmin)
block_zero_point = torch.round(
qmin - block_min / block_scale
).to(torch.int32)

# Extract scalar values if needed
if hasattr(block_scale, "item"):
block_scale = block_scale.item()
if hasattr(block_zero_point, "item"):
block_zero_point = block_zero_point.item()

# Fill the corresponding region in the tensors
scale_tensor[row_start:row_end, col_start:col_end].fill_(
block_scale
)
zero_point_tensor[row_start:row_end, col_start:col_end].fill_(
block_zero_point
)

# Store the full tensors
self._scale = scale_tensor
self._zero_point = zero_point_tensor.to(
dtype=(
FP8_E4M3_DATA.dtype
if is_fp4(quantization_args=self.quantization_args)
else self.quantization_args.pytorch_dtype()
)
)

return self._scale, self._zero_point
else:
# This is a recursive call for a specific block
min_val = observed.min()
max_val = observed.max()

# For int8, qmin=-128, qmax=127
qmin, qmax = -128, 127 # Default for INT8

if max_val == min_val:
scale = torch.tensor(1.0, device=observed.device)
zero_point = torch.tensor(
0, dtype=torch.int32, device=observed.device
)
else:
scale = (max_val - min_val) / (qmax - qmin)
zero_point = torch.round(qmin - min_val / scale).to(torch.int32)

return scale, zero_point
else:
# For non-block quantization, use global min/max
min_val = observed.min()
max_val = observed.max()

# For int8, qmin=-128, qmax=127
qmin, qmax = -128, 127 # Default for INT8

if max_val == min_val:
scale = torch.tensor(1.0, device=min_val.device)
zero_point = torch.tensor(0, dtype=torch.int32, device=min_val.device)
else:
scale = (max_val - min_val) / (qmax - qmin)
zero_point = torch.round(qmin - min_val / scale).to(torch.int32)

return scale, zero_point

def post_calculate_qparams(self) -> None:
"""
Expand All @@ -79,16 +223,15 @@ def get_qparams(
g_idx: Optional[Tensor] = None,
global_scale: Optional[Tensor] = None,
) -> Tuple[FloatTensor, IntTensor]:
"""
Convenience function to wrap overwritten calculate_qparams
adds support to make observed tensor optional and support for tracking latest
calculated scale and zero point
"""Get quantization parameters for the observed tensor.

:param observed: optional observed tensor to calculate quantization parameters
from
:param g_idx: optional mapping from column index to group index
:param global_scale: optional scale to further scale local quantization scales
:return: tuple of scale and zero point based on last observed value
Args:
observed: Optional tensor to calculate quantization parameters from
g_idx: Optional mapping from column index to group index
global_scale: Optional scale to further scale local quantization scales

Returns:
Tuple of scale and zero point based on last observed value
"""
if observed is not None:
group_size = self.quantization_args.group_size
Expand Down Expand Up @@ -165,11 +308,13 @@ def get_qparams(
)

elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
# TODO (#1475) add support for block-wise quantization
raise NotImplementedError(
"Block-wise quantization is not yet supported, "
"consider group-wise quantization instead. More info at "
"https://github.com/vllm-project/llm-compressor/issues/1475"
self._scale, self._zero_point = self.calculate_qparams(
observed, tensor_id=None, global_scale=global_scale
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the majority of your logic you'll want to have in here, or in a helper method that this calls to help with readability.

else:
raise ValueError(
f"Unsupported quantization strategy: "
f"{self.quantization_args.strategy}"
)

return self._scale, self._zero_point
Expand Down