Skip to content

Commit 4a19c04

Browse files
committed
v1/offloading: Add worker-side CPU support
This commit adds worker-side support for CPU offloading. It uses the swap_blocks function to perform the actual copying between CPU and GPU. Supports any CPU block size which is divided by GPU block size. Signed-off-by: Or Ozeri <oro@il.ibm.com>
1 parent b1cfec2 commit 4a19c04

File tree

3 files changed

+289
-0
lines changed

3 files changed

+289
-0
lines changed

tests/v1/offloading/test_cpu.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import random
4+
5+
import pytest
6+
import torch
7+
8+
from vllm.platforms import current_platform
9+
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
10+
from vllm.v1.offloading.mediums import CPULoadStoreSpec, GPULoadStoreSpec
11+
from vllm.v1.offloading.worker.cpu import (create_cpu_tensors,
12+
generate_tensors_transfer_function)
13+
14+
NUM_GPU_BLOCKS = [64]
15+
NUM_CPU_BLOCKS = [256]
16+
GPU_BLOCK_SIZES = [16]
17+
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
18+
HEAD_SIZES = [64]
19+
NUM_HEADS = [8]
20+
NUM_LAYERS = [4]
21+
DTYPES = [torch.bfloat16]
22+
SEEDS = [0]
23+
CUDA_DEVICES = ['cuda:0']
24+
NUM_MAPPINGS = [3]
25+
26+
27+
@pytest.mark.parametrize("gpu_to_cpu", [True, False])
28+
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
29+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
30+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
31+
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
32+
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
33+
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
34+
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
35+
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
36+
@pytest.mark.parametrize("dtype", DTYPES)
37+
@pytest.mark.parametrize("seed", SEEDS)
38+
@pytest.mark.parametrize("device", CUDA_DEVICES)
39+
@torch.inference_mode()
40+
def test_transfer(
41+
gpu_to_cpu: bool,
42+
num_mappings: int,
43+
head_size: int,
44+
num_heads: int,
45+
gpu_block_size: int,
46+
gpu_blocks_per_cpu_block: int,
47+
num_gpu_blocks: int,
48+
num_cpu_blocks: int,
49+
num_layers: int,
50+
dtype: torch.dtype,
51+
seed: int,
52+
device: str,
53+
) -> None:
54+
current_platform.seed_everything(seed)
55+
56+
# create per-layer GPU KV caches
57+
attn_backend = FlashAttentionBackend
58+
gpu_cache_shape = attn_backend.get_kv_cache_shape(num_gpu_blocks,
59+
gpu_block_size,
60+
num_heads, head_size)
61+
gpu_caches = {}
62+
for i in range(num_layers):
63+
gpu_caches[f'layer {i}'] = torch.rand(gpu_cache_shape,
64+
dtype=dtype,
65+
device=device)
66+
67+
# create CPU KV caches
68+
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
69+
gpu_tensors, cpu_tensors = create_cpu_tensors(gpu_caches, gpu_block_size,
70+
cpu_block_size,
71+
num_cpu_blocks)
72+
73+
# select block mappings
74+
gpu_blocks = random.sample(range(num_gpu_blocks),
75+
num_mappings * gpu_blocks_per_cpu_block)
76+
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
77+
78+
# convert cpu blocks to gpu block size
79+
cpu_blocks_in_gpu_block_size = []
80+
for cpu_block in cpu_blocks:
81+
base_block_id = cpu_block * gpu_blocks_per_cpu_block
82+
for i in range(gpu_blocks_per_cpu_block):
83+
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
84+
85+
# set transfer direction
86+
if gpu_to_cpu:
87+
src_kv_caches = gpu_tensors
88+
dst_kv_caches = cpu_tensors
89+
src_block_size = gpu_block_size
90+
dst_block_size = cpu_block_size
91+
src_spec_class = GPULoadStoreSpec
92+
dst_spec_class = CPULoadStoreSpec
93+
src_blocks = gpu_blocks
94+
dst_blocks = cpu_blocks
95+
src_blocks_in_gpu_block_size = gpu_blocks
96+
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
97+
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
98+
else:
99+
src_kv_caches = cpu_tensors
100+
dst_kv_caches = gpu_tensors
101+
src_block_size = cpu_block_size
102+
dst_block_size = gpu_block_size
103+
src_spec_class = CPULoadStoreSpec
104+
dst_spec_class = GPULoadStoreSpec
105+
src_blocks = cpu_blocks
106+
dst_blocks = gpu_blocks
107+
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
108+
dst_blocks_in_gpu_block_size = gpu_blocks
109+
dst_size_in_gpu_blocks = num_gpu_blocks
110+
111+
# build dst -> src mapping
112+
dst_to_src = {}
113+
for src_block, dst_block in zip(src_blocks_in_gpu_block_size,
114+
dst_blocks_in_gpu_block_size):
115+
dst_to_src[dst_block] = src_block
116+
117+
# build transfer specs
118+
src_specs = [src_spec_class(block_id) for block_id in src_blocks]
119+
dst_specs = [dst_spec_class(block_id) for block_id in dst_blocks]
120+
121+
# create transfer function
122+
transfer_func = generate_tensors_transfer_function(src_kv_caches,
123+
dst_kv_caches,
124+
attn_backend,
125+
src_block_size,
126+
dst_block_size)
127+
128+
# clone src and dst tensors before transfer
129+
orig_src_caches = [x.clone() for x in src_kv_caches]
130+
orig_dst_caches = [x.clone() for x in dst_kv_caches]
131+
132+
# call transfer function
133+
assert transfer_func((src_specs, dst_specs)) is True
134+
135+
# verify src tensors did not change
136+
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
137+
assert torch.equal(orig_tensor, tensor)
138+
139+
# verify dst tensors
140+
for dst_block in range(dst_size_in_gpu_blocks):
141+
src_block_candidate = dst_to_src.get(dst_block)
142+
for src_cache, dst_cache, orig_dst_cache in zip(
143+
src_kv_caches, dst_kv_caches, orig_dst_caches):
144+
# iterate over key, value
145+
for i in range(2):
146+
if src_block_candidate is not None:
147+
expected_value = src_cache[i][src_block_candidate]
148+
else:
149+
expected_value = orig_dst_cache[i][dst_block]
150+
torch.testing.assert_close(dst_cache[i][dst_block].cpu(),
151+
expected_value.cpu())

vllm/v1/attention/backends/flash_attn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,19 @@ def get_kv_cache_shape(
8686
raise ValueError("Block size must be a multiple of 16.")
8787
return (2, num_blocks, block_size, num_kv_heads, head_size)
8888

89+
@staticmethod
90+
def swap_blocks(
91+
src_kv_cache: torch.Tensor,
92+
dst_kv_cache: torch.Tensor,
93+
src_to_dst: torch.Tensor,
94+
) -> None:
95+
src_key_cache = src_kv_cache[0]
96+
dst_key_cache = dst_kv_cache[0]
97+
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
98+
src_value_cache = src_kv_cache[1]
99+
dst_value_cache = dst_kv_cache[1]
100+
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
101+
89102
@staticmethod
90103
def get_kv_cache_stride_order() -> tuple[int, ...]:
91104
# `stride_order` indicates the permutation that gets

vllm/v1/offloading/worker/cpu.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Iterator
4+
5+
import torch
6+
7+
from vllm.attention import AttentionBackend
8+
from vllm.v1.offloading.abstract import LoadStoreSpec
9+
from vllm.v1.offloading.mediums import BlockIDLoadStoreSpec
10+
from vllm.v1.offloading.worker.worker import TransferFunction, TransferSpec
11+
12+
13+
def create_cpu_tensors(
14+
gpu_kv_caches: dict[str, torch.Tensor],
15+
gpu_block_size: int,
16+
cpu_block_size: int,
17+
num_cpu_blocks: int,
18+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
19+
"""
20+
Create tensors for the CPU KV cache.
21+
22+
Args:
23+
gpu_kv_caches: The per-layer GPU KV cache tensors
24+
gpu_block_size: Number of tokens per GPU block
25+
cpu_block_size: Number of tokens per CPU block
26+
num_cpu_blocks: The number of CPU blocks to allocate
27+
28+
Note:
29+
- The GPU block size must divide the CPU block size.
30+
- The shape of the GPU KV cache must be (2, num_blocks, ...)
31+
32+
Returns:
33+
Matching per-layer lists of (gpu_tensors, cpu_tensors).
34+
"""
35+
assert cpu_block_size % gpu_block_size == 0
36+
37+
gpu_tensors = []
38+
cpu_tensors = []
39+
for gpu_tensor in gpu_kv_caches.values():
40+
gpu_shape = gpu_tensor.shape
41+
assert len(gpu_shape) >= 4 # (2, num_blocks, ..., ...)
42+
assert gpu_shape[0] == 2
43+
44+
cpu_shape = list(gpu_shape)
45+
cpu_shape[1] = num_cpu_blocks * (cpu_block_size // gpu_block_size)
46+
47+
gpu_tensors.append(gpu_tensor)
48+
cpu_tensors.append(
49+
torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu"))
50+
51+
return gpu_tensors, cpu_tensors
52+
53+
54+
def block_ids(specs_list: list[LoadStoreSpec],
55+
block_size_factor: int) -> Iterator[int]:
56+
"""
57+
Convert a list of BlockIDLoadStoreSpec to a list of matching block ids,
58+
assuming each spec is composed of actual block_size_factor blocks.
59+
60+
For example, if spec_list = [0, 1, 3] and block_size_factor = 4,
61+
then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
62+
since 0 maps to [0, 1, 2, 3]
63+
1 maps to [4, 5, 6, 7]
64+
and 3 maps to [12, 13, 14, 15]
65+
"""
66+
for spec in specs_list:
67+
assert isinstance(spec, BlockIDLoadStoreSpec)
68+
base_block_id = spec.block_id * block_size_factor
69+
for i in range(block_size_factor):
70+
yield base_block_id + i
71+
72+
73+
def generate_tensors_transfer_function(
74+
src_tensors: list[torch.Tensor],
75+
dst_tensors: list[torch.Tensor],
76+
attn_backend: type[AttentionBackend],
77+
src_block_size: int,
78+
dst_block_size: int,
79+
) -> TransferFunction:
80+
"""
81+
Generate a function for transferring from one KV cache to another.
82+
83+
Args:
84+
src_tensors: the per-layer tensors of the source KV cache.
85+
dst_tensors: the per-layer tensors of the destination KV cache.
86+
attn_backend: the attention backend for both caches.
87+
src_block_size: the block size of the source KV cache.
88+
dst_block_size: the block size of the destination KV cache.
89+
90+
Returns:
91+
A function for executing transfers between the caches.
92+
93+
Note: one of src_block_size, dst_block_size must divide the other.
94+
"""
95+
assert len(src_tensors) == len(dst_tensors)
96+
97+
min_block_size = min(src_block_size, dst_block_size)
98+
max_block_size = max(src_block_size, dst_block_size)
99+
assert max_block_size % min_block_size == 0
100+
101+
src_block_size_factor = src_block_size // min_block_size
102+
dst_block_size_factor = dst_block_size // min_block_size
103+
104+
def transfer_function(spec: TransferSpec) -> bool:
105+
src_blocks_specs_list, dst_blocks_specs_list = spec
106+
107+
assert (len(src_blocks_specs_list) *
108+
src_block_size_factor == len(dst_blocks_specs_list) *
109+
dst_block_size_factor)
110+
111+
src_to_dst_list: list[tuple[int, int]] = list(
112+
zip(block_ids(src_blocks_specs_list, src_block_size_factor),
113+
block_ids(dst_blocks_specs_list, dst_block_size_factor)))
114+
src_to_dst = torch.tensor(src_to_dst_list,
115+
device="cpu",
116+
dtype=torch.int64).view(-1, 2)
117+
118+
# iterate over layers
119+
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
120+
attn_backend.swap_blocks(src_tensor, dst_tensor, src_to_dst)
121+
122+
# always successful
123+
return True
124+
125+
return transfer_function

0 commit comments

Comments
 (0)