-
Notifications
You must be signed in to change notification settings - Fork 179
[KV Cache] support kv cache int8 per channel quantization #1663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -94,6 +94,14 @@ def update( | |||||||||||||||||||||||||||||||
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) | ||||||||||||||||||||||||||||||||
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# reshape for per channel scenario | ||||||||||||||||||||||||||||||||
num_heads = key_states.shape[1] | ||||||||||||||||||||||||||||||||
head_dim = key_states.shape[-1] | ||||||||||||||||||||||||||||||||
# from [batch_size, num_heads, seq_len - residual_length, head_dim] | ||||||||||||||||||||||||||||||||
# to [batch_size, seq_len - residual_length, num_heads * head_dim] | ||||||||||||||||||||||||||||||||
key_states = key_states.transpose(1, 2).flatten(2) | ||||||||||||||||||||||||||||||||
value_states = value_states.transpose(1, 2).flatten(2) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
q_key_states = self._quantize( | ||||||||||||||||||||||||||||||||
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
@@ -106,6 +114,14 @@ def update( | |||||||||||||||||||||||||||||||
q_value_states, KVCacheScaleType.VALUE, layer_idx | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# reshape for per channel scenario | ||||||||||||||||||||||||||||||||
# from [batch_size, seq_len - residual_length, num_heads * head_dim] | ||||||||||||||||||||||||||||||||
# to [batch_size, num_heads, seq_len - residual_length, head_dim] | ||||||||||||||||||||||||||||||||
qdq_key_states = qdq_key_states.view( | ||||||||||||||||||||||||||||||||
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2) | ||||||||||||||||||||||||||||||||
qdq_value_states = qdq_value_states.view( | ||||||||||||||||||||||||||||||||
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2) | ||||||||||||||||||||||||||||||||
Comment on lines
+117
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This reverse reshape operation is specific to the per-channel scenario and should be executed conditionally. Applying this to tensors from other quantization strategies will result in shape mismatches. Wrap this in a condition that checks for the per-channel strategy:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per tensor is same |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
keys_to_return, values_to_return = qdq_key_states, qdq_value_states | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return keys_to_return, values_to_return | ||||||||||||||||||||||||||||||||
|
@@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx): | |||||||||||||||||||||||||||||||
zps = self.v_zps | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
scale, zp = observer(tensor) | ||||||||||||||||||||||||||||||||
_pad_and_append_at_idx_(scales, layer_idx, scale) | ||||||||||||||||||||||||||||||||
_pad_and_append_at_idx_(zps, layer_idx, zp) | ||||||||||||||||||||||||||||||||
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze()) | ||||||||||||||||||||||||||||||||
_pad_and_append_at_idx_(zps, layer_idx, zp.squeeze()) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
q_tensor = quantize( | ||||||||||||||||||||||||||||||||
x=tensor, | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reshape logic is specific to per-channel quantization but is currently applied unconditionally. This can cause incorrect behavior for other quantization strategies. Consider adding a conditional check to apply this logic only when
self.quantization_args.strategy == QuantizationStrategy.CHANNEL
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per tensor is same