|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import random |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.platforms import current_platform |
| 9 | +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend |
| 10 | +from vllm.v1.offloading.mediums import CPULoadStoreSpec, GPULoadStoreSpec |
| 11 | +from vllm.v1.offloading.worker.cpu import (create_cpu_tensors, |
| 12 | + generate_tensors_transfer_function) |
| 13 | + |
| 14 | +NUM_GPU_BLOCKS = [64] |
| 15 | +NUM_CPU_BLOCKS = [256] |
| 16 | +GPU_BLOCK_SIZES = [16] |
| 17 | +GPU_BLOCKS_PER_CPU_BLOCK = [1, 3] |
| 18 | +HEAD_SIZES = [64] |
| 19 | +NUM_HEADS = [8] |
| 20 | +NUM_LAYERS = [4] |
| 21 | +DTYPES = [torch.bfloat16] |
| 22 | +SEEDS = [0] |
| 23 | +CUDA_DEVICES = ['cuda:0'] |
| 24 | +NUM_MAPPINGS = [3] |
| 25 | + |
| 26 | + |
| 27 | +@pytest.mark.parametrize("gpu_to_cpu", [True, False]) |
| 28 | +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
| 29 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 30 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 31 | +@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES) |
| 32 | +@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK) |
| 33 | +@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS) |
| 34 | +@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS) |
| 35 | +@pytest.mark.parametrize("num_layers", NUM_LAYERS) |
| 36 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 37 | +@pytest.mark.parametrize("seed", SEEDS) |
| 38 | +@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 39 | +@torch.inference_mode() |
| 40 | +def test_transfer( |
| 41 | + gpu_to_cpu: bool, |
| 42 | + num_mappings: int, |
| 43 | + head_size: int, |
| 44 | + num_heads: int, |
| 45 | + gpu_block_size: int, |
| 46 | + gpu_blocks_per_cpu_block: int, |
| 47 | + num_gpu_blocks: int, |
| 48 | + num_cpu_blocks: int, |
| 49 | + num_layers: int, |
| 50 | + dtype: torch.dtype, |
| 51 | + seed: int, |
| 52 | + device: str, |
| 53 | +) -> None: |
| 54 | + current_platform.seed_everything(seed) |
| 55 | + |
| 56 | + # create per-layer GPU KV caches |
| 57 | + attn_backend = FlashAttentionBackend |
| 58 | + gpu_cache_shape = attn_backend.get_kv_cache_shape(num_gpu_blocks, |
| 59 | + gpu_block_size, |
| 60 | + num_heads, head_size) |
| 61 | + gpu_caches = {} |
| 62 | + for i in range(num_layers): |
| 63 | + gpu_caches[f'layer {i}'] = torch.rand(gpu_cache_shape, |
| 64 | + dtype=dtype, |
| 65 | + device=device) |
| 66 | + |
| 67 | + # create CPU KV caches |
| 68 | + cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size |
| 69 | + gpu_tensors, cpu_tensors = create_cpu_tensors(gpu_caches, gpu_block_size, |
| 70 | + cpu_block_size, |
| 71 | + num_cpu_blocks) |
| 72 | + |
| 73 | + # select block mappings |
| 74 | + gpu_blocks = random.sample(range(num_gpu_blocks), |
| 75 | + num_mappings * gpu_blocks_per_cpu_block) |
| 76 | + cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings) |
| 77 | + |
| 78 | + # convert cpu blocks to gpu block size |
| 79 | + cpu_blocks_in_gpu_block_size = [] |
| 80 | + for cpu_block in cpu_blocks: |
| 81 | + base_block_id = cpu_block * gpu_blocks_per_cpu_block |
| 82 | + for i in range(gpu_blocks_per_cpu_block): |
| 83 | + cpu_blocks_in_gpu_block_size.append(i + base_block_id) |
| 84 | + |
| 85 | + # set transfer direction |
| 86 | + if gpu_to_cpu: |
| 87 | + src_kv_caches = gpu_tensors |
| 88 | + dst_kv_caches = cpu_tensors |
| 89 | + src_block_size = gpu_block_size |
| 90 | + dst_block_size = cpu_block_size |
| 91 | + src_spec_class = GPULoadStoreSpec |
| 92 | + dst_spec_class = CPULoadStoreSpec |
| 93 | + src_blocks = gpu_blocks |
| 94 | + dst_blocks = cpu_blocks |
| 95 | + src_blocks_in_gpu_block_size = gpu_blocks |
| 96 | + dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size |
| 97 | + dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block |
| 98 | + else: |
| 99 | + src_kv_caches = cpu_tensors |
| 100 | + dst_kv_caches = gpu_tensors |
| 101 | + src_block_size = cpu_block_size |
| 102 | + dst_block_size = gpu_block_size |
| 103 | + src_spec_class = CPULoadStoreSpec |
| 104 | + dst_spec_class = GPULoadStoreSpec |
| 105 | + src_blocks = cpu_blocks |
| 106 | + dst_blocks = gpu_blocks |
| 107 | + src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size |
| 108 | + dst_blocks_in_gpu_block_size = gpu_blocks |
| 109 | + dst_size_in_gpu_blocks = num_gpu_blocks |
| 110 | + |
| 111 | + # build dst -> src mapping |
| 112 | + dst_to_src = {} |
| 113 | + for src_block, dst_block in zip(src_blocks_in_gpu_block_size, |
| 114 | + dst_blocks_in_gpu_block_size): |
| 115 | + dst_to_src[dst_block] = src_block |
| 116 | + |
| 117 | + # build transfer specs |
| 118 | + src_specs = [src_spec_class(block_id) for block_id in src_blocks] |
| 119 | + dst_specs = [dst_spec_class(block_id) for block_id in dst_blocks] |
| 120 | + |
| 121 | + # create transfer function |
| 122 | + transfer_func = generate_tensors_transfer_function(src_kv_caches, |
| 123 | + dst_kv_caches, |
| 124 | + attn_backend, |
| 125 | + src_block_size, |
| 126 | + dst_block_size) |
| 127 | + |
| 128 | + # clone src and dst tensors before transfer |
| 129 | + orig_src_caches = [x.clone() for x in src_kv_caches] |
| 130 | + orig_dst_caches = [x.clone() for x in dst_kv_caches] |
| 131 | + |
| 132 | + # call transfer function |
| 133 | + assert transfer_func((src_specs, dst_specs)) is True |
| 134 | + |
| 135 | + # verify src tensors did not change |
| 136 | + for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches): |
| 137 | + assert torch.equal(orig_tensor, tensor) |
| 138 | + |
| 139 | + # verify dst tensors |
| 140 | + for dst_block in range(dst_size_in_gpu_blocks): |
| 141 | + src_block_candidate = dst_to_src.get(dst_block) |
| 142 | + for src_cache, dst_cache, orig_dst_cache in zip( |
| 143 | + src_kv_caches, dst_kv_caches, orig_dst_caches): |
| 144 | + # iterate over key, value |
| 145 | + for i in range(2): |
| 146 | + if src_block_candidate is not None: |
| 147 | + expected_value = src_cache[i][src_block_candidate] |
| 148 | + else: |
| 149 | + expected_value = orig_dst_cache[i][dst_block] |
| 150 | + torch.testing.assert_close(dst_cache[i][dst_block].cpu(), |
| 151 | + expected_value.cpu()) |
0 commit comments