-
Notifications
You must be signed in to change notification settings - Fork 182
[KV Cache] support kv cache int8 per channel quant #1662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -181,8 +181,11 @@ def get_qparams( | |
self._zero_point[:, group_index] = zero_point.squeeze(1) | ||
|
||
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: | ||
# assume observed is transposed, because its the output, hence use dim 0 | ||
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) | ||
# 1. dim=2 scenario: in kv cache quant scenario which is | ||
# [batch_size, seq_len - residual_length, num_heads * head_dim] | ||
# 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0 | ||
dim = 2 if observed.dim == 3 else 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There appears to be a typo in the condition here. Comparing the method object to an integer ( dim = 2 if observed.dim() == 3 else 0 |
||
self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim) | ||
|
||
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: | ||
# use dim 1, assume the obsersed.shape = [batch, token, hidden] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reshape logic for
key_states
andvalue_states
is specific to the per-channel quantization strategy, as noted in the comments. However, it's applied unconditionally, which will likely break other KV cache quantization strategies (e.g., per-tensor).This reshape block should be wrapped in a conditional check, for example:
if self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
This change will introduce a variable scope issue for
num_heads
andhead_dim
, which are defined in this block but also needed for the reverse reshape. You'll need to refactor theupdate
method to handle this, for instance by initializingnum_heads
andhead_dim
toNone
at the beginning of the function.You will also need to add the following import at the top of the file:
from compressed_tensors.quantization.quant_args import QuantizationStrategy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pertensor and perchannel are same.