Skip to content

[KV Cache] support kv cache int8 per channel quantization #1663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def update(
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)

# reshape for per channel scenario
num_heads = key_states.shape[1]
head_dim = key_states.shape[-1]
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
key_states = key_states.transpose(1, 2).flatten(2)
value_states = value_states.transpose(1, 2).flatten(2)
Comment on lines +97 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reshape logic is specific to per-channel quantization but is currently applied unconditionally. This can cause incorrect behavior for other quantization strategies. Consider adding a conditional check to apply this logic only when self.quantization_args.strategy == QuantizationStrategy.CHANNEL.

Suggested change
# reshape for per channel scenario
num_heads = key_states.shape[1]
head_dim = key_states.shape[-1]
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
key_states = key_states.transpose(1, 2).flatten(2)
value_states = value_states.transpose(1, 2).flatten(2)
# reshape for per channel scenario
num_heads = key_states.shape[1]
head_dim = key_states.shape[-1]
if self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
key_states = key_states.transpose(1, 2).flatten(2)
value_states = value_states.transpose(1, 2).flatten(2)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per tensor is same


q_key_states = self._quantize(
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
)
Expand All @@ -106,6 +114,14 @@ def update(
q_value_states, KVCacheScaleType.VALUE, layer_idx
)

# reshape for per channel scenario
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
qdq_key_states = qdq_key_states.view(
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
qdq_value_states = qdq_value_states.view(
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)
Comment on lines +117 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This reverse reshape operation is specific to the per-channel scenario and should be executed conditionally. Applying this to tensors from other quantization strategies will result in shape mismatches. Wrap this in a condition that checks for the per-channel strategy: if self.quantization_args.strategy == QuantizationStrategy.CHANNEL:.

Suggested change
# reshape for per channel scenario
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
qdq_key_states = qdq_key_states.view(
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
qdq_value_states = qdq_value_states.view(
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)
if self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# reshape for per channel scenario
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
qdq_key_states = qdq_key_states.view(
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
qdq_value_states = qdq_value_states.view(
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per tensor is same


keys_to_return, values_to_return = qdq_key_states, qdq_value_states

return keys_to_return, values_to_return
Expand Down Expand Up @@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
zps = self.v_zps

scale, zp = observer(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale)
_pad_and_append_at_idx_(zps, layer_idx, zp)
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze())
_pad_and_append_at_idx_(zps, layer_idx, zp.squeeze())

q_tensor = quantize(
x=tensor,
Expand Down
7 changes: 5 additions & 2 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,11 @@ def get_qparams(
self._zero_point[:, group_index] = zero_point.squeeze(1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
# 1. dim=2 scenario: in kv cache quant scenario which is
# [batch_size, seq_len - residual_length, num_heads * head_dim]
# 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0
dim = 2 if observed.dim() == 3 else 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
Expand Down
Loading