From 0d948ea8bcdbfcdac347c0a092540fd6f7d04e73 Mon Sep 17 00:00:00 2001 From: evian Date: Sat, 19 Jul 2025 18:04:04 +0800 Subject: [PATCH] [KV Cache] support kv cache int8 per channel quantization Signed-off-by: evian --- .../modifiers/quantization/cache.py | 20 +++++++++++++++++-- src/llmcompressor/observers/base.py | 7 +++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/cache.py b/src/llmcompressor/modifiers/quantization/cache.py index dd3640dda..4d99ae4e7 100644 --- a/src/llmcompressor/modifiers/quantization/cache.py +++ b/src/llmcompressor/modifiers/quantization/cache.py @@ -94,6 +94,14 @@ def update( _pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer) _pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer) + # reshape for per channel scenario + num_heads = key_states.shape[1] + head_dim = key_states.shape[-1] + # from [batch_size, num_heads, seq_len - residual_length, head_dim] + # to [batch_size, seq_len - residual_length, num_heads * head_dim] + key_states = key_states.transpose(1, 2).flatten(2) + value_states = value_states.transpose(1, 2).flatten(2) + q_key_states = self._quantize( key_states.contiguous(), KVCacheScaleType.KEY, layer_idx ) @@ -106,6 +114,14 @@ def update( q_value_states, KVCacheScaleType.VALUE, layer_idx ) + # reshape for per channel scenario + # from [batch_size, seq_len - residual_length, num_heads * head_dim] + # to [batch_size, num_heads, seq_len - residual_length, head_dim] + qdq_key_states = qdq_key_states.view( + qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2) + qdq_value_states = qdq_value_states.view( + qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2) + keys_to_return, values_to_return = qdq_key_states, qdq_value_states return keys_to_return, values_to_return @@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx): zps = self.v_zps scale, zp = observer(tensor) - _pad_and_append_at_idx_(scales, layer_idx, scale) - _pad_and_append_at_idx_(zps, layer_idx, zp) + _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze()) + _pad_and_append_at_idx_(zps, layer_idx, zp.squeeze()) q_tensor = quantize( x=tensor, diff --git a/src/llmcompressor/observers/base.py b/src/llmcompressor/observers/base.py index 3ee446cf3..e82ffc899 100644 --- a/src/llmcompressor/observers/base.py +++ b/src/llmcompressor/observers/base.py @@ -181,8 +181,11 @@ def get_qparams( self._zero_point[:, group_index] = zero_point.squeeze(1) elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL: - # assume observed is transposed, because its the output, hence use dim 0 - self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0) + # 1. dim=2 scenario: in kv cache quant scenario which is + # [batch_size, seq_len - residual_length, num_heads * head_dim] + # 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0 + dim = 2 if observed.dim() == 3 else 0 + self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim) elif self.quantization_args.strategy == QuantizationStrategy.TOKEN: # use dim 1, assume the obsersed.shape = [batch, token, hidden]