Skip to content

v1/offloading: Add worker-side CPU support #21448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ steps:
- pytest -v -s v1/core
- pytest -v -s v1/engine
- pytest -v -s v1/entrypoints
- pytest -v -s v1/offloading
- pytest -v -s v1/sample
- pytest -v -s v1/worker
- pytest -v -s v1/structured_output
Expand Down
151 changes: 151 additions & 0 deletions tests/v1/offloading/test_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random

import pytest
import torch

from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.offloading.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.offloading.worker.cpu import (create_cpu_tensors,
generate_tensors_transfer_function)

NUM_GPU_BLOCKS = [64]
NUM_CPU_BLOCKS = [256]
GPU_BLOCK_SIZES = [16]
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
HEAD_SIZES = [64]
NUM_HEADS = [8]
NUM_LAYERS = [4]
DTYPES = [torch.bfloat16]
SEEDS = [0]
CUDA_DEVICES = ['cuda:0']
NUM_MAPPINGS = [3]


@pytest.mark.parametrize("gpu_to_cpu", [True, False])
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_transfer(
gpu_to_cpu: bool,
num_mappings: int,
head_size: int,
num_heads: int,
gpu_block_size: int,
gpu_blocks_per_cpu_block: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
num_layers: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)

# create per-layer GPU KV caches
attn_backend = FlashAttentionBackend
gpu_cache_shape = attn_backend.get_kv_cache_shape(num_gpu_blocks,
gpu_block_size,
num_heads, head_size)
gpu_caches = {}
for i in range(num_layers):
gpu_caches[f'layer {i}'] = torch.rand(gpu_cache_shape,
dtype=dtype,
device=device)

# create CPU KV caches
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
gpu_tensors, cpu_tensors = create_cpu_tensors(gpu_caches, gpu_block_size,
cpu_block_size,
num_cpu_blocks)

# select block mappings
gpu_blocks = random.sample(range(num_gpu_blocks),
num_mappings * gpu_blocks_per_cpu_block)
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)

# convert cpu blocks to gpu block size
cpu_blocks_in_gpu_block_size = []
for cpu_block in cpu_blocks:
base_block_id = cpu_block * gpu_blocks_per_cpu_block
for i in range(gpu_blocks_per_cpu_block):
cpu_blocks_in_gpu_block_size.append(i + base_block_id)

# set transfer direction
if gpu_to_cpu:
src_kv_caches = gpu_tensors
dst_kv_caches = cpu_tensors
src_block_size = gpu_block_size
dst_block_size = cpu_block_size
src_spec_class = GPULoadStoreSpec
dst_spec_class = CPULoadStoreSpec
src_blocks = gpu_blocks
dst_blocks = cpu_blocks
src_blocks_in_gpu_block_size = gpu_blocks
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
else:
src_kv_caches = cpu_tensors
dst_kv_caches = gpu_tensors
src_block_size = cpu_block_size
dst_block_size = gpu_block_size
src_spec_class = CPULoadStoreSpec
dst_spec_class = GPULoadStoreSpec
src_blocks = cpu_blocks
dst_blocks = gpu_blocks
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
dst_blocks_in_gpu_block_size = gpu_blocks
dst_size_in_gpu_blocks = num_gpu_blocks

# build dst -> src mapping
dst_to_src = {}
for src_block, dst_block in zip(src_blocks_in_gpu_block_size,
dst_blocks_in_gpu_block_size):
dst_to_src[dst_block] = src_block

# build transfer specs
src_specs = [src_spec_class(block_id) for block_id in src_blocks]
dst_specs = [dst_spec_class(block_id) for block_id in dst_blocks]

# create transfer function
transfer_func = generate_tensors_transfer_function(src_kv_caches,
dst_kv_caches,
attn_backend,
src_block_size,
dst_block_size)

# clone src and dst tensors before transfer
orig_src_caches = [x.clone() for x in src_kv_caches]
orig_dst_caches = [x.clone() for x in dst_kv_caches]

# call transfer function
assert transfer_func((src_specs, dst_specs)) is True

# verify src tensors did not change
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
assert torch.equal(orig_tensor, tensor)

# verify dst tensors
for dst_block in range(dst_size_in_gpu_blocks):
src_block_candidate = dst_to_src.get(dst_block)
for src_cache, dst_cache, orig_dst_cache in zip(
src_kv_caches, dst_kv_caches, orig_dst_caches):
# iterate over key, value
for i in range(2):
if src_block_candidate is not None:
expected_value = src_cache[i][src_block_candidate]
else:
expected_value = orig_dst_cache[i][dst_block]
torch.testing.assert_close(dst_cache[i][dst_block].cpu(),
expected_value.cpu())
140 changes: 140 additions & 0 deletions tests/v1/offloading/test_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading

from vllm.v1.offloading.abstract import LoadStoreSpec
from vllm.v1.offloading.worker.worker import (OffloadingQueueManager,
TransferSpec)


class LoadStoreSpec1(LoadStoreSpec):

def __init__(self, success: bool = True, exception: bool = False):
self.called_event = threading.Event()
self.finished_event = threading.Event()
self.success = success
self.exception = exception

@staticmethod
def medium() -> str:
return "1"


class LoadStoreSpec2(LoadStoreSpec):

@staticmethod
def medium() -> str:
return "2"


def transfer_function_1_to_2(transfer_spec: TransferSpec) -> bool:
srcs, dsts = transfer_spec
assert len(srcs) == 1
assert len(dsts) == 1

src, dst = srcs[0], dsts[0]
assert isinstance(src, LoadStoreSpec1)
assert isinstance(dst, LoadStoreSpec2)

src.called_event.set()
src.finished_event.wait()
if src.exception:
raise Exception("An expected exception. Don't worry!")
return src.success


def transfer_function_2_to_1(transfer_spec: TransferSpec) -> bool:
srcs, dsts = transfer_spec
assert len(srcs) == 1
assert len(dsts) == 1

src, dst = srcs[0], dsts[0]
assert isinstance(src, LoadStoreSpec2)
assert isinstance(dst, LoadStoreSpec1)

dst.called_event.set()
dst.finished_event.wait()
if dst.exception:
raise Exception()
return dst.success


def test_offloading_queue_manager():
"""
Tests OffloadingQueueManager with 2 workers.
One worker performs 1->2 transfers, and the other handles 2->1.
"""
offloading_queue_manager = OffloadingQueueManager()
offloading_queue_manager.register_worker(LoadStoreSpec1, LoadStoreSpec2,
transfer_function_1_to_2)
offloading_queue_manager.register_worker(LoadStoreSpec2, LoadStoreSpec1,
transfer_function_2_to_1)

# 1st transfer 1->2 (exception)
src1 = LoadStoreSpec1(exception=True)
dst1 = LoadStoreSpec2()
offloading_queue_manager.transfer_async(1, ([src1], [dst1]))

# 2ed transfer 1->2 (failure)
src2 = LoadStoreSpec1(success=False)
dst2 = LoadStoreSpec2()
offloading_queue_manager.transfer_async(2, ([src2], [dst2]))

# 3rd transfer 1->2 (success)
src3 = LoadStoreSpec1()
dst3 = LoadStoreSpec2()
offloading_queue_manager.transfer_async(3, ([src3], [dst3]))

# 4th transfer 2->1
src4 = LoadStoreSpec2()
dst4 = LoadStoreSpec1()
offloading_queue_manager.transfer_async(4, ([src4], [dst4]))

# 1st transfer started
assert src1.called_event.wait(timeout=1)

# 4th transfer started
assert dst4.called_event.wait(timeout=1)

# 2ed transfer have not started (blocked by 1st)
assert not src2.called_event.is_set()

# no transfer completed yet
assert offloading_queue_manager.get_finished() == []

# complete 1st transfer
src1.finished_event.set()

# 2ed transfer started
src2.called_event.wait(timeout=1)

# 1st transfer finished with failure (exception)
assert offloading_queue_manager.get_finished() == [(1, False)]

# complete 2ed, 3rd and 4th transfers
src2.finished_event.set()
src3.finished_event.set()
dst4.finished_event.set()

# 5th transfer 1->2
src5 = LoadStoreSpec1()
dst5 = LoadStoreSpec2()
offloading_queue_manager.transfer_async(5, ([src5], [dst5]))

# 6th transfer 2->1
src6 = LoadStoreSpec2()
dst6 = LoadStoreSpec1()
offloading_queue_manager.transfer_async(6, ([src6], [dst6]))

# 5th and 6th transfers started
assert src5.called_event.wait(timeout=1)
assert dst6.called_event.wait(timeout=1)

# verify result of 2ed, 3rd and 4th transfers
assert (sorted(offloading_queue_manager.get_finished()) == [(2, False),
(3, True),
(4, True)])

# complete 5th and 6th transfers
src5.finished_event.set()
dst6.finished_event.set()
13 changes: 13 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ def get_kv_cache_shape(
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets
Expand Down
Loading