@@ -326,7 +326,54 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
326
326
return kv_cache
327
327
328
328
329
+ # We can move this function to a common utils file if it's also useful for other
330
+ # hardware.
331
+ def dtype_bits (dtype : torch .dtype ):
332
+ if dtype .is_floating_point :
333
+ try :
334
+ return torch .finfo (dtype ).bits
335
+ except TypeError :
336
+ pass
337
+ elif dtype .is_complex :
338
+ if dtype is torch .complex32 :
339
+ return 32
340
+ elif dtype is torch .complex64 :
341
+ return 64
342
+ elif dtype is torch .complex128 :
343
+ return 128
344
+ else :
345
+ try :
346
+ return torch .iinfo (dtype ).bits
347
+ # torch.iinfo cannot support int4, int2, bits8...
348
+ except TypeError :
349
+ pass
350
+ str_dtype = str (dtype )
351
+ # support torch.int4, torch.int5, torch.uint5...
352
+ if str_dtype .startswith ("torch.int" ) or str_dtype .startswith ("torch.uint" ):
353
+ return int (str_dtype [- 1 ])
354
+ raise TypeError (f"Getting the bit width of { dtype } is not supported" )
355
+
356
+
357
+ def get_dtype_packing (dtype ):
358
+ bits = dtype_bits (dtype )
359
+ if 32 % bits != 0 :
360
+ raise ValueError (
361
+ f"The bit width must be divisible by 32, but got bits={ bits } , "
362
+ "dtype={dtype}" )
363
+ return 32 // bits
364
+
365
+
329
366
def get_page_size_bytes (block_size : int , num_kv_heads : int , head_size : int ,
330
367
kv_cache_dtype : torch .dtype ) -> int :
331
368
"""Returns the size in bytes of one page of the KV cache."""
332
- return block_size * num_kv_heads * head_size * kv_cache_dtype .itemsize
369
+ padded_head_size = cdiv (head_size ,
370
+ TPU_HEAD_SIZE_ALIGNMENT ) * TPU_HEAD_SIZE_ALIGNMENT
371
+ num_combined_kv_heads = num_kv_heads * 2
372
+
373
+ # NOTE: for the implicit padding in XLA
374
+ packing = get_dtype_packing (kv_cache_dtype )
375
+ num_combined_kv_heads = cdiv (num_combined_kv_heads , packing ) * packing
376
+
377
+ kv_cache_dtype_bits = dtype_bits (kv_cache_dtype )
378
+ return (block_size * num_combined_kv_heads * padded_head_size *
379
+ kv_cache_dtype_bits // 8 )
0 commit comments