Skip to content

Commit 71d1219

Browse files
authored
[Kernel] correct cpu worker function parameter type (#19745)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
1 parent e384f2f commit 71d1219

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

vllm/attention/ops/ipex_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_kv_cache_shape(
2929
head_size: int,
3030
*args,
3131
) -> Tuple[int, ...]:
32-
return (2, num_blocks, block_size * num_kv_heads * head_size)
32+
return 2, num_blocks, block_size * num_kv_heads * head_size
3333

3434
@staticmethod
3535
def split_kv_cache(

vllm/worker/cpu_worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""A CPU worker class."""
44
import os
55
from importlib import util
6-
from typing import Dict, List, Optional, Set, Tuple, Type
6+
from typing import List, Optional, Set, Tuple, Type
77

88
import torch
99
import torch.distributed
@@ -88,13 +88,13 @@ def _allocate_kv_cache(
8888
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
8989
return kv_cache
9090

91-
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
91+
def swap_in(self, src_to_dst: torch.Tensor) -> None:
9292
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
9393

94-
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
94+
def swap_out(self, src_to_dst: torch.Tensor) -> None:
9595
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
9696

97-
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
97+
def copy(self, src_to_dsts: torch.Tensor) -> None:
9898
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
9999

100100
@staticmethod

0 commit comments

Comments
 (0)