From 1daa1c3541c62f64dd14d05f373f1a2f6a615fdf Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Fri, 4 Apr 2025 17:26:00 +0000 Subject: [PATCH 1/7] apply patch for dynamo --- vllm/config.py | 39 +- vllm/core/block/cpu_gpu_block_allocator.py | 10 +- vllm/core/block/naive_block.py | 9 +- vllm/core/block/prefix_caching_block.py | 24 +- vllm/core/block_manager.py | 29 +- vllm/core/event_manager.py | 108 +++++ vllm/core/scheduler.py | 149 ++++++- .../device_communicators/kv_rearrange.py | 108 +++++ vllm/distributed/device_communicators/nixl.py | 414 ++++++++++++++++++ .../kv_connector/dynamo_connector.py | 350 +++++++++++++++ .../kv_transfer/kv_connector/factory.py | 11 +- .../kv_connector/simple_connector.py | 159 +++++-- .../kv_lookup_buffer/simple_buffer.py | 154 ++++--- vllm/distributed/kv_transfer/kv_pipe/base.py | 4 +- .../kv_transfer/kv_pipe/dynamo_nccl_pipe.py | 124 ++++++ .../kv_transfer/kv_pipe/pynccl_pipe.py | 75 ++-- .../kv_transfer/kv_transfer_agent.py | 3 +- vllm/distributed/parallel_state.py | 3 +- vllm/engine/llm_engine.py | 183 +++++++- vllm/engine/multiprocessing/__init__.py | 21 +- vllm/engine/multiprocessing/client.py | 130 +++++- vllm/engine/multiprocessing/engine.py | 146 +++++- vllm/entrypoints/openai/serving_chat.py | 3 + vllm/envs.py | 20 + vllm/model_executor/models/deepseek_v2.py | 2 + vllm/outputs.py | 4 +- vllm/remote_prefill.py | 66 +++ vllm/sampling_params.py | 2 +- vllm/sequence.py | 34 +- vllm/worker/model_runner.py | 6 + vllm/worker/worker.py | 53 ++- vllm/worker/worker_base.py | 152 +++++-- 32 files changed, 2343 insertions(+), 252 deletions(-) create mode 100644 vllm/core/event_manager.py create mode 100644 vllm/distributed/device_communicators/kv_rearrange.py create mode 100644 vllm/distributed/device_communicators/nixl.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py create mode 100644 vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py create mode 100644 vllm/remote_prefill.py diff --git a/vllm/config.py b/vllm/config.py index 2912361ee35e..99261d13798d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3091,6 +3091,9 @@ class KVTransferConfig(BaseModel): # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + # Whether to use NIXL prepped xfer for KV cache transfer. + use_prepped_xfer: bool = False + # The device used by kv connector to buffer the KV cache. # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" @@ -3100,7 +3103,7 @@ class KVTransferConfig(BaseModel): kv_buffer_size: float = 1e9 # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. + # are 'kv_producer', 'kv_consumer', and 'kv_both'. kv_role: Optional[str] = None # The rank of this vLLM instance in the KV cache transfer. Typical value: @@ -3121,6 +3124,13 @@ class KVTransferConfig(BaseModel): # any extra config that the connector may need kv_connector_extra_config: dict[str, Any] = {} + # This does not need to be set by the user. It is set by the connector. + kv_producers_parallel_size: Optional[int] = None + kv_producers_tensor_parallel_size: Optional[int] = None + kv_producers_pipeline_parallel_size: Optional[int] = None + kv_consumers_tensor_parallel_size: Optional[int] = None + kv_consumers_pipeline_parallel_size: Optional[int] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -3155,16 +3165,29 @@ def model_post_init(self, __context: Any) -> None: f"Supported roles are `kv_producer`, `kv_consumer`, " f"and `kv_both`") - if self.kv_connector is not None and self.kv_role is None: + if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " "is set, supported roles are `kv_producer`, " "`kv_consumer`, and `kv_both`") + if self.use_prepped_xfer is False: + logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.") + self.use_prepped_xfer = True + + @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + @property + def need_kv_parallel_group(self) -> bool: + # for those database-based connector, vLLM does not need to create + # parallel group, and in that case the kv parallel size will be 1. + if self.kv_connector == "DynamoNixlConnector": + return False + return self.kv_connector is not None and self.kv_parallel_size > 1 + @property def is_kv_producer(self) -> bool: return self.kv_connector is not None and \ @@ -3178,6 +3201,18 @@ def is_kv_consumer(self) -> bool: def get_from_extra_config(self, key, default) -> Any: return self.kv_connector_extra_config.get(key, default) + @property + def tensor_parallel_multiplier(self) -> int: + return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size + + @property + def kv_consumers_parallel_size(self) -> int: + return self.kv_parallel_size - self.kv_producers_parallel_size + + @property + def kv_world_size(self) -> int: + return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier + class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index d64142e77f37..9403b13b4f8b 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -6,6 +6,7 @@ DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.core.event_manager import KVCacheEventManager from vllm.platforms import current_platform from vllm.utils import Device @@ -28,6 +29,7 @@ def create( num_gpu_blocks: int, num_cpu_blocks: int, block_size: int, + event_manager: Optional[KVCacheEventManager] = None, ) -> DeviceAwareBlockAllocator: """Creates a CpuGpuBlockAllocator instance with the specified configuration. @@ -64,6 +66,7 @@ def create( cpu_block_ids = block_ids[num_gpu_blocks:] if allocator_type == "naive": + assert event_manager is None, "Event API not supported with naive allocator." gpu_allocator: BlockAllocator = NaiveBlockAllocator( create_block=NaiveBlock, # type: ignore num_blocks=num_gpu_blocks, @@ -82,12 +85,14 @@ def create( num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, + event_manager=event_manager, ) cpu_allocator = PrefixCachingBlockAllocator( num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, + event_manager=event_manager, ) else: raise ValueError(f"Unknown allocator type {allocator_type=}") @@ -95,10 +100,12 @@ def create( return CpuGpuBlockAllocator( cpu_block_allocator=cpu_allocator, gpu_block_allocator=gpu_allocator, + event_manager=event_manager, ) def __init__(self, cpu_block_allocator: BlockAllocator, - gpu_block_allocator: BlockAllocator): + gpu_block_allocator: BlockAllocator, + event_manager: Optional[KVCacheEventManager] = None,): assert not ( cpu_block_allocator.all_block_ids & gpu_block_allocator.all_block_ids @@ -108,6 +115,7 @@ def __init__(self, cpu_block_allocator: BlockAllocator, Device.CPU: cpu_block_allocator, Device.GPU: gpu_block_allocator, } + self.event_manager = event_manager self._swap_mapping: Dict[int, int] = {} self._null_block: Optional[Block] = None diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index c388366b825f..31ed7aa44ada 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -2,7 +2,7 @@ from collections import deque from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union - +import heapq from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device @@ -38,7 +38,7 @@ def __init__( if block_ids is None: block_ids = range(num_blocks) - self._free_block_indices: Deque[BlockId] = deque(block_ids) + self._free_block_indices: List[BlockId] = list(block_ids) self._all_block_indices = frozenset(block_ids) assert len(self._all_block_indices) == num_blocks @@ -134,7 +134,8 @@ def _allocate_block_id(self) -> BlockId: if not self._free_block_indices: raise BlockAllocator.NoFreeBlocksError() - block_id = self._free_block_indices.popleft() + block_id = heapq.heappop(self._free_block_indices) + # TODO: figure out why sometime block_id is None self._refcounter.incr(block_id) return block_id @@ -148,7 +149,7 @@ def _free_block_id(self, block: Union[Block, BlockId]) -> None: refcount = self._refcounter.decr(block_id) if refcount == 0: - self._free_block_indices.appendleft(block_id) + heapq.heappush(self._free_block_indices, block_id) def free(self, block: Block, keep_block_object: bool = False) -> None: # Release the physical block id diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 1ca9e49dac37..cd780f698859 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,7 +4,7 @@ from bisect import bisect_left from os.path import commonprefix from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, - Tuple) + Tuple, TYPE_CHECKING) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -23,6 +23,9 @@ # then we know this block hasn't been accessed yet. _DEFAULT_LAST_ACCESSED_TIME = -1 +if TYPE_CHECKING: + from vllm.core.event_manager import KVCacheEventManager + logger = init_logger(__name__) @@ -80,6 +83,7 @@ def __init__( block_size: int, block_ids: Optional[Iterable[int]] = None, eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + event_manager: Optional["KVCacheEventManager"] = None, ): if block_ids is None: block_ids = range(num_blocks) @@ -131,6 +135,9 @@ def __init__( self.metric_data = CacheMetricData() + self.event_manager = event_manager + + # Implements Block.Factory. def _create_block( self, prev_block: Optional[Block], @@ -337,6 +344,9 @@ def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: assert self._refcounter.get(_block_id) == 0 assert _block_id == block_id + if self.event_manager: + self.event_manager.enqueue_removed_event(content_hash_to_evict) + self._cached_blocks.pop(content_hash_to_evict) self._refcounter.incr(block_id) @@ -513,6 +523,10 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: # Mark this block as touched so that it can be marked as # computed after the entire batch of sequences are scheduled. self._touched_blocks.add(block.block_id) + + if self.event_manager: + self.event_manager.enqueue_stored_event(block.prev_block, block) + return block.block_id # Reuse the cached content hash @@ -579,9 +593,11 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: # Mark all touched blocks as computed. - for block_id in self._touched_blocks: - self._block_tracker[block_id].computed = True - self._touched_blocks.clear() + for block_id in block_ids: + if block_id in self._touched_blocks: + logger.debug("Mark block as computed: %s", block_id) + self._block_tracker[block_id].computed = True + self._touched_blocks.remove(block_id) def _track_block_id(self, block_id: Optional[BlockId], computed: bool) -> None: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index c6bf6d163132..b922adf87db1 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -10,7 +10,10 @@ from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, LastAccessBlocksTracker) from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.event_manager import KVCacheEventManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE, + VLLM_WORKER_ID) from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): def __init__( self, + model_name: str, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, @@ -91,11 +95,29 @@ def __init__( self.watermark_blocks = int(watermark * num_gpu_blocks) + kv_event_manager_params = [ + VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE, + VLLM_KV_COMPONENT + ] + set_kv_event_manager_params = len( + [param for param in kv_event_manager_params if param is not None]) + + if set_kv_event_manager_params == len(kv_event_manager_params): + self.event_manager = KVCacheEventManager( + namespace=VLLM_KV_NAMESPACE, + component=VLLM_KV_COMPONENT, + worker_id=VLLM_WORKER_ID, + lib_path=VLLM_KV_CAPI_PATH, + kv_block_size=block_size) + else: + self.event_manager = None + self.block_allocator = CpuGpuBlockAllocator.create( allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, + event_manager=self.event_manager, ) self.block_tables: Dict[SeqId, BlockTable] = {} @@ -108,7 +130,8 @@ def __init__( def can_allocate(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: + num_lookahead_slots: int = 0, + is_remote_decode: bool = False) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -121,6 +144,10 @@ def can_allocate(self, num_lookahead_slots=num_lookahead_slots, ) + # if remote decode, we need to allocate twice as many blocks for staging + if is_remote_decode: + num_required_blocks *= 2 + if seq_group.is_encoder_decoder(): encoder_seq = seq_group.get_encoder_seq() assert encoder_seq is not None diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py new file mode 100644 index 000000000000..a27af5808a42 --- /dev/null +++ b/vllm/core/event_manager.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +import ctypes +import logging +import uuid +from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64 +from typing import Optional + +from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash + +logger = logging.getLogger(__name__) + + +class DynamoResult: + OK = 0 + ERR = 1 + + +class KVCacheEventManager: + + def __init__(self, namespace: str, component: str, worker_id: int, + lib_path: str, kv_block_size: int): + self.lib = None + + try: + self.lib = ctypes.CDLL(lib_path) + self.lib.dynamo_llm_init.argtypes = [ + c_char_p, + c_char_p, + c_int64, + c_uint32, + ] + self.lib.dynamo_llm_init.restype = c_uint32 + + result = self.lib.dynamo_llm_init( + namespace.encode(), component.encode(), worker_id, kv_block_size + ) + if result == DynamoResult.OK: + logger.info( + "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events" + ) + else: + logger.info("KVCacheEventManager initialization failed!") + + except Exception as e: + print(f"Failed to load {lib_path}") + raise e + + self.lib.dynamo_kv_event_publish_stored.argtypes = [ + ctypes.c_uint64, # event_id + ctypes.POINTER(ctypes.c_uint32), # token_ids + ctypes.POINTER(ctypes.c_size_t), # num_block_tokens + ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.c_size_t, # num_blocks + ctypes.POINTER(ctypes.c_uint64), # parent_hash + ctypes.c_uint64, # lora_id + ] + self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t + + self.lib.dynamo_kv_event_publish_removed.argtypes = [ + ctypes.c_uint64, # event_id + ctypes.POINTER(ctypes.c_uint64), # block_ids + ctypes.c_size_t, # num_blocks + ] + self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t + + self.event_id_counter = 0 + + def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], + block: PrefixCachingBlock): + token_ids_arr = (ctypes.c_uint32 * + len(block.token_ids))(*block.token_ids) + num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids)) + block_hash = (ctypes.c_uint64 * 1)(block.content_hash) + parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) + if parent is not None else None) + + # Publish the event + result = self.lib.dynamo_kv_event_publish_stored( + self.event_id_counter, # uint64_t event_id + token_ids_arr, # const uint32_t *token_ids + num_block_tokens, # const uintptr_t *num_block_tokens + block_hash, # const uint64_t *block_ids + 1, # uintptr_t num_blocks + parent_hash, # const uint64_t *parent_hash + 0, # uint64_t lora_id + ) + + if result == DynamoResult.OK: + logger.debug(f"Store - Published KV Event: {block.content_hash}") + else: + logger.debug( + f"Store - Failed to Publish KV Event: {block.content_hash}") + + self.event_id_counter += 1 + + def enqueue_removed_event(self, block_hash: PrefixHash): + result = self.lib.dynamo_kv_event_publish_removed( + self.event_id_counter, + (ctypes.c_uint64 * 1)(block_hash), + 1, + ) + + if result == DynamoResult.OK: + logger.debug(f"Remove - Published KV Event: {block_hash}") + else: + logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}") + + self.event_id_counter += 1 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index cf85a2135c81..fb8ce03b574e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,13 +4,14 @@ import os import random import time +import copy from collections import deque from dataclasses import dataclass, field from typing import Callable, Deque, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence -from typing import Set, Tuple, Union +from typing import Set, Tuple, Union, Any -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -292,6 +293,7 @@ class SchedulerPrefillOutputs: # Ignored sequence groups. ignored_seq_groups: List[SequenceGroup] num_lookahead_slots: int + num_remote_prefill_groups: int @classmethod def create_empty(cls) -> "SchedulerPrefillOutputs": @@ -299,6 +301,7 @@ def create_empty(cls) -> "SchedulerPrefillOutputs": seq_groups=[], ignored_seq_groups=[], num_lookahead_slots=0, + num_remote_prefill_groups=0, ) @@ -426,12 +429,14 @@ class Scheduler: def __init__( self, + model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, output_proc_callback: Optional[Callable] = None, ) -> None: + self.model_config = model_config self.scheduler_config = scheduler_config self.cache_config = cache_config # Note for LoRA scheduling: the current policy is extremely @@ -457,6 +462,7 @@ def __init__( # Create the block space manager. self.block_manager = BlockSpaceManagerImpl( + model_name=self.model_config.served_model_name, block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, @@ -473,6 +479,16 @@ def __init__( # Sequence groups in the SWAPPED state. # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() + + # Sequence groups in the REMOTE_PREFILLING state. + # Contain requests that are being prefilled by a remote worker. + self.remote_prefilling: Deque[SequenceGroup] = deque() + # Contain requests that are being prefilled by a local worker. + self.prefill_sending: Deque[SequenceGroup] = deque() + + self._remote_prefill_outputs: Dict[str, int] = {} + + # Sequence groups finished requests ids since last step iteration. # It lets the model know that any state associated with these requests # can and must be released after the current step. @@ -628,8 +644,8 @@ def _free_seq_group_cross_attn_blocks( self.block_manager.free_cross(seq_group) def has_unfinished_seqs(self) -> bool: - return (len(self.waiting) != 0 or len(self.running) != 0 - or len(self.swapped) != 0) + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0 def get_prefix_cache_hit_rate(self, device: Device) -> float: return self.block_manager.get_prefix_cache_hit_rate(device) @@ -652,6 +668,8 @@ def _schedule_running( curr_loras: Optional[Set[int]], enable_chunking: bool = False, partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, + finished_prefills: Optional[Set[str]] = None, + finished_transfers: Optional[Set[str]] = None ) -> SchedulerRunningOutputs: """Schedule sequence groups that are running. @@ -668,6 +686,8 @@ def _schedule_running( all tokens. partial_prefill_metadata: information about the partial prefills that are currently running + finished_remote_prefill_request_ids: Set of request ids of remote + prefills that have finished. Returns: SchedulerRunningOutputs. @@ -697,6 +717,38 @@ def _schedule_running( preempted: List[SequenceGroup] = ret.preempted swapped_out: List[SequenceGroup] = ret.swapped_out + remote_prefilling_queue = self.remote_prefilling + leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque() + while remote_prefilling_queue: + seq_group = remote_prefilling_queue.popleft() + if seq_group.request_id not in finished_prefills: + leftover_remote_prefilling_sequences.append(seq_group) + continue + + else: + finished_prefills.remove(seq_group.request_id) + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + # we computed all but the last token in prefill, we need to decode the first token on decode + seq_group.update_num_computed_tokens(seq.get_len() - 1) + seq.status = SequenceStatus.RUNNING + seq.data._stage = SequenceStage.DECODE + self.running.appendleft(seq_group) + remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences) + + remote_transfers_queue = self.prefill_sending + leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque() + while remote_transfers_queue: + seq_group = remote_transfers_queue.popleft() + if seq_group.request_id not in finished_transfers: + leftover_remote_transfers_sequences.append(seq_group) + else: + finished_transfers.remove(seq_group.request_id) + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + self.free_seq(seq) + remote_transfers_queue.extendleft(leftover_remote_transfers_sequences) + running_queue = self.running assert len(self._async_stopped) == 0 while running_queue: @@ -1073,6 +1125,7 @@ def _schedule_prefills( seq_groups: List[ScheduledSequenceGroup] = [] waiting_queue = self.waiting + num_remote_prefill_groups = 0 leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: @@ -1121,8 +1174,10 @@ def _schedule_prefills( True, enable_chunking) # If the sequence group cannot be allocated, stop. + is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode can_allocate = self.block_manager.can_allocate( - seq_group, num_lookahead_slots=num_lookahead_slots) + seq_group, num_lookahead_slots=num_lookahead_slots, + is_remote_decode=is_remote_decode) if can_allocate == AllocStatus.LATER: break elif can_allocate == AllocStatus.NEVER: @@ -1170,7 +1225,18 @@ def _schedule_prefills( if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) waiting_queue.popleft() - self._allocate_and_set_running(seq_group) + + seq_group_copy = copy.deepcopy(seq_group) + seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1 + + logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id) + logger.debug("Seq id: %s", seq_group.seqs[0].seq_id) + is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group) + num_remote_prefill_groups += is_remote_prefill + if is_remote_decode: + logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id) + self._allocate_and_set_running_or_remote_prefill(seq_group_copy) + self.prefill_sending.append(seq_group_copy) if partial_prefill_metadata is not None: partial_prefill_metadata.maybe_increment_partial_prefills( @@ -1214,9 +1280,10 @@ def _schedule_prefills( ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=self._get_num_lookahead_slots( is_prefill=True, enable_chunking=enable_chunking), + num_remote_prefill_groups=num_remote_prefill_groups ) - def _schedule_default(self) -> SchedulerOutputs: + def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: """Schedule queued requests. The current policy is designed to optimize the throughput. First, @@ -1234,6 +1301,9 @@ def _schedule_default(self) -> SchedulerOutputs: for seq_group in self.running: budget.add_num_seqs(seq_group.request_id, seq_group.get_max_num_running_seqs()) + for seq_group in self.remote_prefilling: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) curr_loras = (set( seq_group.lora_int_id for seq_group in self.running if seq_group.lora_int_id > 0) if self.lora_enabled else None) @@ -1255,10 +1325,12 @@ def _schedule_default(self) -> SchedulerOutputs: # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. - if len(prefills.seq_groups) == 0: + if len(prefills.seq_groups) == prefills.num_remote_prefill_groups: running_scheduled = self._schedule_running(budget, curr_loras, - enable_chunking=False) + enable_chunking=False, + finished_prefills=finished_prefills, + finished_transfers=finished_transfers) # If any sequence group is preempted, do not swap in any sequence # group. because it means there's no slot for new running requests. @@ -1275,7 +1347,12 @@ def _schedule_default(self) -> SchedulerOutputs: self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. if len(prefills.seq_groups) > 0: - self.running.extend([s.seq_group for s in prefills.seq_groups]) + for s in prefills.seq_groups: + seq_group = s.seq_group + if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + self.remote_prefilling.append(seq_group) + else: + self.running.append(seq_group) self.running.extend(running_scheduled.decode_seq_groups_list) @@ -1452,12 +1529,14 @@ def _order_finishing_prefills_first( ] return finishing + not_finishing - def _schedule(self) -> SchedulerOutputs: + def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: """Schedule queued requests.""" if self.scheduler_config.chunked_prefill_enabled: + if finished_prefills or finished_transfers: + raise ValueError("Chunked prefill does not support remote prefills") return self._schedule_chunked_prefill() else: - return self._schedule_default() + return self._schedule_default(finished_prefills, finished_transfers) def _can_append_slots(self, seq_group: SequenceGroup, enable_chunking: bool) -> bool: @@ -1491,14 +1570,16 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: return no_single_seq def schedule( - self + self, + finished_prefills: Optional[Set[str]] = None, + finished_transfers: Optional[Set[str]] = None ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. - scheduler_start_time = time.perf_counter() - scheduler_outputs: SchedulerOutputs = self._schedule() + scheduler_start_time = time.perf_counter() + scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers) now = time.time() if not self.cache_config.enable_prefix_caching: @@ -1537,7 +1618,8 @@ def schedule( encoder_seq_data = None cross_block_table = None - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING) + for seq in running_or_remote_prefilling_seqs: seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) @@ -1546,7 +1628,9 @@ def schedule( if self.cache_config.enable_prefix_caching: common_computed_block_nums = ( self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) + running_or_remote_prefilling_seqs + ) + ) do_sample = True is_prompt = seq_group.is_prefill() @@ -1568,9 +1652,29 @@ def schedule( < seqs[0].data.get_len()): do_sample = False + is_remote_prefill = False + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + is_remote_prefill = True + logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums) + if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: + block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids + + # Since we know that prefill is scheduled we can + # assume that the blocks computed on decode + # will be fetched by the time we run prefill + logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids) + if seq_group.remote_prefill_params.decode_computed_block_ids: + computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids) + prefill_block_ids = block_tables[seq_group.seqs[0].seq_id] + prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)] + + assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching" + common_computed_block_nums = prefill_fetched_block_ids + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: + logger.debug("Assinged blocks: %s", block_tables) seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=is_prompt, @@ -1598,6 +1702,7 @@ def schedule( if scheduler_outputs.num_prefill_groups > 0 else None), mm_processor_kwargs=seq_group.mm_processor_kwargs, prompt_adapter_request=seq_group.prompt_adapter_request, + do_remote_prefill=is_remote_prefill, ) else: # When SPMD mode is enabled, we only send delta data except for @@ -1696,10 +1801,16 @@ def free_finished_seq_groups(self) -> None: self._async_stopped.clear() - def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool: self.block_manager.allocate(seq_group) + is_remote_prefill = False for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - seq.status = SequenceStatus.RUNNING + if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: + seq.status = SequenceStatus.REMOTE_PREFILLING + is_remote_prefill = True + else: + seq.status = SequenceStatus.RUNNING + return is_remote_prefill def _append_slots( self, diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py new file mode 100644 index 000000000000..a8a88ab1f497 --- /dev/null +++ b/vllm/distributed/device_communicators/kv_rearrange.py @@ -0,0 +1,108 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def rearrange_kernel_read( + t1_ptr, + t2_ptr, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + curr_n = offsets // block_size + curr_b = offsets // token_size % B + curr_h = offsets // C % H + curr_c = offsets % C + + src_pos = offsets + + tp_group = curr_h * d // H + dst_h = curr_h % (H // d) + tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c + + dst_pos = tensor_subset_size * tp_group + tp_group_offset + + tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos)) + +@triton.jit +def rearrange_kernel_write( + t1_ptr, + t2_ptr, + N, + B, + H, + C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + curr_n = offsets // block_size + curr_b = offsets // token_size % B + curr_h = offsets // C % H + curr_c = offsets % C + + src_pos = offsets + + tp_group = curr_h * d // H + dst_h = curr_h % (H // d) + tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c + + dst_pos = tensor_subset_size * tp_group + tp_group_offset + + tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) + +def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str): + N, B, H, C = t1.shape + + assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source" + assert H % d == 0, "H must be divisible by d" + + block_size = B * H * C + token_size = H * C + tensor_size = N * block_size + tensor_subset_size = tensor_size // d + + BLOCK_SIZE = 1024 + grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,) + + if direction == "read": + rearrange_kernel_read[grid]( + t1, t2, + N, B, H, C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE=BLOCK_SIZE + ) + elif direction == "write": + rearrange_kernel_write[grid]( + t1, t2, + N, B, H, C, + d, + tensor_subset_size, + block_size, + token_size, + BLOCK_SIZE=BLOCK_SIZE + ) + else: + raise ValueError(f"Invalid direction: {direction}") diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py new file mode 100644 index 000000000000..30f529824f08 --- /dev/null +++ b/vllm/distributed/device_communicators/nixl.py @@ -0,0 +1,414 @@ +import torch +from typing import List +from vllm.config import VllmConfig +from vllm.logger import init_logger +import msgspec +import time +import uuid +from collections import defaultdict +from .kv_rearrange import rearrange_tensors + +logger = init_logger(__name__) + +# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used +try: + from nixl._api import nixl_agent as NixlWrapper + logger.info("NIXL is available") +except ImportError: + logger.warning("NIXL is not available") + NixlWrapper = None + +class NixlMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + engine_id: str + agent_metadata: List[bytes] + kv_caches_base_addr: List[List[List[int]]] # base address for each rank for each layer for keys and values + num_blocks: int + + +class DynamoNixlConnector: + def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): + self.vllm_config = vllm_config + if NixlWrapper is None: + logger.error("NIXL is not available") + raise RuntimeError("NIXL is not available") + logger.info("Initializing NIXL wrapper") + self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) + + self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer + + self.num_layers = None + self.num_blocks = None + self.num_heads = None + self.block_len = None + self.kv_caches = None + self.kv_caches_base_addr = {} + self.kv_cache_shape = {} + + self._registered_descs = [] + self._remote_agents = {} + self.engine_id = engine_id + self.rank = rank + self._tp_size = {} + self.src_xfer_side_handles = {} + self.dst_xfer_side_handles = defaultdict(dict) + self.dst_num_blocks = {} + self.use_mla = vllm_config.model_config.use_mla + + self._transfers = defaultdict(list) + + + self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size + + + @property + def agent_name(self): + return self.nixl_wrapper.name + + def register_kv_caches(self, kv_caches: List[torch.Tensor]): + if self.use_mla: + num_blocks, block_size, head_dim = kv_caches[0].shape + self.block_len = block_size * head_dim * kv_caches[0].element_size() + else: + _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape + self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + self.num_heads = num_heads + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + self.num_layers = len(kv_caches) + self.num_blocks = num_blocks + self.kv_caches = kv_caches + kv_caches_base_addr = [] + caches_data = [] + if self.use_mla: + for key_cache in kv_caches: + base_addr = key_cache.data_ptr() + region_len = num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append([key_cache.data_ptr()]) + else: + for key_cache, value_cache in kv_caches: + base_addr = key_cache.data_ptr() + region_len = 2 * num_blocks * self.block_len + caches_data.append((base_addr, region_len, self.rank, "")) + kv_caches_base_addr.append([key_cache.data_ptr(), value_cache.data_ptr()]) + + self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr + + descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + logger.debug("Registering descs: %s", caches_data) + self.nixl_wrapper.register_memory(descs) + self._registered_descs.append(descs) + + def get_agent_metadata(self): + return self.nixl_wrapper.get_agent_metadata() + + def shutdown(self): + for descs_list in self._registered_descs: + self.nixl_wrapper.deregister_memory(descs_list) + for agent_names in self._remote_agents.values(): + for agent_name in agent_names: + self.nixl_wrapper.remove_remote_agent(agent_name) + for src_xfer_side_handle in self.src_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values(): + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) + + def _get_ranges(self, block_ids): + # This function should return a list of ranges of block ids that are contiguous + # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] + # The ranges are sorted by the starting block id + # The function should also make sure that the block ids are contiguous + # If the block ids are not contiguous, the function should raise an error + ranges = [] + for i in range(len(block_ids)): + if i == 0 or block_ids[i] != block_ids[i-1] + 1: + ranges.append([block_ids[i], block_ids[i]]) + else: + ranges[-1][1] = block_ids[i] + return ranges + + def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None): + + if layer_ids == "all": + layer_ids = list(range(self.num_layers)) + if block_ids == "all": + block_ids = list(range(self.num_blocks)) + + descs_ids = [] + + + if i is not None: + num_blocks = self.num_blocks + for layer_id in layer_ids: + if self.use_mla: + staging_range_idx = 0 + for block_id in block_ids: + if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]: + staging_range_idx += 1 + start_offset = staging_ranges[staging_range_idx][0] + i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1) + descs_ids.append(layer_id * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) + else: + for is_value in [0, 1]: + staging_range_idx = 0 + for block_id in block_ids: + if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]: + staging_range_idx += 1 + start_offset = staging_ranges[staging_range_idx][0] + i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1) + descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) + else: + num_blocks = self.dst_num_blocks[engine_id] + for layer_id in layer_ids: + if self.use_mla: + for block_id in block_ids: + descs_ids.append(layer_id * num_blocks + block_id) + else: + for is_value in [0, 1]: + for block_id in block_ids: + descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) + return descs_ids + + def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False): + # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length + # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]] + # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]]) + src_overlapping_ranges, dst_overlapping_ranges = [], [] + + original_src_ranges = [] + org_src_range = tuple(src_ranges[0]) + + src_idx, dst_idx = 0, 0 + while src_idx < len(src_ranges) and dst_idx < len(dst_ranges): + src_range = src_ranges[src_idx] + dst_range = dst_ranges[dst_idx] + + # Calculate the length of each range + src_len = src_range[-1] - src_range[0] + 1 + dst_len = dst_range[-1] - dst_range[0] + 1 + + # If ranges have the same length, add them directly + if src_len == dst_len: + src_overlapping_ranges.append([src_range[0], src_range[-1]]) + dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) + original_src_ranges.append(org_src_range) + src_idx += 1 + dst_idx += 1 + if src_idx < len(src_ranges): + org_src_range = tuple(src_ranges[src_idx]) + # If source range is longer, split it + elif src_len > dst_len: + src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1]) + dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) + original_src_ranges.append(org_src_range) + # Update source range for next iteration + src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]] + dst_idx += 1 + # If destination range is longer, split it + else: # src_len < dst_len + src_overlapping_ranges.append([src_range[0], src_range[-1]]) + dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1]) + original_src_ranges.append(org_src_range) + # Update destination range for next iteration + dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]] + src_idx += 1 + if src_idx < len(src_ranges): + org_src_range = tuple(src_ranges[src_idx]) + if return_original_src_ranges: + return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges + return src_overlapping_ranges, dst_overlapping_ranges + + def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id): + logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id) + + assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids) + + if len(local_block_ids) == 0: + logger.debug("No blocks to read") + return + + start_time = time.perf_counter() + + local_ranges = self._get_ranges(local_block_ids) + staging_ranges = self._get_ranges(staging_block_ids) + + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) + + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + handles = [] + + logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000) + create_xfer_start_time = time.perf_counter() + + for i in range(tp_multiplier): + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] + handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids, + remote_xfer_side_handle, remote_block_descs_ids, + "") + handles.append(handle) + status = self.nixl_wrapper.transfer(handle) + + logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) + + transfer_start_time = time.perf_counter() + + for handle in handles: + while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE": + if status == "PROC": + time.sleep(0.001) + else: + raise RuntimeError("Read transfer failed with state %s", status) + # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors? + + logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000) + + rearrange_start_time = time.perf_counter() + + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) + for kv_cache in self.kv_caches: + if self.use_mla: + cache_src = kv_cache[local_range[0]:local_range[1] + 1].unsqueeze(0) + cache_staging = kv_cache[staging_range[0]:staging_range[1] + 1].unsqueeze(0) + rearrange_tensors(cache_src, cache_staging, 1, "read") # for mla, the head size is not the same as flash attention, split not required here + else: + for cache in kv_cache: + rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read") + + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000) + logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000) + + def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg): + logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg) + + # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last + # isl[-1] token is calculated in the first iteration in decode. + # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ + # one less block due to the missing last token. + remote_block_ids = remote_block_ids[:len(local_block_ids)] + + assert len(staging_block_ids) == len(local_block_ids) + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + + if len(local_block_ids) == 0: + logger.debug("No blocks to write") + for i in range(tp_multiplier): + self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg) + return + + start_time = time.perf_counter() + + local_ranges = self._get_ranges(local_block_ids) + staging_ranges = self._get_ranges(staging_block_ids) + + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) + + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) + for kv_cache in self.kv_caches: + if self.use_mla: + cache_src = kv_cache[local_range[0]:local_range[1] + 1].unsqueeze(0) + cache_staging = kv_cache[staging_range[0]:staging_range[1] + 1].unsqueeze(0) + rearrange_tensors(cache_src, cache_staging, 1, "write") # for mla, the head size is not the same as flash attention, split not required here + else: + for cache in kv_cache: + rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write") + + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) + + create_xfer_start_time = time.perf_counter() + + # getting block descs ids + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + + for i in range(tp_multiplier): + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] + handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids, + remote_xfer_side_handle, remote_block_descs_ids, + notify_msg) + self._transfers[notify_msg].append(handle) + status = self.nixl_wrapper.transfer(handle) + + logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) + + # transfer_start_time = time.perf_counter() + logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000) + + def get_notifs(self): + return self.nixl_wrapper.update_notifs() + + def get_new_notifs(self): + return self.nixl_wrapper.get_new_notifs() + + def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks): + self._tp_size[engine_id] = agent_tp + agent_names = [] + for agent_meta in agent_metadata: + agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) + agent_names.append(agent_name) + self._remote_agents[engine_id] = agent_names + self.kv_caches_base_addr[engine_id] = kv_caches_base_addr + + tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id] + assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}" + + logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier) + dst_block_len = self.block_len // tp_multiplier + if tp_multiplier not in self.src_xfer_side_handles: + # create descs and xfer side handles + blocks_data = [] + for layer_id in range(self.num_layers): + for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]: + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for i in range(tp_multiplier): + tp_multiplier_offset = i * dst_block_len + blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs) + + # create dst xfer side handles + self.dst_num_blocks[engine_id] = num_blocks + for i in range(tp_multiplier): + blocks_data = [] + for layer_id in range(self.num_layers): + for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]: + for block_id in range(num_blocks): + block_offset = block_id * dst_block_len + blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i)) + logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i) + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs) + + return agent_names + + def get_done_tranfers(self) -> List[str]: + done_req_ids = [] + for req_id, handles in self._transfers.items(): + running_reqs = [] + for handle in handles: + xfer_state = self.nixl_wrapper.check_xfer_state(handle) + if xfer_state == "DONE": + # self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors? + continue + if xfer_state == "PROC": + running_reqs.append(handle) + else: + raise RuntimeError("Transfer failed with state %s", xfer_state) + if len(running_reqs) == 0: + done_req_ids.append(req_id) + else: + self._transfers[req_id] = running_reqs + return done_req_ids diff --git a/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py new file mode 100644 index 000000000000..7b3344f8a0a0 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or +MooncakePipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +import re +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, KVTransferConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class DynamoConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + world_group, + ): + + self.config = config.kv_transfer_config + self.tp_size = config.parallel_config.tensor_parallel_size + self.rank = rank + + if self.config.kv_connector != "DynamoNcclConnector": + raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class") + + from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( + PyNcclPipe) + from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import ( + DynamoNcclDataPlane) + + logger.info( + "Initializing DynamoNcclConnector under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_data_pipe: PyNcclPipe + self.consumer_data_pipe: PyNcclPipe + self.producer_signal_pipe: PyNcclPipe + self.consumer_signal_pipe: PyNcclPipe + + self._broadcast_and_enhance_kv_config(rank, config, world_group) + + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + + # 2 pipes for every rank in the world + if self.config.is_kv_producer: + port_offset_base = rank + 1 + else: + port_offset_base = rank // self.config.tensor_parallel_multiplier + 1 + + + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier + self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config) + + self.data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + + self.data_plane = DynamoNcclDataPlane( + data_pipe=self.data_pipe, + port=self._get_data_plane_port(self.global_kv_rank), + ) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + if not is_deepseek: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + else: + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(4.5 * hidden_size / num_attention_heads) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id) + decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config) + + for target_rank in range(self.config.tensor_parallel_multiplier): + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank + + if not is_deepseek: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + else: + key_cache = kv_cache + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(torch.empty(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + decode_global_rank = decode_first_global_rank + target_rank + decode_port = self._get_data_plane_port(decode_global_rank) + partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos] + self._send(decode_hostname, decode_port, current_request_id, keys, values, + partial_hidden_or_intermediate_states) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + start_pos_list = [] + + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self._recv(current_request_id) + keys: torch.Tensor = ret[0] + values: torch.Tensor = ret[1] + hidden: torch.Tensor = ret[2] + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + if not is_deepseek: + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + else: + key_cache = kv_cache + copy_from =keys[i - model_executable.model.start_layer].to( + key_cache.device) + kv_cache[slot_mapping[start_pos:end_pos]] = copy_from + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.data_pipe.close() + # self.data_plane.close() + + @staticmethod + def parse_request_id(request_id: str) -> Tuple[str, int]: + # Regular expression to match the string hostname and integer decode_kv_rank + pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + if match: + # Extract the ranks + decode_hostname = match.group(1) + decode_rank = int(match.group(2)) + + return decode_hostname, decode_rank + raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank") + + def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor): + remote_address = f"{hostname}:{port}" + self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address) + self.data_plane.send_tensor(values, f"{request_id}_values", remote_address) + self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address) + + def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + keys = self.data_plane.recv_tensor(f"{request_id}_keys") + values = self.data_plane.recv_tensor(f"{request_id}_values") + hidden = self.data_plane.recv_tensor(f"{request_id}_hidden") + return keys, values, hidden + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + + def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank <= config.kv_producers_parallel_size: + return kv_rank * config.kv_producers_tensor_parallel_size + rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank + + + def _get_data_plane_port(self, global_kv_rank: int) -> int: + return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index e37ce6dc75b0..33fed5e81e20 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -27,13 +27,13 @@ def loader() -> Type[KVConnectorBase]: @classmethod def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + config: "VllmConfig", world_group) -> KVConnectorBase: connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() - return connector_cls(rank, local_rank, config) + return connector_cls(rank, local_rank, config, world_group) # Register various connectors here. @@ -57,4 +57,9 @@ def create_connector(cls, rank: int, local_rank: int, KVConnectorFactory.register_connector( "MooncakeStoreConnector", "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") \ No newline at end of file + "MooncakeStoreConnector") + +KVConnectorFactory.register_connector( + "DynamoNcclConnector", + "vllm.distributed.kv_transfer.kv_connector.dynamo_connector", + "DynamoConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 49b97d7b5889..8f22ea9e4ff3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -8,14 +8,16 @@ But the logic can be extended to support other pipe and lookup buffer. """ +import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -34,6 +36,7 @@ def __init__( rank: int, local_rank: int, config: VllmConfig, + world_group, ): self.config = config.kv_transfer_config @@ -74,20 +77,31 @@ def __init__( self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self._broadcast_and_enhance_kv_config(rank, config, world_group) + + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + # 2 pipes for every rank in the world - port_offset_base = 2 * rank + if self.config.is_kv_producer: + port_offset_base = 2 * rank + 1 + else: + port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe if self.config.is_kv_producer: if self.config.kv_connector == "PyNcclConnector": self.producer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.producer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -111,11 +125,13 @@ def __init__( # its recv pipe to the send pipe of KV producder if self.config.kv_connector == "PyNcclConnector": self.consumer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.consumer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -134,21 +150,25 @@ def __init__( self.config.kv_buffer_size, ) - def select(self, input_tokens: Optional[torch.Tensor], + def select(self, source_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + logger.info("Selecting KV caches and hidden states for source rank %d", source_rank) + assert self.consumer_buffer is not None, "Please initialize the "\ "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) + return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank) + assert self.producer_buffer is not None, "Please initialize the "\ "producer buffer before calling insert." - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) def send_kv_caches_and_hidden_states( self, @@ -165,6 +185,7 @@ def send_kv_caches_and_hidden_states( num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) model_config = model_executable.model.config num_heads = int(model_config.num_key_value_heads / self.tp_size) @@ -207,31 +228,39 @@ def send_kv_caches_and_hidden_states( break current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + _, decode_kv_rank = self.parse_request_id(current_request_id) + starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) - keys, values = [], [] + for target_rank in range(self.config.tensor_parallel_multiplier): + + keys, values = [], [] - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + if self.is_deepseek_mla and self.use_mla_opt: + key_cache = kv_cache.reshape(-1, num_heads, head_size) + value_cache = kv_cache.reshape(-1, num_heads, head_size) + else: + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) + self.insert(starting_kv_group_rank, target_rank, current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) @@ -254,6 +283,7 @@ def recv_kv_caches_and_hidden_states( seq_lens = model_input.attn_metadata.seq_lens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) hidden_or_intermediate_states_for_one_req = [] @@ -261,6 +291,9 @@ def recv_kv_caches_and_hidden_states( num_computed_tokens_list = [] start_pos_list = [] + model_config = model_executable.model.config + is_deepseek = "deepseek" in model_config.architectures[0].lower() + # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): @@ -280,13 +313,15 @@ def recv_kv_caches_and_hidden_states( break current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + prefill_rank, _ = self.parse_request_id(current_request_id) num_tokens = slen # collecting data for rebuilding the input input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - ret = self.select(current_tokens, + ret = self.select(prefill_rank, current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. @@ -379,3 +414,77 @@ def close(self): # MooncakePipe reuses data_pipe for signal_pipe, so we only have to # close the data_pipe. pass + + @staticmethod + def parse_request_id(request_id): + # Regular expression to match the ranks + pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + + if match: + # Extract the ranks + prefill_rank = int(match.group(1)) + decode_rank = int(match.group(2)) + + return prefill_rank, decode_rank + else: + return None, None + + + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + if self.config.kv_connector == "PyNcclConnector": + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + + else: + raise NotImplementedError("MooncakeConnector is not supported in Dynamo patch") + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 10bbfe1ddd8a..b4bd40259420 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -10,8 +10,10 @@ stop the prefill instance when the decode instance is slow. """ import threading +import time from collections import deque -from typing import Deque, List, Optional, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union, Dict import torch @@ -45,7 +47,7 @@ def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, self.buffer_cv = threading.Condition() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None + self.request_handling_thread: Optional[ThreadPoolExecutor] = None self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None @@ -56,10 +58,16 @@ def _matches(self, tokens_roi_sender: List[torch.Tensor], # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] + target_rank_sender = tokens_roi_sender[0] + target_rank_recver = tokens_roi_recver[0] + + if target_rank_sender.item() != target_rank_recver.item(): + return 0 + + tokens_sender = tokens_roi_sender[1] + tokens_recver = tokens_roi_recver[1] + roi_sender = tokens_roi_sender[2] + roi_recver = tokens_roi_recver[2] if tokens_recver is None: # consumer sends an empty request @@ -79,14 +87,14 @@ def _matches(self, tokens_roi_sender: List[torch.Tensor], return 0 - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor], + target_rank: int) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: tensor = tensor.float() - self.data_pipe.send_tensor(tensor) + self.data_pipe.send_tensor(tensor, target_rank) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): @@ -99,7 +107,7 @@ def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor): @@ -114,7 +122,7 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - buffer_item = [input_tokens, roi, key, value, hidden] + buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden] data_size = sum([self._get_element_size(data) for data in buffer_item]) with self.buffer_cv: @@ -124,7 +132,6 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, logger.debug("KV transfer buffer is full. Handling...") while self.buffer_size + data_size > self.buffer_size_threshold: self.buffer_cv.wait() - self.buffer_size += data_size self.buffer.append(buffer_item) self.buffer_cv.notify() @@ -132,49 +139,43 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, def _is_end_signal(self, signal): return signal is None - def drop_select_handler(self): + def drop_select_handler(self, rank: int): try: - - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - assert roi is not None, "Please provide the roi when sending "\ - "drop-select request" - roi = (roi > 0.5) - tokens_roi_recver = [input_tokens, roi] - - def is_buffer_available( - tokens_roi_recver: List[torch.Tensor], ) -> bool: - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - for _ in range(len(self.buffer)): - if self._matches(self.buffer[0], - tokens_roi_recver) > 0: - return True - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - return False - - with self.buffer_cv: - while not is_buffer_available(tokens_roi_recver): - logger.debug( - "KV transfer buffer is not available. Waiting...") - self.buffer_cv.wait() - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - self.buffer_cv.notify() + signal = self.signal_pipe.recv_tensor(rank) + if self._is_end_signal(signal): + logger.info("Received end signal!") + return + + target_kv_rank = self.data_pipe.recv_tensor(rank) + # assert target_rank.item() == rank, "Target rank does not match"\ + # "the rank of the drop-select handler" + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [target_kv_rank, input_tokens, roi] + + def is_buffer_available( + tokens_roi_recver: List[torch.Tensor], ) -> bool: + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + for _ in range(len(self.buffer)): + if self._matches(self.buffer[0], + tokens_roi_recver) > 0: + return True + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + return False + + with self.buffer_cv: + while not is_buffer_available(tokens_roi_recver): + logger.debug( + "KV transfer buffer is not available. Waiting...") + self.buffer_cv.wait() except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -183,10 +184,10 @@ def is_buffer_available( logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], + self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - assert self.request_handling_thread is None, \ + assert not self.request_handling_thread, \ "drop_select should be called by the KV cache consumer "\ "(e.g. the decode vLLM instance)" @@ -195,40 +196,51 @@ def drop_select( if isinstance(roi, torch.Tensor): roi = roi.clone().float() - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) + self.signal_pipe.send_tensor(self.normal_signal, rank) + + self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) + self.data_pipe.send_tensor(input_tokens, rank) + self.data_pipe.send_tensor(roi, rank) - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) if roi is not None: # convert from float tensor to bool tensor # as PyNccl does not support sending bool tensor roi = (roi > 0.5) - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor(rank) + value = self.data_pipe.recv_tensor(rank) + hidden = self.data_pipe.recv_tensor(rank) return [input_tokens, roi, key, value, hidden] - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def full_handler(self): + time.sleep(0.001) + + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: - self._add_to_buffer(input_tokens, roi, key, value, hidden) + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. + target_rank_global = target_rank + kv_group_rank if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) - self.request_handling_thread.start() + self.request_handling_thread = ThreadPoolExecutor(max_workers=1) + self.request_handling_thread.submit(self.drop_select_handler, target_rank_global) def close(self): - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: - self.request_handling_thread.join() + if hasattr(self, "request_handling_thread") and self.request_handling_thread: + self.request_handling_thread.shutdown() else: # TODO: have a explicit close signal and have a explicit way to diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 40589fb3ef87..da2829cfcc56 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -23,7 +23,7 @@ class KVPipeBase(ABC): """ @abstractmethod - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: """Send a tensor, or None, via the pipe. Need to support sending None -- important for error handling. @@ -41,7 +41,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: raise NotImplementedError @abstractmethod - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """Receive a tensor (can be None) from the pipeline. Returns: diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py new file mode 100644 index 000000000000..3ee0fa78f4aa --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py @@ -0,0 +1,124 @@ +import logging +import threading +import typing +import zmq +import socket +import time +import torch + +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + + +logger = logging.getLogger(__name__) + + +class DynamoNcclDataPlane: + def __init__( + self, + data_pipe: PyNcclPipe, + hostname: str = "", + port: int = 0, + ) -> None: + + self.data_pipe = data_pipe + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + self._hostname = hostname + self._port = port + self.store = {} + self.context = zmq.Context() + self.rep_socket = self.context.socket(zmq.REP) + logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}") + self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}") + self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True) + self._listener_thread.start() + self.req_sockets = {} + logger.info(f"Rank {self.rank} connected to the server") + + @property + def rank(self): + return self.data_pipe.kv_group_rank + + def send_tensor( + self, + tensor: torch.Tensor, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ): + logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}") + return self._send_tensor(tensor, tensor_id, remote_address) + + def recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + ret = self._recv_tensor(tensor_id, remote_address) + return ret + + def _send_tensor( + self, + tensor: torch.Tensor, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ): + logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}") + if remote_address is None: + self.store[tensor_id] = tensor + else: + # tensor_shape = "_".join(str(dim) for dim in tensor.shape) + # tensor_dtype = str(tensor.dtype) + if remote_address not in self.req_sockets: + self.req_sockets[remote_address] = self.context.socket(zmq.REQ) + self.req_sockets[remote_address].connect(f"tcp://{remote_address}") + + req_socket = self.req_sockets[remote_address] + # req_socket.connect(f"tcp://{remote_address}") + req_socket.send_string(f"PUT {self.rank} {tensor_id}") + dst_rank = req_socket.recv_string() + logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}") + self.data_pipe.send_tensor(tensor, int(dst_rank)) + + def _recv_tensor( + self, + tensor_id: str, + remote_address: typing.Optional[str] = None, + ) -> torch.Tensor: + logger.debug(f"Rank {self.rank} receiving tensor") + if remote_address is not None: + raise NotImplementedError("Getting tensor from remote rank not implemented") + if tensor_id in self.store: + logger.debug(f"Popping tensor {tensor_id} from store") + future = self.store.pop(tensor_id) + tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait + logger.debug(f"Rank {self.rank} received tensor") + return tensor + + logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}") + time.sleep(0.001) + return self._recv_tensor(tensor_id, remote_address) + # raise NotImplementedError("Tensor not found in store") + + def _receive_tensor( + self, + tensor_id: str, + rank: int, + ): + future = self.data_pipe.recv_tensor(rank) + logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store") + self.store[tensor_id] = future + + def listen_for_requests(self): + while True: + cmd, rank, tensor_id = self.rep_socket.recv_string().split() + logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}") + self.rep_socket.send_string(f"{self.rank}") + if cmd == "GET": + raise NotImplementedError("Getting tensor from remote rank not implemented") + elif cmd == "PUT": + rank = int(rank) + # shape = [int(dim) for dim in shape.split("_")] + # dtype = getattr(torch, dtype) + self._receive_tensor(tensor_id, rank) diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index e8bf607eb899..a1f02da269f5 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -45,14 +45,16 @@ class PyNcclPipe(KVPipeBase): METADATA_DTYPE = torch.int64 def __init__(self, + kv_group_rank: int, local_rank: int, config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): self.config = config self.local_rank = local_rank - self.kv_rank = self.config.kv_rank + self.kv_group_rank = kv_group_rank self.kv_parallel_size = self.config.kv_parallel_size + self.kv_world_size = self.config.kv_world_size if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: @@ -60,20 +62,18 @@ def __init__(self, # build distributed connection and send/recv implementation store_timeout = self.config.get_from_extra_config("store_timeout", 300) + logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, - rank=self.kv_rank, - world_size=self.kv_parallel_size, + rank=self.kv_group_rank, + world_size=self.kv_world_size, store_timeout=store_timeout, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() impl = self._get_device_send_recv_impl(self.group) self.device_send_func, self.device_recv_func = impl - # set target rank - self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size - self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None @@ -147,16 +147,16 @@ def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: dtype=metadata["dtype"], device=self.device) - def _send_metadata(self, metadata: Metadata): + def _send_metadata(self, metadata: Metadata, target_rank: int): """ Send the metadata dictionary to the target rank. Parameters: - metadata: A dictionary with keys "dtype" and "shape". """ - self.group.send_obj(metadata, self.target_rank_for_send) + self.group.send_obj(metadata, target_rank) - def _recv_metadata(self) -> Metadata: + def _recv_metadata(self, src_rank: int) -> Metadata: """ Receive the metadata dictionary from the target rank. @@ -164,9 +164,9 @@ def _recv_metadata(self) -> Metadata: - metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ - return self.group.recv_obj(self.target_rank_for_recv) + return self.group.recv_obj(src_rank) - def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ The actual implementation of sending the tensor and its metadata to the target rank. @@ -176,12 +176,12 @@ def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: being sent. """ metadata = self._make_metadata(tensor) - self._send_metadata(metadata) + self._send_metadata(metadata, target_rank) if tensor is not None: self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + target_rank) - def _recv_impl(self) -> Optional[torch.Tensor]: + def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]: """ The actual implementation of receiving a tensor and its metadata from the target rank. @@ -189,21 +189,22 @@ def _recv_impl(self) -> Optional[torch.Tensor]: Returns: - buffer: The received tensor, or None if no tensor is received. """ - metadata = self._recv_metadata() + metadata = self._recv_metadata(src_rank) if metadata["dtype"] is None: return None buffer = self._prepare_recv_buffer(metadata) - self.device_recv_func(buffer, self.target_rank_for_recv) + self.device_recv_func(buffer, src_rank) return buffer def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + tensor_size: int, + target_rank: int) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ try: - self._send_impl(tensor) + self._send_impl(tensor, target_rank) with self.buffer_size_lock: self.buffer_size -= tensor_size @@ -222,7 +223,7 @@ def block_if_full(self): logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ Sends a tensor and its metadata to the destination rank in a non-blocking way. @@ -230,6 +231,7 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: Parameters: - tensor: The tensor to send, or None if no tensor is being sent. """ + logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank) if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -243,32 +245,39 @@ def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: with self.buffer_size_lock: self.buffer_size += tensor_size - self.transport_thread.submit(self.send_tensor_wrapper, tensor, - tensor_size) + future = self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size, + target_rank) + return future - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. Returns: - tensor: The received tensor, or None if no tensor is received. """ + + logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank) + if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) - future = self.transport_thread.submit(self._recv_impl) + future = self.transport_thread.submit(self._recv_impl, src_rank) - try: - tensor = future.result() - except Exception as e: - logger.error("Encountering exception in KV receiving thread") - logger.error("%s", e) - logger.error("My device: %s", self.device) - import traceback - traceback.print_exc() - raise e + return future + + # try: + # tensor = future.result() + # except Exception as e: + # logger.error("Encountering exception in KV receiving thread") + # logger.error("%s", e) + # logger.error("My device: %s", self.device) + # import traceback + # traceback.print_exc() + # raise e - return tensor + # return tensor def close(self): """ diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 1e80e0bd7de8..cd90206f89b3 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -35,6 +35,7 @@ def __init__( rank: int, local_rank: int, config: "VllmConfig", + world_group, ): self.config = config @@ -47,7 +48,7 @@ def __init__( "TransferAgent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector( - rank, local_rank, config) + rank, local_rank, config, world_group) def send_kv_caches_and_hidden_states( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e0eeeffb88a7..353431b1f1e4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -979,7 +979,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, - config=vllm_config) + config=vllm_config, + world_group=get_world_group()) def ensure_model_parallel_initialized( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 54f7b8fb69b5..05c8a2d634d2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,10 +2,14 @@ import copy import time +import pickle +import uuid from collections import Counter as collectionsCounter from collections import deque +from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Literal, Mapping, NamedTuple, Optional) @@ -62,6 +66,9 @@ resolve_obj_by_qualname, weak_bind) from vllm.version import __version__ as VLLM_VERSION from vllm.worker.model_runner_base import InputProcessingError +from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType +from vllm.distributed.device_communicators.nixl import NixlMetadata + logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -93,7 +100,7 @@ class OutputData(NamedTuple): # outputs from multiple steps. is_first_step_output: Optional[bool] skip: List[int] - + remote_prefill_requests: Optional[List[RemotePrefillRequest]] class SchedulerContext: @@ -107,11 +114,14 @@ def __init__(self, multi_step_stream_outputs: bool = False): self.multi_step_stream_outputs: bool = multi_step_stream_outputs + self.remote_prefill_requests: List[RemotePrefillRequest] = [] + def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, is_last_step: bool, - is_first_step_output: Optional[bool]): + is_first_step_output: Optional[bool], + remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None): self.output_queue.append( OutputData(outputs=outputs, seq_group_metadata_list=seq_group_metadata_list, @@ -119,7 +129,9 @@ def append_output(self, outputs: List[SamplerOutput], is_async=is_async, is_last_step=is_last_step, is_first_step_output=is_first_step_output, - skip=[])) + skip=[], + remote_prefill_requests=remote_prefill_requests)) + class LLMEngine: @@ -362,7 +374,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: Scheduler = self.vllm_config.scheduler_config.scheduler_cls self.scheduler = [ Scheduler( - self.scheduler_config, self.cache_config, self.lora_config, + self.model_config, self.scheduler_config, self.cache_config, self.lora_config, self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] if self.model_config.use_async_output_proc else None) @@ -422,6 +434,39 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Flag to set when an input fails to process and the engine should run # the next step without re-scheduling. self._skip_scheduling_next_step = False + self.engine_id = str(uuid.uuid4()) + self._nixl_agents_names: Optional[List[str]] = None + if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": + self._nixl_agents_names = self._initialize_nixl() + + self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) + self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) + self._finished_prefills = set() + self._finished_transfers = set() + + @property + def is_nixl_initialized(self) -> bool: + return getattr(self, "_nixl_agents_names", None) is not None + + def get_nixl_metadata(self) -> NixlMetadata: + if not self.is_nixl_initialized: + raise RuntimeError("Nixl is not initialized") + agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata") + kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr") + return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks) + + def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]: + if not self.is_nixl_initialized: + raise RuntimeError("Nixl is not initialized") + engine_id = nixl_metadata.engine_id + agents_metadata = nixl_metadata.agent_metadata + kv_caches_base_addr = nixl_metadata.kv_caches_base_addr + num_blocks = nixl_metadata.num_blocks + return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks)) + + def _initialize_nixl(self) -> List[bytes]: + agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,)) + return agents_names def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -535,6 +580,8 @@ def __del__(self): # Shutdown model executor when engine is garbage collected # Use getattr since __init__ can fail before the field is set if model_executor := getattr(self, "model_executor", None): + if self.is_nixl_initialized: + model_executor.collective_rpc("shutdown_nixl") model_executor.shutdown() def get_tokenizer_group( @@ -587,11 +634,14 @@ def _add_processed_request( prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> Optional[SequenceGroup]: """Add a processed request to the engine's request pool. return the created sequence group. """ if isinstance(params, SamplingParams) and params.n > 1: + if remote_prefill_params is not None: + raise ValueError("Remote prefill params are not supported for multi-step sampling") ParallelSampleSequenceGroup.add_request( request_id, self, @@ -609,12 +659,14 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) + if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: + next(self.seq_counter) # empty sequence for staging eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, - lora_request, prompt_adapter_request) + lora_request, prompt_adapter_request, remote_prefill_params) encoder_seq = (None if encoder_inputs is None else Sequence( seq_id, encoder_inputs, block_size, eos_token_id, lora_request, @@ -631,8 +683,12 @@ def _add_processed_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, - priority=priority) + priority=priority, + remote_prefill_params=remote_prefill_params, + ) elif isinstance(params, PoolingParams): + if remote_prefill_params is not None: + raise ValueError("Remote prefill params are not supported for pooling") seq_group = self._create_sequence_group_with_pooling( request_id, seq, @@ -703,6 +759,7 @@ def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -794,6 +851,7 @@ def add_request( prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=priority, + remote_prefill_params=remote_prefill_params, ) def _validate_token_prompt(self, prompt: PromptType, @@ -828,6 +886,7 @@ def _create_sequence_group_with_sampling( prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -863,7 +922,9 @@ def _create_sequence_group_with_sampling( prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, priority=priority, - draft_size=draft_size) + draft_size=draft_size, + remote_prefill_params=remote_prefill_params, + ) return seq_group @@ -1030,11 +1091,11 @@ def _process_model_outputs(self, # When we process only one request, no pop is required # (since later we will process all of the rest) (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0] else: (outputs, seq_group_metadata_list, scheduler_outputs, is_async, is_last_step, is_first_step_output, - skip) = ctx.output_queue.popleft() + skip, remote_prefill_requests) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( @@ -1360,6 +1421,12 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Clear outputs for each new scheduler iteration ctx.request_outputs.clear() + ctx.remote_prefill_requests.clear() + + remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + running_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] + running_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current @@ -1372,7 +1439,41 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc - ) = self.scheduler[virtual_engine].schedule() + ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers) + + + # Separate remote prefill and running seq groups + for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups): + if seq_group_metadata.do_remote_prefill: + remote_prefill_seq_group_metadata_list.append(seq_group_metadata) + remote_prefill_scheduled_seq_groups.append(scheduled_seq_group) + else: + running_seq_group_metadata_list.append(seq_group_metadata) + running_scheduled_seq_groups.append(scheduled_seq_group) + + seq_group_metadata_list = running_seq_group_metadata_list + scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups + + # Send remote prefill requests before model execution + for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups): + assert len(scheduled_seq_group.seq_group.seqs) == 1 + assert self._nixl_agents_names + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + block_table = seq_group_metadata.block_tables[seq_id] + if len(block_table) == len(seq_group_metadata.computed_block_nums): + logger.debug("No blocks to prefill") + self._finished_prefills.add(seq_group_metadata.request_id) + continue + remote_prefill_request = RemotePrefillRequest( + request_id=seq_group_metadata.request_id, + # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway + prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks) + sampling_params=scheduled_seq_group.seq_group.sampling_params, + block_ids=block_table, + engine_id=self.engine_id, + computed_block_ids=seq_group_metadata.computed_block_nums, + ) + scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request) ctx.seq_group_metadata_list = seq_group_metadata_list ctx.scheduler_outputs = scheduler_outputs @@ -1427,8 +1528,45 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] + # After model execution, we need to transfer the memory from the prefill to the decode + memory_transfer_reqs = [] + for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list): + remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params + if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: + assert len(scheduled_seq_group.seq_group.seqs) == 1 + req_id = scheduled_seq_group.seq_group.request_id + seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id + block_table = seq_group_metadata.block_tables[seq_id] + staging_block_ids = seq_group_metadata.block_tables[seq_id + 1] + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks] + if computed_decode_block_ids: + kv_recv_req = MemoryTransferRequest( + request_id=req_id, + local_block_ids=block_table[:num_computed_blocks], + staging_block_ids=staging_block_ids[:num_computed_blocks], + remote_block_ids=computed_decode_block_ids, + remote_engine_id=remote_prefill_params.decode_engine_id, + notify_msg=req_id, + op_type=MemoryOpType.READ + ) + memory_transfer_reqs.append(kv_recv_req) + + kv_send_req = MemoryTransferRequest( + request_id=req_id, + local_block_ids=block_table[num_computed_blocks:], + staging_block_ids=staging_block_ids[num_computed_blocks:], + remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:], + remote_engine_id=remote_prefill_params.decode_engine_id, + notify_msg=req_id, + op_type=MemoryOpType.WRITE + ) + memory_transfer_reqs.append(kv_send_req) + + execute_model_req.memory_transfer_requests = memory_transfer_reqs + try: - outputs = self.model_executor.execute_model( + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( execute_model_req=execute_model_req) self._skip_scheduling_next_step = False except InputProcessingError as e: @@ -1455,7 +1593,26 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) # No outputs in this case - outputs = [] + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=[], + blocks_to_swap_in=[], + blocks_to_swap_out=[], + blocks_to_copy=[]) + + outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( + execute_model_req=execute_model_req) + + for req_id, notif_count in request_notif_counter.items(): + self._request_notif_counter[req_id] += notif_count + if self._request_notif_counter[req_id] > -1: + self._finished_prefills.add(req_id) + del self._request_notif_counter[req_id] + + for req_id, done_count in request_done_counter.items(): + self._request_done_counter[req_id] += done_count + if self._request_done_counter[req_id] > -1: + self._finished_transfers.add(req_id) + del self._request_done_counter[req_id] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -1515,7 +1672,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. logger.debug("Stopping remote worker execution loop.") self.model_executor.stop_remote_worker_execution_loop() - + return ctx.request_outputs def _abort_and_cache_schedule( diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index cafd8150bc01..d33ddd19bfe2 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -14,6 +14,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.utils import Device, deprecate_kwargs +from vllm.remote_prefill import RemotePrefillParams +from vllm.distributed.device_communicators.nixl import NixlMetadata VLLM_RPC_SUCCESS_STR = "SUCCESS" @@ -21,6 +23,9 @@ IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" IPC_DATA_EXT = "_data_socket" +IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket" +IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket" +IPC_METRICS_EXT = "_metrics_socket" class MQEngineDeadError(RuntimeError): @@ -36,6 +41,7 @@ class RPCProcessRequest: trace_headers: Optional[Mapping[str, str]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 + remote_prefill_params: Optional[RemotePrefillParams] = None @overload def __init__( @@ -78,6 +84,7 @@ def __init__( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -95,7 +102,7 @@ def __init__( self.trace_headers = trace_headers self.prompt_adapter_request = prompt_adapter_request self.priority = priority - + self.remote_prefill_params = remote_prefill_params @dataclass class RPCError: @@ -116,7 +123,7 @@ class RPCStartupRequest(Enum): @dataclass class RPCStartupResponse: tracing_enabled: bool - + nixl_metadata: Optional[bytes] = None class RPCUProfileRequest(Enum): START_PROFILE = 1 @@ -181,3 +188,13 @@ def ENGINE_DEAD_ERROR( return MQEngineDeadError( "Engine loop is not running. Inspect the stacktrace to " f"find the original error: {repr(error)}.") + +@dataclass +class KvMetrics: + request_active_slots: int + request_total_slots: int + kv_active_blocks: int + kv_total_blocks: int + num_requests_waiting: int + gpu_cache_usage_perc: float + gpu_prefix_cache_hit_rate: float diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index f058b13297bb..d16a2d3fa042 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -8,6 +8,7 @@ Optional, Union, cast, overload) import cloudpickle +import msgspec import psutil import zmq import zmq.asyncio @@ -18,14 +19,18 @@ from vllm import PoolingParams from vllm.config import DecodingConfig, ModelConfig, VllmConfig from vllm.core.scheduler import SchedulerOutputs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics import Stats # yapf conflicts with isort for this block # yapf: disable from vllm.engine.async_llm_engine import ( build_guided_decoding_logits_processor_async) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, RPC_REQUEST_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT, + RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest, + IPC_METRICS_EXT, RPCAdapterLoadedResponse, RPCError, RPCIsSleepingRequest, RPCIsSleepingResponse, @@ -34,7 +39,8 @@ RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) + RPCUProfileRequest, RPCWakeUpRequest, + KvMetrics) from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT @@ -48,6 +54,8 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import Device, deprecate_kwargs +from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback +from vllm.distributed.device_communicators.nixl import NixlMetadata logger = init_logger(__name__) @@ -93,6 +101,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self._errored_with: Optional[BaseException] = None # Get the configs. + self.vllm_config = engine_config self.model_config = engine_config.model_config self.decoding_config = engine_config.decoding_config @@ -117,6 +126,10 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # Metrics. + self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL) + self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -131,8 +144,27 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, # Loop to check health of the LLMEngine periodically. # Started after the MQLLMEngine is ready. self.health_loop: Optional[asyncio.Task] = None + + # Loop to check metrics of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.metrics_loop: Optional[asyncio.Task] = None + self.metrics_publisher = None + self._engine_process = psutil.Process(engine_pid) + self.nixl_metadata: Optional[NixlMetadata] = None + self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL) + self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {} + if self.using_nixl_connector: + self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") + self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") + + + @property + def using_nixl_connector(self) -> bool: + return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector" + @staticmethod def is_unsupported_config(vllm_config: VllmConfig): # Pipeline parallel not yet supported @@ -182,6 +214,61 @@ async def run_heartbeat_loop(self, timeout: int): except Exception as e: self._set_errored(e) + async def run_remote_prefill_request_handler_loop(self): + try: + while True: + if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT): + frames = await self.remote_prefill_request_socket.recv(copy=False) + remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest) + await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request) + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.") + + async def run_metrics_loop(self, timeout: int): + """Background loop that continually checks to ensure the engine process + is still alive. + """ + try: + while True: + # Check if the engine process is running: + if not self._engine_process.is_running() or ( + self._engine_process.status() == psutil.STATUS_ZOMBIE): + # NB: is_running() returns True for zombies + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) " + "died.")) + break + + if await self.metrics_socket.poll(timeout=timeout): + # Metrics received- check the message + message: Frame = await self.metrics_socket.recv(copy=False) + metrics = pickle.loads(message.buffer) + if self.metrics_publisher is not None and isinstance( + metrics, KvMetrics + ): + self.metrics_publisher.publish(metrics.request_active_slots, + metrics.request_total_slots, + metrics.kv_active_blocks, + metrics.kv_total_blocks, + metrics.num_requests_waiting, + metrics.gpu_cache_usage_perc, + metrics.gpu_prefix_cache_hit_rate) + logger.debug("Metrics successful.") + + # TODO: Investigate sending whole stats object + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check metrics loop.") + + except psutil.NoSuchProcess: + self._set_errored( + RuntimeError( + f"Engine process (pid {self._engine_process.pid}) died.")) + + except Exception as e: + self._set_errored(e) + async def run_output_handler_loop(self): """Get RequestOutputs from Engine and stream to Request Queues""" @@ -283,12 +370,26 @@ async def setup(self): # Wait until server is ready. response = await self._wait_for_server_rpc(socket) + if response.nixl_metadata is not None: + assert self.using_nixl_connector + self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata) + self.tracing_flag = response.tracing_enabled # Start health_loop. if self.health_loop is None: self.health_loop = asyncio.create_task( self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + if self.using_nixl_connector: + self.remote_prefill_loop = asyncio.create_task( + self.run_remote_prefill_request_handler_loop()) + + # Start metrics_loop. + if self.metrics_loop is None: + self.metrics_loop = asyncio.create_task( + self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT)) + def close(self): """Destroy the ZeroMQ Context.""" @@ -298,6 +399,8 @@ def close(self): # Cancel background tasks. if self.health_loop is not None: self.health_loop.cancel() + if self.metrics_loop is not None: + self.metrics_loop.cancel() if self.output_loop is not None: self.output_loop.cancel() @@ -420,6 +523,9 @@ async def check_health(self): """ if self._errored_with is not None: raise self._errored_with + + async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata): + await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False) @property def is_running(self) -> bool: @@ -478,6 +584,7 @@ def generate( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, *, inputs: Optional[PromptType] = None # DEPRECATED ) -> AsyncGenerator[RequestOutput, None]: @@ -507,7 +614,8 @@ def generate( return self._process_request(prompt, sampling_params, request_id, lora_request, trace_headers, - prompt_adapter_request, priority) + prompt_adapter_request, priority, + remote_prefill_params) @overload def encode( @@ -591,6 +699,7 @@ async def _process_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ PoolingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" @@ -636,6 +745,12 @@ async def _process_request( else: lp_bytes = None + if remote_prefill_params is not None: + self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback + remote_prefill_params.remote_prefill_request_callback = None + else: + remote_prefill_request_callback = None + request_bytes = pickle.dumps( RPCProcessRequest( prompt=prompt, @@ -645,11 +760,11 @@ async def _process_request( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, + remote_prefill_params=remote_prefill_params, )) # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) + parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,) await self.input_socket.send_multipart(parts, copy=False) # 4) Stream the RequestOutputs from the output queue. Note @@ -740,3 +855,6 @@ async def add_lora(self, lora_request: LoRARequest) -> None: # Raise on error, otherwise happily return None if isinstance(request_output, BaseException): raise request_output + + def set_metrics_publisher(self, metrics_publisher): + self.metrics_publisher = metrics_publisher diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 6ed5ae0a94f1..6205d10e36f4 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -3,11 +3,12 @@ import pickle import signal from contextlib import contextmanager -from typing import Iterator, List, Optional, Union +from typing import Iterator, List, Optional, Union, Dict import cloudpickle +import time import zmq - +import msgspec from vllm import AsyncEngineArgs, SamplingParams from vllm.config import VllmConfig from vllm.engine.llm_engine import LLMEngine @@ -15,8 +16,10 @@ # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, - IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT, + RPCAbortRequest, + IPC_OUTPUT_EXT, IPC_METRICS_EXT, RPCAdapterLoadedResponse, RPCError, RPCIsSleepingRequest, RPCIsSleepingResponse, @@ -25,7 +28,9 @@ RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, - RPCUProfileRequest, RPCWakeUpRequest) + RPCUProfileRequest, RPCWakeUpRequest, + IPC_REMOTE_NIXL_METADATA_EXT, + KvMetrics) # yapf: enable from vllm.logger import init_logger from vllm.outputs import RequestOutput @@ -33,12 +38,88 @@ maybe_register_config_serialize_by_value) from vllm.usage.usage_lib import UsageContext from vllm.worker.model_runner_base import InputProcessingError +from vllm.remote_prefill import RemotePrefillRequest +from vllm.distributed.device_communicators.nixl import NixlMetadata + +from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo +from dataclasses import dataclass, field logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) +class KvStatLogger(StatLoggerBase): + def __init__( + self, + max_num_seqs: int, + num_total_gpu_blocks: int, + metrics_socket + ): + # Must query initialized scheduler for max infos + self.request_total_slots = max_num_seqs + self.kv_total_blocks = num_total_gpu_blocks + self.metrics_socket = metrics_socket + + # KV metrics + self._send_kv_metrics(0, 0, 0, 0.0, 0.0) + + def log(self, stats: Stats) -> None: + self._send_kv_metrics( + stats.num_running_sys, + int(stats.gpu_cache_usage_sys * self.kv_total_blocks), + stats.num_waiting_sys, + stats.gpu_cache_usage_sys, + stats.gpu_prefix_cache_hit_rate + ) + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + pass + + def _send_kv_metrics( + self, + active_slots, + active_kv_blocks, + num_requests_waiting, + gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate, + ): + if not self.metrics_socket.closed: + metrics_bytes = pickle.dumps( + KvMetrics( + active_slots, + self.request_total_slots, + active_kv_blocks, + self.kv_total_blocks, + num_requests_waiting, + gpu_cache_usage_perc, + gpu_prefix_cache_hit_rate, + ) + ) + self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + +# TODO: Send entire stats object to the client +# class StatLogger(StatLoggerBase): +# def __init__( +# self, +# metrics_socket +# ): +# self.metrics_socket = metrics_socket + +# def log(self, stats: Stats) -> None: +# self._send_metrics(stats) + +# def info(self, type: str, obj: SupportsMetricsInfo) -> None: +# pass + +# def _send_metrics(self, stats: Stats): +# if not self.metrics_socket.closed: +# metrics_bytes = pickle.dumps(stats) +# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) + + + + class MQLLMEngine: """A multiprocessing wrapper for :class:`LLMEngine`. @@ -101,12 +182,37 @@ def __init__(self, self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send metrics back to client. + self.metrics_socket = self.ctx.socket(zmq.constants.PUSH) + self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}") + # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" # Error state. self._errored_with: Optional[BaseException] = None + self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH) + self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL) + if self.engine.is_nixl_initialized: + self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") + self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") + + + # Attach logger for continuous metrics publishing + self.kv_stat_logger = KvStatLogger( + self.engine.scheduler_config.max_num_seqs, + self.engine.cache_config.num_gpu_blocks, + self.metrics_socket + ) + self.engine.add_logger("kv_metrics", self.kv_stat_logger) + + # TODO investigate sending whole stats object + # self.general_stat_logger = StatLogger( + # self.metrics_socket + # ) + # self.engine.add_logger("general_metrics", self.general_stat_logger) + @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -192,8 +298,17 @@ def run_startup_loop(self) -> None: # Handle the query from the Client. if request == RPCStartupRequest.IS_SERVER_READY: tracing_enabled = self.engine.is_tracing_enabled() - response = RPCStartupResponse( - tracing_enabled=tracing_enabled) + + # Send nixl metadata to the client + if self.engine.is_nixl_initialized: + nixl_metadata = self.engine.get_nixl_metadata() + encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata) + response = RPCStartupResponse( + tracing_enabled=tracing_enabled, + nixl_metadata=encoded_nixl_metadata) + else: + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) except Exception as e: response = e @@ -206,6 +321,7 @@ def run_engine_loop(self): while True: if not self.engine.has_unfinished_requests(): + logger.debug("No unfinished requests") # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: # When there's no work, check on engine health and send @@ -249,6 +365,13 @@ def engine_step(self) -> List[RequestOutput]: def handle_new_input(self): """Handle new input from the socket""" try: + if self.engine.is_nixl_initialized: + while self.remote_nixl_metadata_socket.poll(timeout=0) != 0: + frames = self.remote_nixl_metadata_socket.recv(copy=False) + nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata) + logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id) + self.engine.add_remote_nixl_metadata(nixl_metadata) + while self.input_socket.poll(timeout=0) != 0: frames = self.input_socket.recv_multipart(copy=False) request = pickle.loads(frames[0].buffer) @@ -297,6 +420,11 @@ def _handle_process_request(self, request: RPCProcessRequest): self._send_outputs(rpc_err) try: + if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill: + def remote_prefill_request_callback(request: RemotePrefillRequest): + logger.debug("Sending remote prefill request: %s", request.request_id) + self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False) + request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback self.engine.add_request( request_id=request_id, prompt=request.prompt, @@ -304,7 +432,9 @@ def _handle_process_request(self, request: RPCProcessRequest): lora_request=request.lora_request, trace_headers=request.trace_headers, prompt_adapter_request=request.prompt_adapter_request, - priority=request.priority) + priority=request.priority, + remote_prefill_params=request.remote_prefill_params, + ) if self.log_requests: logger.info("Added request %s.", request.request_id) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index dd0b67df4f15..8d36bda3d280 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -41,6 +41,7 @@ from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, truncate_tool_call_ids, validate_request_params) +from vllm.remote_prefill import RemotePrefillParams logger = init_logger(__name__) @@ -122,6 +123,7 @@ async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Optional[Request] = None, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, ErrorResponse]: """ @@ -247,6 +249,7 @@ async def create_chat_completion( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=request.priority, + remote_prefill_params=remote_prefill_params, ) generators.append(generator) diff --git a/vllm/envs.py b/vllm/envs.py index f80bf878f79c..ef2543b23a88 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -107,6 +107,11 @@ VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 + VLLM_KV_CAPI_PATH: Optional[str] = None + VLLM_KV_NAMESPACE: Optional[str] = None + VLLM_KV_COMPONENT: Optional[str] = None + VLLM_WORKER_ID: Optional[int] = None + def get_default_cache_root(): @@ -704,6 +709,21 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # It can be changed with this variable if needed for some reason. "VLLM_XGRAMMAR_CACHE_MB": lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), + + # Path to the C API Library + "VLLM_KV_CAPI_PATH": + lambda: os.environ.get("VLLM_KV_CAPI_PATH", None), + + # Identifiers to publish KV related information + "VLLM_KV_NAMESPACE": + lambda: os.environ.get("VLLM_KV_NAMESPACE", None), + "VLLM_KV_COMPONENT": + lambda: os.environ.get("VLLM_KV_COMPONENT", None), + + # Worker ID used for identifying workers in distributed settings + "VLLM_WORKER_ID": + lambda: int(os.getenv("VLLM_WORKER_ID", "0")) + if "VLLM_WORKER_ID" in os.environ else None, } # end-env-vars-definition diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 23b450aeddac..d39a4c116b4d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -605,6 +605,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config + self.config = config + self.vocab_size = config.vocab_size if get_pp_group().is_first_rank: diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d8823..7ffac8d1e138 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -6,16 +6,16 @@ from dataclasses import dataclass from typing import Generic, Optional, Union +import msgspec import torch from typing_extensions import TypeVar, deprecated from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, SequenceGroup, SequenceGroupBase, SequenceStatus) - @dataclass class CompletionOutput: """The output data of one completion output of a request. diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py new file mode 100644 index 000000000000..3acda50f0ec4 --- /dev/null +++ b/vllm/remote_prefill.py @@ -0,0 +1,66 @@ +from dataclasses import dataclass +from typing import Callable, Optional, List +from enum import Enum + +import msgspec + +from vllm.sampling_params import SamplingParams + + +class RemotePrefillRequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + """The request data of one remote prefill output of a request. + + Args: + engine_id: The unique ID of the engine. + request_id: The unique ID of the request. + prompt_token_ids: The token IDs of the prompt. + sampling_params: The sampling parameters. + block_ids: The block IDs of the request. + computed_block_ids: The computed block IDs of the request. + """ + engine_id: str + request_id: str + prompt_token_ids: List[int] + sampling_params: SamplingParams + block_ids: List[int] + computed_block_ids: List[int] + +class MemoryOpType(str, Enum): + WRITE = "WRITE" + READ = "READ" + + +class MemoryTransferRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """The request data of one memory transfer output of a request. + + Args: + request_id: The unique ID of the request. + """ + request_id: str + local_block_ids: List[int] + staging_block_ids: List[int] + remote_block_ids: List[int] + remote_engine_id: str + notify_msg: str + op_type: MemoryOpType + + +RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None] + + +@dataclass +class RemotePrefillParams: + """Remote prefill parameters for text generation.""" + is_remote_prefill: bool = False + is_remote_decode: bool = False + decode_block_ids: Optional[List[int]] = None + decode_computed_block_ids: Optional[List[int]] = None + decode_engine_id: Optional[str] = None + remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None \ No newline at end of file diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 68ed99664947..17eea4155b83 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -103,7 +103,7 @@ class RequestOutputKind(Enum): DELTA = 1 # Do not return intermediate RequestOutput FINAL_ONLY = 2 - + class SamplingParams( msgspec.Struct, diff --git a/vllm/sequence.py b/vllm/sequence.py index 61867b025315..3ed720bd822f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,7 +9,7 @@ from collections.abc import Sequence as GenericSequence from dataclasses import dataclass, field from functools import reduce -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, List import msgspec import torch @@ -20,6 +20,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -59,13 +60,14 @@ class SequenceStatus(enum.IntEnum): """Status of a sequence.""" WAITING = 0 RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered + REMOTE_PREFILLING = 2 + SWAPPED = 3 + # Note: anything after SWAPPED (3) will be considered # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 + FINISHED_STOPPED = 4 + FINISHED_LENGTH_CAPPED = 5 + FINISHED_ABORTED = 6 + FINISHED_IGNORED = 7 @staticmethod def is_finished(status: "SequenceStatus") -> bool: @@ -417,6 +419,7 @@ def __init__( eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + remote_prefill_params: Optional[RemotePrefillParams] = None, ) -> None: self.seq_id = seq_id self.inputs = SingletonInputsAdapter(inputs) @@ -424,7 +427,7 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - + self.remote_prefill_params = remote_prefill_params self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -651,6 +654,7 @@ class SequenceGroup: model; equal to max number of tokens a step can generate for single-draft speculative decoding but larger than that for multi-draft SD (currently not supported). + remote_prefill_params: Remote prefill parameters. """ def __init__(self, @@ -665,7 +669,9 @@ def __init__(self, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - draft_size: int = 1) -> None: + draft_size: int = 1, + remote_prefill_params: Optional[RemotePrefillParams] = None, + ) -> None: self.request_id = request_id self.seqs = seqs self.first_seq = seqs[0] @@ -691,7 +697,7 @@ def __init__(self, self.encoder_seq = encoder_seq self.trace_headers = trace_headers self.priority = priority - + self.remote_prefill_params = remote_prefill_params self.cached_request_output = None @property @@ -940,6 +946,9 @@ class SequenceGroupMetadata( query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. + do_remote_prefill: True if remote prefill is required. + do_remote_decode: True if remote decode is required. + decode_memory_desc: The memory descriptor for the decoder blocks. lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. @@ -979,6 +988,9 @@ class SequenceGroupMetadata( cross_block_table: Optional[list[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None + do_remote_prefill: bool = False + do_remote_decode: bool = False + decode_memory_desc: Optional[bytes] = None ### Stateful fields that are lazily defined. ### # The number of speculative tokens adopted in this request. @@ -1329,6 +1341,8 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None + # The memory transfer requests. + memory_transfer_requests: Optional[List[MemoryTransferRequest]] = None @property def is_first_multi_step(self) -> bool: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9524a69f6b3a..3e0fff335a78 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1876,6 +1876,9 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: if self.vllm_config.kv_transfer_config is None: return False + if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": + return False + prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling @@ -1901,6 +1904,9 @@ def need_send_kv(self, model_input, kv_caches) -> bool: if self.vllm_config.kv_transfer_config is None: return False + if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": + return False + prefill_meta = model_input.attn_metadata.prefill_metadata # check if the current run is profiling diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d59f20f49996..ea1a17cf72ea 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Optional, Set, Tuple, Type, Union +from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any import torch import torch.distributed @@ -31,6 +31,9 @@ from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) +from vllm.distributed.device_communicators.nixl import DynamoNixlConnector +from vllm.remote_prefill import MemoryOpType + logger = init_logger(__name__) @@ -307,6 +310,46 @@ def initialize_cache(self, num_gpu_blocks: int, self._init_cache_engine() self._warm_up_model() + def initialize_nixl(self, engine_id: str) -> List[bytes]: + + # TODO ptarasiewicz nixl can also support DRAM + assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector" + + self.nixl_connector = DynamoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank? + assert len(self.cache_engine) == 1, "Only one cache engine is supported for now" + self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache) + return self.nixl_connector.agent_name + + def get_nixl_agent_metadata(self) -> bytes: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + return self.nixl_connector.get_agent_metadata() + + def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank? + return agent_name + + def get_nixl_kv_caches_base_addr(self) -> List[bytes]: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id] + + def _read_blocks(self, worker_input: WorkerInput) -> None: + for i, op_type in enumerate(worker_input.op_type): + if op_type == MemoryOpType.READ: + self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i]) + + def _write_blocks(self, worker_input: WorkerInput) -> None: + if not self.is_driver_worker: + torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize + + for i, op_type in enumerate(worker_input.op_type): + if op_type == MemoryOpType.WRITE: + self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i]) + + def shutdown_nixl(self) -> None: + assert self.nixl_connector is not None, "Nixl connector is not initialized" + self.nixl_connector.shutdown() + def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ @@ -368,6 +411,8 @@ def prepare_worker_input( blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2) + + mem_transfer_reqs = execute_model_req.memory_transfer_requests or [] return WorkerInput( num_seq_groups=num_seq_groups, @@ -376,6 +421,12 @@ def prepare_worker_input( blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, num_steps=num_steps, + local_block_ids=[r.local_block_ids for r in mem_transfer_reqs], + staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs], + remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs], + remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs], + notify_msg=[r.notify_msg for r in mem_transfer_reqs], + op_type=[r.op_type for r in mem_transfer_reqs], ) @torch.inference_mode() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e5662e69343c..ee75155e7104 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -9,6 +9,7 @@ import cloudpickle import torch import torch.nn as nn +from collections import defaultdict from vllm.config import (ObservabilityConfig, VllmConfig, set_current_vllm_config) @@ -24,6 +25,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) +from vllm.distributed.device_communicators.nixl import DynamoNixlConnector +from vllm.remote_prefill import MemoryOpType logger = init_logger(__name__) @@ -55,6 +58,9 @@ def __init__( from vllm.platforms import current_platform self.current_platform = current_platform + self.nixl_connector: Optional[DynamoNixlConnector] = None + + @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device memory allocations. @@ -221,6 +227,13 @@ class WorkerInput: virtual_engine: int = 0 num_steps: int = 1 + local_block_ids: Optional[List[List[int]]] = None + staging_block_ids: Optional[List[List[int]]] = None + remote_block_ids: Optional[List[List[int]]] = None + remote_engine_id: Optional[List[str]] = None + notify_msg: Optional[List[str]] = None + op_type: Optional[List[MemoryOpType]] = None + @classmethod def from_broadcasted_tensor_dict( cls: Type["WorkerInput"], @@ -237,6 +250,12 @@ def from_broadcasted_tensor_dict( blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], num_steps=tensor_dict.pop("num_steps"), + local_block_ids=tensor_dict.pop("local_block_ids"), + staging_block_ids=tensor_dict.pop("staging_block_ids"), + remote_block_ids=tensor_dict.pop("remote_block_ids"), + remote_engine_id=tensor_dict.pop("remote_engine_id"), + notify_msg=tensor_dict.pop("notify_msg"), + op_type=tensor_dict.pop("op_type"), ) def as_broadcastable_tensor_dict( @@ -251,6 +270,12 @@ def as_broadcastable_tensor_dict( "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, "num_steps": self.num_steps, + "local_block_ids": self.local_block_ids, + "staging_block_ids": self.staging_block_ids, + "remote_block_ids": self.remote_block_ids, + "remote_engine_id": self.remote_engine_id, + "notify_msg": self.notify_msg, + "op_type": self.op_type, } return tensor_dict @@ -321,13 +346,16 @@ def _get_worker_input_from_broadcast( return None worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) + if worker_input.num_seq_groups > 0: + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) - kwargs = extract_previous_hidden_states(broadcast_data) + kwargs = extract_previous_hidden_states(broadcast_data) - return model_input, worker_input, kwargs + return model_input, worker_input, kwargs + else: + return None, worker_input, {} def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest @@ -403,49 +431,89 @@ def execute_model( self.execute_worker(worker_input) # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] + if worker_input.num_seq_groups > 0: + + self._read_blocks(worker_input) + + intermediate_tensors = None + orig_model_execute_time = 0.0 + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + orig_model_execute_time = intermediate_tensors.tensors.get( + "model_execute_time", torch.tensor(0)).item() + + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + model_execute_time = time.perf_counter() - start_time + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + assert isinstance(output, IntermediateTensors) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return [None] if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() + and self.observability_config.collect_model_execute_time + and output is not None): + for o in output: + o.model_execute_time = (orig_model_execute_time + + model_execute_time) - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) + self._write_blocks(worker_input) - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - assert isinstance(output, IntermediateTensors) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) + else: + output = [] + + # collect kv transfer notifications from non driver workers + + if self.nixl_connector is not None: + new_notifs = self.nixl_connector.get_new_notifs() + rank = get_tp_group().rank + all_new_notifs = [new_notifs] + if rank > 0: + get_tp_group().send_object(new_notifs, dst=0) + else: + for i in range(1, get_tp_group().world_size): + all_new_notifs.append(get_tp_group().recv_object(src=i)) + request_notif_counter = defaultdict(int) + for notifs in all_new_notifs: + for req_ids in notifs.values(): + for req_id in req_ids: + request_notif_counter[req_id] += 1 + + if request_notif_counter: + logger.debug("Request notif counter: %s", request_notif_counter) + + request_done_counter = defaultdict(int) + for req_id in self.nixl_connector.get_done_tranfers(): + request_done_counter[req_id] += 1 + + else: + request_notif_counter = {} + request_done_counter = {} # output is List[SamplerOutput] - return output + return output, request_notif_counter, request_done_counter + + def _read_blocks(self, worker_input: WorkerInput) -> None: + pass + + def _write_blocks(self, worker_input: WorkerInput) -> None: + pass def _execute_model_spmd( self, From 2be9bdc6b489318eeabedacf72ae089e715f9e15 Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Fri, 4 Apr 2025 22:09:55 +0000 Subject: [PATCH 2/7] support mtp for dynamo remote prefill --- vllm/core/scheduler.py | 2 ++ vllm/engine/llm_engine.py | 8 ++--- vllm/remote_prefill.py | 18 +++++++++-- vllm/spec_decode/batch_expansion.py | 10 ++++-- vllm/spec_decode/mqa_scorer.py | 8 ++++- vllm/spec_decode/multi_step_worker.py | 6 +++- vllm/spec_decode/spec_decode_worker.py | 43 +++++++++++++++++++++----- vllm/worker/worker_base.py | 15 ++++----- 8 files changed, 85 insertions(+), 25 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fb8ce03b574e..008e3f968500 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -729,6 +729,8 @@ def _schedule_running( finished_prefills.remove(seq_group.request_id) assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] + # when there is one request complete remote prefill, we should run non_spec decode + ret.num_lookahead_slots = 0 # we computed all but the last token in prefill, we need to decode the first token on decode seq_group.update_num_computed_tokens(seq.get_len() - 1) seq.status = SequenceStatus.RUNNING diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 05c8a2d634d2..ca3f0d66141f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1566,7 +1566,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: execute_model_req.memory_transfer_requests = memory_transfer_reqs try: - outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( + outputs, remote_prefill_result = self.model_executor.execute_model( execute_model_req=execute_model_req) self._skip_scheduling_next_step = False except InputProcessingError as e: @@ -1599,16 +1599,16 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: blocks_to_swap_out=[], blocks_to_copy=[]) - outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( + outputs, remote_prefill_result = self.model_executor.execute_model( execute_model_req=execute_model_req) - for req_id, notif_count in request_notif_counter.items(): + for req_id, notif_count in remote_prefill_result.request_notif_counter.items(): self._request_notif_counter[req_id] += notif_count if self._request_notif_counter[req_id] > -1: self._finished_prefills.add(req_id) del self._request_notif_counter[req_id] - for req_id, done_count in request_done_counter.items(): + for req_id, done_count in remote_prefill_result.request_done_counter.items(): self._request_done_counter[req_id] += done_count if self._request_done_counter[req_id] > -1: self._finished_transfers.add(req_id) diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py index 3acda50f0ec4..67a620c06ece 100644 --- a/vllm/remote_prefill.py +++ b/vllm/remote_prefill.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from typing import Callable, Optional, List +from dataclasses import dataclass, field +from typing import Callable, Optional, List, Dict from enum import Enum import msgspec @@ -63,4 +63,16 @@ class RemotePrefillParams: decode_block_ids: Optional[List[int]] = None decode_computed_block_ids: Optional[List[int]] = None decode_engine_id: Optional[str] = None - remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None \ No newline at end of file + remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None + +@dataclass +class RemotePrefillResult: + """Remote prefill notifications and progress. + Args: + request_notif_counter: The recv kv cache notification consumed by decode worker. + request_done_counter: The send kv cache notification consumed by prefill worker. + is_pending_remote_prefill: Whether the current loop pending on kv cache transfer. + """ + request_notif_counter: Dict[str, int] = field(default_factory=dict) + request_done_counter: Dict[str, int] = field(default_factory=dict) + is_pending_remote_prefill: bool = False diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index e08ed742a522..13b941ff3882 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -14,6 +14,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len +from vllm.remote_prefill import RemotePrefillResult SeqId = int TargetSeqId = int @@ -77,9 +78,14 @@ def score_proposals( proposal_lens_list=proposal_lens_list, ) + remote_prefill_result = RemotePrefillResult() + target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( seq_group_metadata_list=target_seq_group_metadata_list)) + if isinstance(target_sampler_output, tuple) \ + and len(target_sampler_output) == 2: + target_sampler_output, remote_prefill_result = target_sampler_output assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] @@ -88,7 +94,7 @@ def score_proposals( return self._contract_batch_all_spec( target_sampler_output=target_sampler_output, proposals=proposals, - ) + ), remote_prefill_result else: # Batch has a mix of spec decode enabled and disabled seq groups return self._contract_batch( @@ -99,7 +105,7 @@ def score_proposals( non_spec_indices=non_spec_indices, spec_indices=spec_indices, k=execute_model_req.num_lookahead_slots, - ) + ), remote_prefill_result def _expand_batch( self, diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py index 6275c460ecef..02b8c81da1f3 100644 --- a/vllm/spec_decode/mqa_scorer.py +++ b/vllm/spec_decode/mqa_scorer.py @@ -4,6 +4,7 @@ SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) +from vllm.remote_prefill import RemotePrefillResult SeqId = int TargetSeqId = int @@ -63,10 +64,15 @@ def score_proposals( ) target_seq_group_metadata_list.append(new_seq_group_metadata) + remote_prefill_result = RemotePrefillResult() + target_sampler_output = self._scorer_worker.execute_model( execute_model_req=execute_model_req.clone( seq_group_metadata_list=target_seq_group_metadata_list)) + if isinstance(target_sampler_output, tuple) and len(target_sampler_output) == 2: + target_sampler_output, remote_prefill_result = target_sampler_output + target_sampler_output = target_sampler_output[0] k = execute_model_req.num_lookahead_slots @@ -156,4 +162,4 @@ def score_proposals( token_ids=all_tokens, logprobs=all_logprobs, hidden_states=hidden_states, - prompt_logprobs=prompt_logprobs) + prompt_logprobs=prompt_logprobs), remote_prefill_result diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d8d54918fa98..15c009abf1f5 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -90,6 +90,8 @@ def sampler_output( indices_of_seq_with_bonus_tokens) model_outputs = self.execute_model( execute_model_req=expanded_request) + if isinstance(model_outputs, tuple) and len(model_outputs) == 2: + model_outputs = model_outputs[0] else: # Here we run multi-step directly, with every step prepared # on the CPU. @@ -99,8 +101,10 @@ def sampler_output( if expanded_request.previous_hidden_states is not None: self.worker.model_runner.return_hidden_states = True for _ in range(sample_len): - model_output: List[SamplerOutput] = self.worker.execute_model( + model_output = self.worker.execute_model( execute_model_req=expanded_request) + if isinstance(model_output, tuple) and len(model_output) == 2: + model_output = model_output[0] assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a724beade129..13a6ee0fb199 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -48,6 +48,7 @@ split_batch_by_proposal_len) from vllm.utils import resolve_obj_by_qualname from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase +from vllm.remote_prefill import RemotePrefillResult logger = init_logger(__name__) @@ -391,6 +392,22 @@ def init_device(self) -> None: def load_model(self, *args, **kwargs): pass + def initialize_nixl(self, engine_id: str) -> List[bytes]: + return self.scorer_worker.initialize_nixl(engine_id) + + def get_nixl_agent_metadata(self) -> bytes: + return self.scorer_worker.get_nixl_agent_metadata() + + def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int +, int]]], num_blocks: int) -> str: + return self.scorer_worker.add_remote_nixl_metadata(engine_id, agents_metadata, kv_caches_base_addr, num_blocks) + + def get_nixl_kv_caches_base_addr(self) -> List[bytes]: + return self.scorer_worker.get_nixl_kv_caches_base_addr() + + def shutdown_nixl(self) -> None: + return self.scorer_worker.shutdown_nixl() + def _configure_model_sampler_for_spec_decode(self): """Configure model sampler to emit GPU tensors. This allows spec decode to keep data on device without transferring to CPU and serializing, @@ -669,8 +686,13 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, not called, meaning that the kv-cache in proposer for requests is not updated, so they cannot enable spec decode in the rest decoding. """ - + remote_prefill_result = RemotePrefillResult() sampler_output = self.scorer_worker.execute_model(execute_model_req) + if isinstance(sampler_output, tuple) and len(sampler_output) == 2: + sampler_output, remote_prefill_result = sampler_output + if remote_prefill_result.is_pending_remote_prefill: + return sampler_output, remote_prefill_result + assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -700,9 +722,11 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # We prepare the prefill hidden states here so that there no # additional complexity in worker for spec_decode vs non_spec_decode # flow and execute_model doesn't need additional modifications. + previous_hidden_states = hidden_states \ + if sampler_output.prefill_hidden_states is None \ + else sampler_output.prefill_hidden_states execute_model_req.previous_hidden_states = \ - prepare_prefill_hidden_states( - sampler_output.prefill_hidden_states) + prepare_prefill_hidden_states(previous_hidden_states) for i in range(self._num_spec_prefill_steps): execute_model_req.spec_step_idx = i self.proposer_worker.execute_model(execute_model_req) @@ -717,7 +741,7 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.sampled_token_probs = None sampler_output.sampled_token_ids = None sampler_output.logprobs = None - return sampler_output_to_return + return sampler_output_to_return, remote_prefill_result def _run_non_driver_rank(self) -> bool: """Run proposer and verifier model in non-driver workers. This is used @@ -736,7 +760,11 @@ def _run_non_driver_rank(self) -> bool: # In case of prefill, scorer_worker has to be run before proposer so # that the hidden states can be propagated to proposer when needed. if data["no_spec"]: - self.scorer_worker.execute_model() + sampler_output = self.scorer_worker.execute_model() + if isinstance(sampler_output, tuple) and len(sampler_output) == 2: + sampler_output, remote_prefill_result = sampler_output + if remote_prefill_result.is_pending_remote_prefill: + return True if not data["disable_all_speculation"]: # Even if num_lookahead_slots is zero, we want to run the @@ -789,9 +817,10 @@ def _run_speculative_decoding_step( "workers generate no tokens") execute_model_req.previous_hidden_states = None + remote_prefill_result = RemotePrefillResult() with Timer() as scoring_timer: - proposal_scores = self.scorer.score_proposals( + proposal_scores, remote_prefill_result = self.scorer.score_proposals( execute_model_req, proposals, ) @@ -831,7 +860,7 @@ def _run_speculative_decoding_step( prompt_logprobs=proposal_scores.prompt_logprobs if not self._disable_logprobs else None, k=execute_model_req.num_lookahead_slots, - stage_times=stage_times) + stage_times=stage_times), remote_prefill_result @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index ee75155e7104..9220574f4ea2 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -26,7 +26,7 @@ ModelRunnerBase, ModelRunnerInputBase) from vllm.distributed.device_communicators.nixl import DynamoNixlConnector -from vllm.remote_prefill import MemoryOpType +from vllm.remote_prefill import MemoryOpType, RemotePrefillResult logger = init_logger(__name__) @@ -429,12 +429,12 @@ def execute_model( kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input) + remote_prefill_result = RemotePrefillResult() # If there is no input, we don't need to execute the model. if worker_input.num_seq_groups > 0: - self._read_blocks(worker_input) - + remote_prefill_result.is_pending_remote_prefill = False intermediate_tensors = None orig_model_execute_time = 0.0 if not get_pp_group().is_first_rank: @@ -476,6 +476,7 @@ def execute_model( self._write_blocks(worker_input) else: + remote_prefill_result.is_pending_remote_prefill = True output = [] # collect kv transfer notifications from non driver workers @@ -503,11 +504,11 @@ def execute_model( for req_id in self.nixl_connector.get_done_tranfers(): request_done_counter[req_id] += 1 - else: - request_notif_counter = {} - request_done_counter = {} + remote_prefill_result.request_notif_counter = request_notif_counter + remote_prefill_result.request_done_counter = request_done_counter + # output is List[SamplerOutput] - return output, request_notif_counter, request_done_counter + return output, remote_prefill_result def _read_blocks(self, worker_input: WorkerInput) -> None: pass From c8c4faace2371fb6d1ce2f70e1586d29ace49f29 Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Fri, 11 Apr 2025 18:10:56 +0000 Subject: [PATCH 3/7] integrate nixl directly to vllm --- vllm/entrypoints/openai/api_server.py | 37 ++++- vllm/entrypoints/openai/protocol.py | 9 ++ .../openai/serving_remote_prefill.py | 145 ++++++++++++++++++ vllm/worker/worker_base.py | 4 + 4 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 vllm/entrypoints/openai/serving_remote_prefill.py diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6a8bdd060228..d808b85a5abe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -58,9 +58,11 @@ EmbeddingResponseData, ErrorResponse, LoadLoRAAdapterRequest, + NixlMetadataRequest, PoolingChatRequest, PoolingCompletionRequest, PoolingRequest, PoolingResponse, + RemotePrefillGenerateRequest, RerankRequest, RerankResponse, ScoreRequest, ScoreResponse, TokenizeRequest, @@ -72,6 +74,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_remote_prefill import OpenAIServingRemotePrefill from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) @@ -386,6 +389,8 @@ def transcription(request: Request) -> OpenAIServingTranscription: def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client +def remote_prefill(request: Request) -> OpenAIServingRemotePrefill: + return request.app.state.openai_serving_remote_prefill @router.get("/health") async def health(raw_request: Request) -> Response: @@ -473,7 +478,10 @@ async def create_chat_completion(request: ChatCompletionRequest, return base(raw_request).create_error_response( message="The model does not support Chat Completions API") - generator = await handler.create_chat_completion(request, raw_request) + remote_prefill_handler = remote_prefill(raw_request) + remote_prefill_params = remote_prefill_handler.get_remote_prefill_params(raw_request) + + generator = await handler.create_chat_completion(request, raw_request, remote_prefill_params) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -659,6 +667,26 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) +@router.get("/nixl_metadata") +async def get_nixl_metadata(raw_request: Request): + handler = remote_prefill(raw_request) + + nixl_metdata = handler.nixl_metadata() + return JSONResponse(content=nixl_metdata.model_dump()) + +@router.post("/remote_nixl_metadata", dependencies=[Depends(validate_json_request)]) +async def add_remote_nixl_metadata(request: NixlMetadataRequest ,raw_request: Request): + handler = remote_prefill(raw_request) + await handler.remote_nixl_metadata(request) + return Response(status_code=200) + +@router.post("/remote_prefill", dependencies=[Depends(validate_json_request)]) +@with_cancellation +@load_aware_call +async def remote_prefill_generate(request: RemotePrefillGenerateRequest, raw_request: Request): + handler = remote_prefill(raw_request) + await handler.remote_prefill(request) + return Response(status_code=200) TASK_HANDLERS: dict[str, dict[str, tuple]] = { "generate": { @@ -1011,6 +1039,13 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, ) if model_config.runner_type == "transcription" else None + # remote prefill + state.openai_serving_remote_prefill = OpenAIServingRemotePrefill( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + ) state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4639b4cea06b..6b9ae71d0f8a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1716,3 +1716,12 @@ class TranscriptionResponseVerbose(OpenAIBaseModel): words: Optional[list[TranscriptionWord]] = None """Extracted words and their corresponding timestamps.""" + +class NixlMetadataResponse(BaseModel): + metadata: str + +class NixlMetadataRequest(BaseModel): + metadata: str + +class RemotePrefillGenerateRequest(BaseModel): + content: str diff --git a/vllm/entrypoints/openai/serving_remote_prefill.py b/vllm/entrypoints/openai/serving_remote_prefill.py new file mode 100644 index 000000000000..1b929fdd8249 --- /dev/null +++ b/vllm/entrypoints/openai/serving_remote_prefill.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import aiohttp +import time +from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import Sequence as GenericSequence +from typing import Optional, Union, cast +import msgspec +import threading + +import jinja2 +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.remote_prefill import (RemotePrefillParams, + RemotePrefillRequest) + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (NixlMetadataRequest, + NixlMetadataResponse, + RemotePrefillGenerateRequest) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import merge_async_iterators +from vllm.inputs.data import TokensPrompt + +from vllm.distributed.device_communicators.nixl import NixlMetadata + +logger = init_logger(__name__) + + +class OpenAIServingRemotePrefill(OpenAIServing): + """OpenAI API for remote prefill. + + Handles the routes: + - /remote_nixl_metadata [POST] + - /nixl_metadata [GET] + - /remote_prefill [POST] + """ + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + ): + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + ) + + self.remote_prefill_endpoints = ["127.0.0.1:8090"] + + self._request_queue = asyncio.Queue() + + loop = asyncio.get_event_loop() + self._background_thread = threading.Thread(target=self.background_event_loop, daemon=True, args=(loop, )) + self._background_thread.start() + + def background_event_loop(self, loop): + asyncio.set_event_loop(loop) + loop.create_task(self._process_requests()) + + async def _process_requests(self): + while True: + request = await self._request_queue.get() + + sampling_params = request.sampling_params + sampling_params.max_tokens = 1 + sampling_params.min_tokens = 1 + + remote_prefill_params = RemotePrefillParams( + is_remote_decode=True, + decode_block_ids=request.block_ids, + decode_engine_id=request.engine_id, + ) + + async for _ in self.engine_client.generate( + request_id=request.request_id, + prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids), + sampling_params=sampling_params, + remote_prefill_params=remote_prefill_params, + ): + pass + + def nixl_metadata(self) -> NixlMetadataResponse: + """Get Nixl metadata""" + + metadata = str( + msgspec.json.encode(self.engine_client.nixl_metadata), encoding="utf-8" + ) + + return NixlMetadataResponse(metadata=metadata) + + async def remote_nixl_metadata( + self, + request: NixlMetadataRequest, + ): + """Add remote Nixl metadata""" + metadata = msgspec.json.decode( + request.metadata.encode(encoding="utf-8"), type=NixlMetadata + ) + + await self.engine_client.add_remote_nixl_metadata(metadata) + + async def remote_prefill(self, request: RemotePrefillGenerateRequest): + request = msgspec.json.decode( + request.content.encode(encoding="utf-8"), + type=RemotePrefillRequest, + ) + + await self._request_queue.put(request) + + def get_remote_prefill_request_callback(self): + # TODO: integrate prefill_queue to dynamo endpoint + async def callback(request: RemotePrefillRequest): + remote_prefill_url = f"http://{self.remote_prefill_endpoints[0]}/remote_prefill" + request = RemotePrefillGenerateRequest( + content=str(msgspec.json.encode(request), encoding="utf-8"), + ) + async with aiohttp.ClientSession() as session: + await session.post(remote_prefill_url, json=request.model_dump()) + + return callback + + def get_remote_prefill_params(self, request: Request): + remote_prefill_params = RemotePrefillParams( + is_remote_prefill=True, + remote_prefill_request_callback=self.get_remote_prefill_request_callback(), + ) + return remote_prefill_params \ No newline at end of file diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 9220574f4ea2..6dd1913d3e8f 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -495,6 +495,10 @@ def execute_model( for notifs in all_new_notifs: for req_ids in notifs.values(): for req_id in req_ids: + # the notification value is changed to bytes in + # nixl commit d40858a0545e285c1b5760909a763a2411d6c89f + if isinstance(req_id, bytes): + req_id = req_id.decode("utf-8") request_notif_counter[req_id] += 1 if request_notif_counter: From dac0543863aa7dfcefeb97341b67717f520b42cf Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Fri, 11 Apr 2025 20:48:55 +0000 Subject: [PATCH 4/7] support add and remove remote prefill endpoint apis --- vllm/entrypoints/openai/api_server.py | 18 +++- vllm/entrypoints/openai/protocol.py | 3 + .../openai/serving_remote_prefill.py | 96 +++++++++++++++---- 3 files changed, 98 insertions(+), 19 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d808b85a5abe..94d9db2ed3af 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -62,6 +62,7 @@ PoolingChatRequest, PoolingCompletionRequest, PoolingRequest, PoolingResponse, + RemotePrefillEpRequest, RemotePrefillGenerateRequest, RerankRequest, RerankResponse, ScoreRequest, ScoreResponse, @@ -685,7 +686,22 @@ async def add_remote_nixl_metadata(request: NixlMetadataRequest ,raw_request: Re @load_aware_call async def remote_prefill_generate(request: RemotePrefillGenerateRequest, raw_request: Request): handler = remote_prefill(raw_request) - await handler.remote_prefill(request) + handler.remote_prefill(request) + return Response(status_code=200) + +@router.post("/add_remote_prefill_eps", dependencies=[Depends(validate_json_request)]) +async def add_remote_prefill_eps(request: RemotePrefillEpRequest ,raw_request: Request): + handler = remote_prefill(raw_request) + try: + await handler.add_remote_prefill_eps(request) + except (ValueError) as e: + return JSONResponse(content={"error": str(e)}, status_code=400) + return Response(status_code=200) + +@router.post("/remove_remote_prefill_eps", dependencies=[Depends(validate_json_request)]) +async def remove_remote_prefill_eps(request: RemotePrefillEpRequest ,raw_request: Request): + handler = remote_prefill(raw_request) + await handler.remove_remote_prefill_eps(request) return Response(status_code=200) TASK_HANDLERS: dict[str, dict[str, tuple]] = { diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6b9ae71d0f8a..245e1b010912 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1725,3 +1725,6 @@ class NixlMetadataRequest(BaseModel): class RemotePrefillGenerateRequest(BaseModel): content: str + +class RemotePrefillEpRequest(BaseModel): + endpoints: list[str] \ No newline at end of file diff --git a/vllm/entrypoints/openai/serving_remote_prefill.py b/vllm/entrypoints/openai/serving_remote_prefill.py index 1b929fdd8249..c9a3ae2d5e3b 100644 --- a/vllm/entrypoints/openai/serving_remote_prefill.py +++ b/vllm/entrypoints/openai/serving_remote_prefill.py @@ -2,15 +2,12 @@ import asyncio import aiohttp -import time -from collections.abc import AsyncGenerator, AsyncIterator -from collections.abc import Sequence as GenericSequence -from typing import Optional, Union, cast +from collections import defaultdict +from fastapi import Request import msgspec import threading - -import jinja2 -from fastapi import Request +from typing import Optional +from urllib.parse import urlparse from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -22,16 +19,12 @@ # yapf: disable from vllm.entrypoints.openai.protocol import (NixlMetadataRequest, NixlMetadataResponse, + RemotePrefillEpRequest, RemotePrefillGenerateRequest) # yapf: enable -from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs +from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger -from vllm.outputs import RequestOutput -from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import merge_async_iterators from vllm.inputs.data import TokensPrompt from vllm.distributed.device_communicators.nixl import NixlMetadata @@ -63,7 +56,9 @@ def __init__( request_logger=request_logger, ) - self.remote_prefill_endpoints = ["127.0.0.1:8090"] + self.remote_prefill_endpoint_map = defaultdict(int) + self.remote_prefill_endpoints = [] + self.counter = 0 self._request_queue = asyncio.Queue() @@ -117,18 +112,21 @@ async def remote_nixl_metadata( await self.engine_client.add_remote_nixl_metadata(metadata) - async def remote_prefill(self, request: RemotePrefillGenerateRequest): + def remote_prefill(self, request: RemotePrefillGenerateRequest): request = msgspec.json.decode( request.content.encode(encoding="utf-8"), type=RemotePrefillRequest, ) - await self._request_queue.put(request) + self._request_queue.put_nowait(request) def get_remote_prefill_request_callback(self): # TODO: integrate prefill_queue to dynamo endpoint async def callback(request: RemotePrefillRequest): - remote_prefill_url = f"http://{self.remote_prefill_endpoints[0]}/remote_prefill" + endpoint = self.remote_prefill_endpoints[self.counter % len(self.remote_prefill_endpoints)] + self.counter = (self.counter + 1) % (2** 31 - 1) + remote_prefill_url = f"{endpoint}/remote_prefill" + logger.debug(f"Remote prefill endpoint: {remote_prefill_url}") request = RemotePrefillGenerateRequest( content=str(msgspec.json.encode(request), encoding="utf-8"), ) @@ -138,8 +136,70 @@ async def callback(request: RemotePrefillRequest): return callback def get_remote_prefill_params(self, request: Request): + if len(self.remote_prefill_endpoints) == 0: + return None + remote_prefill_params = RemotePrefillParams( is_remote_prefill=True, remote_prefill_request_callback=self.get_remote_prefill_request_callback(), ) - return remote_prefill_params \ No newline at end of file + return remote_prefill_params + + def _update_remote_prefill_endpoints(self): + """Calculate remote prefill endpoints""" + if not self.remote_prefill_endpoint_map: + self.remote_prefill_endpoints = [] + + self.remote_prefill_endpoints = [ep for ep in self.remote_prefill_endpoint_map.keys() \ + if self.remote_prefill_endpoint_map[ep] == 1] + #TODO: let's clean up the map with the value = 0 + # and we should remote the nixl connections + logger.info(f"Remote prefill endpoints: {self.remote_prefill_endpoints}") + return self.remote_prefill_endpoints + + async def add_remote_prefill_ep(self, ep: str): + add_remote_nixl_metadata_url = f"{ep}/remote_nixl_metadata" + get_remote_nixl_metadata_url = f"{ep}/nixl_metadata" + metadata = NixlMetadataRequest( + metadata=self.nixl_metadata().metadata, + ) + async with aiohttp.ClientSession() as session: + async with session.post(add_remote_nixl_metadata_url, json=metadata.model_dump()) as resp: + if resp.status != 200: + raise ValueError(f"add local nixl metadata to remote failed with status: {resp.status}") + + async with session.get(get_remote_nixl_metadata_url) as response: + if response.status != 200: + raise ValueError(f"get remote nixl metadata failed with status: {response.status}") + response_data = await response.json() + metadata = NixlMetadataResponse(**response_data) + request = NixlMetadataRequest( + metadata=metadata.metadata + ) + await self.remote_nixl_metadata(request) + + async def add_remote_prefill_eps(self, request: RemotePrefillEpRequest): + if not request.endpoints or len(request.endpoints) == 0: + raise ValueError("Empty URL") + endpoints = [parsed for parsed in map(urlparse, request.endpoints) \ + if all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]] + endpoints = [f"{x.scheme}://{x.netloc}" for x in endpoints] + if len(endpoints) == 0: + raise ValueError(f"No valid endpoints: {request.endpoints}") + for ep in endpoints: + try: + await self.add_remote_prefill_ep(ep) + self.remote_prefill_endpoint_map[ep] = 1 + except ValueError as e: + logger.error(f"Failed to add remote prefill endpoint {ep}: {e}") + continue + self._update_remote_prefill_endpoints() + + async def remove_remote_prefill_eps(self, request: RemotePrefillEpRequest): + if not request.endpoints or len(request.endpoints) == 0: + logger.error("No remote prefill endpoint to be removed") + return + for ep in request.endpoints: + if ep in self.remote_prefill_endpoint_map: + self.remote_prefill_endpoint_map[ep] = 0 + self._update_remote_prefill_endpoints() From a3c9862d3f282b808510d5b9e8dc8cbc457b8a6a Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Sat, 12 Apr 2025 00:57:17 +0000 Subject: [PATCH 5/7] not require for decode to add nixl metadata from prefill --- vllm/entrypoints/openai/serving_remote_prefill.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/entrypoints/openai/serving_remote_prefill.py b/vllm/entrypoints/openai/serving_remote_prefill.py index c9a3ae2d5e3b..3e7630ec7fdd 100644 --- a/vllm/entrypoints/openai/serving_remote_prefill.py +++ b/vllm/entrypoints/openai/serving_remote_prefill.py @@ -159,7 +159,6 @@ def _update_remote_prefill_endpoints(self): async def add_remote_prefill_ep(self, ep: str): add_remote_nixl_metadata_url = f"{ep}/remote_nixl_metadata" - get_remote_nixl_metadata_url = f"{ep}/nixl_metadata" metadata = NixlMetadataRequest( metadata=self.nixl_metadata().metadata, ) @@ -168,16 +167,6 @@ async def add_remote_prefill_ep(self, ep: str): if resp.status != 200: raise ValueError(f"add local nixl metadata to remote failed with status: {resp.status}") - async with session.get(get_remote_nixl_metadata_url) as response: - if response.status != 200: - raise ValueError(f"get remote nixl metadata failed with status: {response.status}") - response_data = await response.json() - metadata = NixlMetadataResponse(**response_data) - request = NixlMetadataRequest( - metadata=metadata.metadata - ) - await self.remote_nixl_metadata(request) - async def add_remote_prefill_eps(self, request: RemotePrefillEpRequest): if not request.endpoints or len(request.endpoints) == 0: raise ValueError("Empty URL") From a7eca8fee7128fd8549eaea3ca9cc923b2aa8e94 Mon Sep 17 00:00:00 2001 From: Rain Jiang Date: Sat, 12 Apr 2025 04:28:27 +0000 Subject: [PATCH 6/7] skip sample on remote prefill worker --- vllm/worker/model_runner.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3e0fff335a78..85dfa673b320 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1631,6 +1631,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ModelInputForGPUWithSamplingMetadata) _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + _fake_sample_output: Optional[SamplerOutput] = None + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], @@ -1822,11 +1824,16 @@ def execute_model( if model_input.async_callback is not None: model_input.async_callback() - # Sample the next token. - output: SamplerOutput = self.model.sample( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) + # in the producer side of pd disagg scenario, the next tokens are + # not needed. So we skip it + if self.need_skip_sampling() and self._fake_sample_output is not None: + output = self._fake_sample_output + else: + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) if (self.observability_config is not None and self.observability_config.collect_model_forward_time and output is not None): @@ -1859,6 +1866,13 @@ def execute_model( output.hidden_states = hidden_states + # save a fake output + if (self._fake_sample_output is None + and output is not None + and self.need_skip_sampling()): + + self._fake_sample_output = output + return [output] def need_recv_kv(self, model_input, kv_caches) -> bool: @@ -1889,6 +1903,16 @@ def need_recv_kv(self, model_input, kv_caches) -> bool: return self.vllm_config.kv_transfer_config.is_kv_consumer and ( not is_profile_run) and is_prefill_run + def need_skip_sampling(self) -> bool: + """ + check whether skip the step of sampling. + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + return self.vllm_config.kv_transfer_config.get_from_extra_config("skip_sampling", False) + def need_send_kv(self, model_input, kv_caches) -> bool: """Check if we need to send kv-cache to the other worker. We need to send KV when From 7cf1c0ea6273c097e412acd9c43a2017cf31d75e Mon Sep 17 00:00:00 2001 From: Changqi Lu Date: Mon, 21 Apr 2025 10:50:21 +0800 Subject: [PATCH 7/7] fix synchronize before transfer blocks Signed-off-by: Changqi Lu --- vllm/worker/worker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ea1a17cf72ea..be9c995b788b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -339,8 +339,7 @@ def _read_blocks(self, worker_input: WorkerInput) -> None: self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i]) def _write_blocks(self, worker_input: WorkerInput) -> None: - if not self.is_driver_worker: - torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize + torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize for i, op_type in enumerate(worker_input.op_type): if op_type == MemoryOpType.WRITE: