Skip to content

Commit efc5f9c

Browse files
Update src/llmcompressor/modifiers/quantization/cache.py
per tensor are same Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 692d1a5 commit efc5f9c

File tree

1 file changed

+8
-7
lines changed
  • src/llmcompressor/modifiers/quantization

1 file changed

+8
-7
lines changed

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,14 @@ def update(
115115
q_value_states, KVCacheScaleType.VALUE, layer_idx
116116
)
117117

118-
# reshape for per channel scenario
119-
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
120-
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
121-
qdq_key_states = qdq_key_states.view(
122-
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
123-
qdq_value_states = qdq_value_states.view(
124-
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)
118+
if self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
119+
# reshape for per channel scenario
120+
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
121+
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
122+
qdq_key_states = qdq_key_states.view(
123+
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
124+
qdq_value_states = qdq_value_states.view(
125+
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)
125126

126127
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
127128

0 commit comments

Comments
 (0)