17
17
18
18
logger = init_logger (__name__ )
19
19
20
+ # TPU requires the head size to be a multiple of 128.
21
+ TPU_HEAD_SIZE_ALIGNMENT = 128
22
+
20
23
21
24
class PallasAttentionBackend (AttentionBackend ):
22
25
@@ -43,6 +46,14 @@ def get_kv_cache_shape(
43
46
num_kv_heads : int ,
44
47
head_size : int ,
45
48
) -> tuple [int , ...]:
49
+ padded_head_size = cdiv (
50
+ head_size , TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
51
+ num_blocks = num_blocks * head_size // padded_head_size
52
+ if padded_head_size != head_size :
53
+ logger .warning_once (
54
+ "head size is padded to %d, and num_blocks is adjusted to %d"
55
+ " accordingly" , padded_head_size , num_blocks )
56
+ head_size = padded_head_size
46
57
return (num_blocks , block_size , num_kv_heads * 2 , head_size )
47
58
48
59
@staticmethod
@@ -132,8 +143,6 @@ def __init__(
132
143
self .kv_sharing_target_layer_name = kv_sharing_target_layer_name
133
144
134
145
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
135
- if head_size % 128 != 0 :
136
- raise NotImplementedError ("Head size must be a multiple of 128." )
137
146
if alibi_slopes is not None :
138
147
raise NotImplementedError ("Alibi slopes is not supported." )
139
148
if kv_cache_dtype != "auto" :
@@ -187,6 +196,18 @@ def forward(
187
196
assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
188
197
num_tokens , hidden_size = query .shape
189
198
query = query .view (num_tokens , self .num_heads , self .head_size )
199
+ key = key .view (- 1 , self .num_kv_heads , self .head_size )
200
+ value = value .view (- 1 , self .num_kv_heads , self .head_size )
201
+ if self .head_size % TPU_HEAD_SIZE_ALIGNMENT != 0 :
202
+ padded_head_size = cdiv (
203
+ self .head_size ,
204
+ TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
205
+ query = torch .nn .functional .pad (
206
+ query , (0 , padded_head_size - self .head_size ), value = 0.0 )
207
+ key = torch .nn .functional .pad (
208
+ key , (0 , padded_head_size - self .head_size ), value = 0.0 )
209
+ value = torch .nn .functional .pad (
210
+ value , (0 , padded_head_size - self .head_size ), value = 0.0 )
190
211
191
212
if self .kv_sharing_target_layer_name is None and kv_cache .numel () > 0 :
192
213
# Write input keys and values to the KV cache.
@@ -213,6 +234,9 @@ def forward(
213
234
soft_cap = self .logits_soft_cap ,
214
235
)
215
236
237
+ if self .head_size % TPU_HEAD_SIZE_ALIGNMENT != 0 :
238
+ output = output [:, :, :self .head_size ]
239
+
216
240
return output .reshape (num_tokens , hidden_size )
217
241
218
242
@@ -231,11 +255,8 @@ def write_to_kv_cache(
231
255
232
256
"""
233
257
_ , _ , num_combined_kv_heads , head_size = kv_cache .shape
234
- num_kv_heads = num_combined_kv_heads // 2
235
-
236
- key = key .view (- 1 , num_kv_heads , head_size )
237
- value = value .view (- 1 , num_kv_heads , head_size )
238
-
258
+ head_size = cdiv (head_size ,
259
+ TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
239
260
kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
240
261
head_size )
241
262
0 commit comments