24
24
# TPU requires the head size to be a multiple of 128.
25
25
TPU_HEAD_SIZE_ALIGNMENT = 128
26
26
27
+ # Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
28
+ # from to fp32 directly. That's why it has a dtype mapping different from GPU
29
+ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
30
+ "half" : torch .half ,
31
+ "bfloat16" : torch .bfloat16 ,
32
+ "float" : torch .float ,
33
+ "fp8" : torch .float8_e4m3fn ,
34
+ "fp8_e4m3" : torch .float8_e4m3fn ,
35
+ "fp8_e5m2" : torch .float8_e5m2 ,
36
+ "int8" : torch .int8 ,
37
+ "uint8" : torch .uint8 ,
38
+ }
39
+
27
40
28
41
class PallasAttentionBackend (AttentionBackend ):
29
42
@@ -156,8 +169,6 @@ def __init__(
156
169
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
157
170
if alibi_slopes is not None :
158
171
raise NotImplementedError ("Alibi slopes is not supported." )
159
- if kv_cache_dtype != "auto" :
160
- raise NotImplementedError ("FP8 KV cache dtype is not supported." )
161
172
if blocksparse_params is not None :
162
173
raise NotImplementedError ("Blocksparse is not supported." )
163
174
@@ -170,6 +181,14 @@ def __init__(
170
181
tpu_version = torch_xla .tpu .version ()
171
182
if tpu_version < 4 :
172
183
raise NotImplementedError ("TPU version must be 4 or higher." )
184
+ self .kv_cache_quantized_dtype = None
185
+ if kv_cache_dtype != "auto" :
186
+ if tpu_version < 5 :
187
+ raise NotImplementedError (
188
+ "FP8 KV cache dtype is only supported when TPU version"
189
+ " is 5 or higher." )
190
+ self .kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE .get (
191
+ kv_cache_dtype .lower ().strip ())
173
192
174
193
def forward (
175
194
self ,
@@ -204,7 +223,6 @@ def forward(
204
223
output = torch .ones_like (query )
205
224
return output
206
225
207
- assert layer ._k_scale_float == 1.0 and layer ._v_scale_float == 1.0
208
226
num_tokens , hidden_size = query .shape
209
227
query = query .view (num_tokens , self .num_heads , self .head_size )
210
228
key = key .view (- 1 , self .num_kv_heads , self .head_size )
@@ -225,10 +243,21 @@ def forward(
225
243
# Skip this if sharing KV cache with an earlier attention layer.
226
244
slot_mapping = attn_metadata .slot_mapping
227
245
write_to_kv_cache (
228
- key , value , kv_cache , slot_mapping ,
246
+ key ,
247
+ value ,
248
+ kv_cache ,
249
+ slot_mapping ,
229
250
attn_metadata .num_slices_per_kv_cache_update_block ,
230
- attn_metadata .num_kv_update_slices )
231
-
251
+ attn_metadata .num_kv_update_slices ,
252
+ self .kv_cache_quantized_dtype ,
253
+ layer ._k_scale_float ,
254
+ layer ._v_scale_float ,
255
+ )
256
+
257
+ if self .kv_cache_quantized_dtype is not None and (
258
+ layer ._k_scale_float == 0.0 or layer ._v_scale_float == 0.0 ):
259
+ raise ValueError (
260
+ "k_scale_float and v_scale_float must be non-zero" )
232
261
output = torch .ops .xla .ragged_paged_attention (
233
262
query ,
234
263
kv_cache ,
@@ -246,6 +275,8 @@ def forward(
246
275
sm_scale = self .scale ,
247
276
sliding_window = self .sliding_window ,
248
277
soft_cap = self .logits_soft_cap ,
278
+ k_scale = layer ._k_scale_float ,
279
+ v_scale = layer ._v_scale_float ,
249
280
)
250
281
251
282
if self .head_size % TPU_HEAD_SIZE_ALIGNMENT != 0 :
@@ -261,18 +292,32 @@ def write_to_kv_cache(
261
292
slot_mapping : torch .Tensor ,
262
293
num_slices_per_kv_cache_update_block : int ,
263
294
num_kv_update_slices : torch .Tensor ,
295
+ kv_cache_quantized_dtype : Optional [torch .dtype ] = None ,
296
+ k_scale : float = 1.0 ,
297
+ v_scale : float = 1.0 ,
264
298
) -> None :
265
299
""" Write the key and values to the KV cache.
266
300
267
301
Args:
268
- key: shape = [num_tokens, num_kv_heads * head_size]
269
- value: shape = [num_tokens, num_kv_heads * head_size]
302
+ key: shape = [num_tokens, num_kv_heads, head_size]
303
+ value: shape = [num_tokens, num_kv_heads, head_size]
270
304
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
271
305
num_slices_per_kv_cache_update_block: int
272
306
"""
273
307
_ , page_size , num_combined_kv_heads , head_size = kv_cache .shape
274
308
head_size = cdiv (head_size ,
275
309
TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
310
+
311
+ if kv_cache_quantized_dtype is not None :
312
+ dtype_info = torch .finfo (kv_cache_quantized_dtype )
313
+ key = key .to (torch .float32 ) / k_scale
314
+ # NOTE: clamp is added here to avoid out of range of quantized dtype
315
+ key = torch .clamp (key , dtype_info .min , dtype_info .max )
316
+ key = key .to (kv_cache_quantized_dtype )
317
+ value = value .to (torch .float32 ) / v_scale
318
+ value = torch .clamp (value , dtype_info .min , dtype_info .max )
319
+ value = value .to (kv_cache_quantized_dtype )
320
+
276
321
kv = torch .cat ([key , value ], axis = - 1 ).reshape (- 1 , num_combined_kv_heads ,
277
322
head_size )
278
323
0 commit comments