@@ -94,6 +94,14 @@ def update(
94
94
_pad_and_append_at_idx_ (self .k_observers , layer_idx , k_observer )
95
95
_pad_and_append_at_idx_ (self .v_observers , layer_idx , v_observer )
96
96
97
+ # reshape for per channel scenario
98
+ num_heads = key_states .shape [1 ]
99
+ head_dim = key_states .shape [- 1 ]
100
+ # from [batch_size, num_heads, seq_len - residual_length, head_dim]
101
+ # to [batch_size, seq_len - residual_length, num_heads * head_dim]
102
+ key_states = key_states .transpose (1 , 2 ).flatten (2 )
103
+ value_states = value_states .transpose (1 , 2 ).flatten (2 )
104
+
97
105
q_key_states = self ._quantize (
98
106
key_states .contiguous (), KVCacheScaleType .KEY , layer_idx
99
107
)
@@ -106,6 +114,14 @@ def update(
106
114
q_value_states , KVCacheScaleType .VALUE , layer_idx
107
115
)
108
116
117
+ # reshape for per channel scenario
118
+ # from [batch_size, seq_len - residual_length, num_heads * head_dim]
119
+ # to [batch_size, num_heads, seq_len - residual_length, head_dim]
120
+ qdq_key_states = qdq_key_states .view (
121
+ qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
122
+ qdq_value_states = qdq_value_states .view (
123
+ qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
124
+
109
125
keys_to_return , values_to_return = qdq_key_states , qdq_value_states
110
126
111
127
return keys_to_return , values_to_return
@@ -155,8 +171,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
155
171
zps = self .v_zps
156
172
157
173
scale , zp = observer (tensor )
158
- _pad_and_append_at_idx_ (scales , layer_idx , scale )
159
- _pad_and_append_at_idx_ (zps , layer_idx , zp )
174
+ _pad_and_append_at_idx_ (scales , layer_idx , scale . squeeze () )
175
+ _pad_and_append_at_idx_ (zps , layer_idx , zp . squeeze () )
160
176
161
177
q_tensor = quantize (
162
178
x = tensor ,
0 commit comments