@@ -115,13 +115,14 @@ def update(
115
115
q_value_states , KVCacheScaleType .VALUE , layer_idx
116
116
)
117
117
118
- # reshape for per channel scenario
119
- # from [batch_size, seq_len - residual_length, num_heads * head_dim]
120
- # to [batch_size, num_heads, seq_len - residual_length, head_dim]
121
- qdq_key_states = qdq_key_states .view (
122
- qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
123
- qdq_value_states = qdq_value_states .view (
124
- qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
118
+ if self .quantization_args .strategy == QuantizationStrategy .CHANNEL :
119
+ # reshape for per channel scenario
120
+ # from [batch_size, seq_len - residual_length, num_heads * head_dim]
121
+ # to [batch_size, num_heads, seq_len - residual_length, head_dim]
122
+ qdq_key_states = qdq_key_states .view (
123
+ qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
124
+ qdq_value_states = qdq_value_states .view (
125
+ qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
125
126
126
127
keys_to_return , values_to_return = qdq_key_states , qdq_value_states
127
128
0 commit comments