Skip to content

Commit 692d1a5

Browse files
Update src/llmcompressor/modifiers/quantization/cache.py
per tensor are same Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 0d948ea commit 692d1a5

File tree

1 file changed

+5
-4
lines changed
  • src/llmcompressor/modifiers/quantization

1 file changed

+5
-4
lines changed

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ def update(
9797
# reshape for per channel scenario
9898
num_heads = key_states.shape[1]
9999
head_dim = key_states.shape[-1]
100-
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
101-
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
102-
key_states = key_states.transpose(1, 2).flatten(2)
103-
value_states = value_states.transpose(1, 2).flatten(2)
100+
if self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
101+
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
102+
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
103+
key_states = key_states.transpose(1, 2).flatten(2)
104+
value_states = value_states.transpose(1, 2).flatten(2)
104105

105106
q_key_states = self._quantize(
106107
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx

0 commit comments

Comments
 (0)