Skip to content

Commit 6121c2d

Browse files
author
evian
committed
[KV Cache] support kv cache int8 per channel quantization
1 parent 2c70cb0 commit 6121c2d

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def update(
9494
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
9595
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)
9696

97+
# reshape for per channel scenario
98+
num_heads = key_states.shape[1]
99+
head_dim = key_states.shape[-1]
100+
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
101+
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
102+
key_states = key_states.transpose(1, 2).flatten(2)
103+
value_states = value_states.transpose(1, 2).flatten(2)
104+
97105
q_key_states = self._quantize(
98106
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
99107
)
@@ -106,6 +114,14 @@ def update(
106114
q_value_states, KVCacheScaleType.VALUE, layer_idx
107115
)
108116

117+
# reshape for per channel scenario
118+
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
119+
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
120+
qdq_key_states = qdq_key_states.view(
121+
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
122+
qdq_value_states = qdq_value_states.view(
123+
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)
124+
109125
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
110126

111127
return keys_to_return, values_to_return
@@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
155171
zps = self.v_zps
156172

157173
scale, zp = observer(tensor)
158-
_pad_and_append_at_idx_(scales, layer_idx, scale)
159-
_pad_and_append_at_idx_(zps, layer_idx, zp)
174+
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze())
175+
_pad_and_append_at_idx_(zps, layer_idx, zp.squeeze())
160176

161177
q_tensor = quantize(
162178
x=tensor,

src/llmcompressor/observers/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,11 @@ def get_qparams(
181181
self._zero_point[:, group_index] = zero_point.squeeze(1)
182182

183183
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
184-
# assume observed is transposed, because its the output, hence use dim 0
185-
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
184+
# 1. dim=2 scenario: in kv cache quant scenario which is
185+
# [batch_size, seq_len - residual_length, num_heads * head_dim]
186+
# 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0
187+
dim = 2 if observed.dim() == 3 else 0
188+
self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim)
186189

187190
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
188191
# use dim 1, assume the obsersed.shape = [batch, token, hidden]

0 commit comments

Comments
 (0)