From 5c78497aec6c50fcd06c5a99237558e6f99a2a76 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Wed, 9 Jul 2025 15:08:25 -0700 Subject: [PATCH 01/18] eep basic Signed-off-by: Rui Qiao --- experimental/bench.sh | 11 + experimental/nvshmem.patch | 92 +++++++ experimental/serve_deepseek_v2.sh | 31 +++ experimental/test_scale.py | 61 +++++ experimental/test_stateless_pg.py | 93 +++++++ vllm/config.py | 14 + vllm/distributed/eplb/eplb_state.py | 256 +++++++++++++++--- vllm/distributed/eplb/rebalance_execute.py | 131 +++++++++ vllm/entrypoints/openai/api_server.py | 63 +++++ vllm/executor/uniproc_executor.py | 9 + vllm/model_executor/layers/fused_moe/layer.py | 39 ++- vllm/model_executor/models/deepseek_v2.py | 24 +- vllm/model_executor/models/interfaces.py | 7 + vllm/v1/engine/__init__.py | 10 + vllm/v1/engine/async_llm.py | 71 +++++ vllm/v1/engine/coordinator.py | 35 ++- vllm/v1/engine/core.py | 67 ++++- vllm/v1/engine/core_client.py | 212 ++++++++++++++- vllm/v1/engine/utils.py | 217 ++++++++++++++- vllm/v1/executor/multiproc_executor.py | 11 + vllm/v1/worker/gpu_model_runner.py | 34 ++- vllm/v1/worker/gpu_worker.py | 121 ++++++++- 22 files changed, 1544 insertions(+), 65 deletions(-) create mode 100644 experimental/bench.sh create mode 100644 experimental/nvshmem.patch create mode 100644 experimental/serve_deepseek_v2.sh create mode 100644 experimental/test_scale.py create mode 100644 experimental/test_stateless_pg.py diff --git a/experimental/bench.sh b/experimental/bench.sh new file mode 100644 index 00000000000..c81a2eab07a --- /dev/null +++ b/experimental/bench.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite-Chat" +HOST="localhost" +PORT=8006 + +vllm bench serve \ + --model $MODEL_NAME \ + --host $HOST \ + --port $PORT \ + --num-prompts 5 diff --git a/experimental/nvshmem.patch b/experimental/nvshmem.patch new file mode 100644 index 00000000000..5ebdaea58dd --- /dev/null +++ b/experimental/nvshmem.patch @@ -0,0 +1,92 @@ +From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001 +From: Yongji Wu +Date: Tue, 20 May 2025 13:41:12 -0700 +Subject: [PATCH] fix reinit issues due to states not cleaned up + +fix double free +--- + src/host/init/init.cu | 10 ++++++++++ + .../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++ + src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++ + 3 files changed, 30 insertions(+) + +diff --git a/src/host/init/init.cu b/src/host/init/init.cu +index b1c5dbf..1fecb4b 100644 +--- a/src/host/init/init.cu ++++ b/src/host/init/init.cu +@@ -43,6 +43,8 @@ + #include "internal/host/nvshmemi_types.h" + #include "internal/host/shared_memory.h" + #include "internal/host/nvshmemi_symmetric_heap.hpp" ++// eep-dev ++#include "internal/host/nvshmemi_mem_transport.hpp" + + extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d; + static std::map registered_device_states; +@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) { + /* Multi-init Multi-fini*/ + nvshmemi_state = NULL; + nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0; ++ ++ // eep-dev ++ nvshmemi_mem_p2p_transport::destroy_instance(); ++ nvshmemi_mem_remote_transport::destroy_instance(); ++ free(nvshmemi_default_session); ++ nvshmemi_default_session = nullptr; ++ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false; ++ + nvshmemi_is_device_state_ready = false; + } else + nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle); +diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp +index 2495844..e4f408a 100644 +--- a/src/include/internal/host/nvshmemi_mem_transport.hpp ++++ b/src/include/internal/host/nvshmemi_mem_transport.hpp +@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final { + return p2p_objref_; + } + } ++ // eep-dev ++ static void destroy_instance(void) { ++ if (p2p_objref_ != nullptr) { ++ delete p2p_objref_; ++ p2p_objref_ = nullptr; ++ } ++ } + + void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj); + +@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final { + } + } + ++ // eep-dev ++ static void destroy_instance(void) { ++ if (remote_objref_ != nullptr) { ++ delete remote_objref_; ++ remote_objref_ = nullptr; ++ } ++ } ++ + int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size); + /* On-demand registration and release of memory */ + int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx, +diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp +index a1fa748..788fa96 100644 +--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp ++++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp +@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi + // Discover the network for bootstrap, if not done previously. + // This code needs to be stateful to be able to be called multiple times by the caller + BOOTSTRAP_CHECK(bootstrap_net_init()); ++ // eep-dev ++ if (handle->pre_init_ops != nullptr) { ++ BOOTSTRAP_PTR_FREE(handle->pre_init_ops); ++ handle->pre_init_ops = nullptr; ++ } + if (handle->pre_init_ops == nullptr) { + BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1); + handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id; +-- +2.43.0 + diff --git a/experimental/serve_deepseek_v2.sh b/experimental/serve_deepseek_v2.sh new file mode 100644 index 00000000000..8f1d7cf6bb9 --- /dev/null +++ b/experimental/serve_deepseek_v2.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Serve DeepSeek V2 model with vLLM +# This script demonstrates how to serve the DeepSeek V2 model using vLLM's V1 engine + +# MODEL_NAME="gaunernst/DeepSeek-V2-Lite-Chat-FP8" +MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite-Chat" +HOST="0.0.0.0" +PORT=8006 + +DATA_PARALLEL_SIZE=3 +DATA_PARALLEL_SIZE_LOCAL=$DATA_PARALLEL_SIZE + +export VLLM_USE_V1=1 +export VLLM_ALL2ALL_BACKEND="pplx" +export VLLM_USE_DEEP_GEMM=1 + +# Launch the vLLM server +vllm serve $MODEL_NAME --trust-remote-code \ + --disable-log-requests \ + --host $HOST \ + --port $PORT \ + --tensor-parallel-size 1 \ + --enable-expert-parallel \ + --enable-eplb \ + --num-redundant-experts 32 \ + --enforce-eager \ + --data-parallel-backend ray \ + --data-parallel-size $DATA_PARALLEL_SIZE \ + --data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \ + --data-parallel-start-rank 0 \ No newline at end of file diff --git a/experimental/test_scale.py b/experimental/test_scale.py new file mode 100644 index 00000000000..0c453159889 --- /dev/null +++ b/experimental/test_scale.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import json +import sys + +import requests + + +def test_scale(host, port, new_dp_size): + url = f"http://{host}:{port}/scale" + payload = {"new_data_parallel_size": new_dp_size} + headers = {"Content-Type": "application/json"} + + print(f"Sending scale request to {url}") + print(f"Payload: {json.dumps(payload, indent=2)}") + + try: + response = requests.post(url, + json=payload, + headers=headers, + timeout=300) + + print(f"Status Code: {response.status_code}") + print(f"Response: {response.text}") + + if response.status_code == 200: + print("Scale up/down request successful!") + return True + else: + print("Scale up/down request failed!") + return False + + except requests.exceptions.RequestException as e: + print(f"Request failed: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Test scale up/down functionality") + parser.add_argument("--host", default="localhost", help="API server host") + parser.add_argument("--port", + type=int, + default=8006, + help="API server port") + parser.add_argument("--new_dp_size", + type=int, + default=2, + help="New data parallel size") + + args = parser.parse_args() + + success = test_scale(args.host, args.port, args.new_dp_size) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/experimental/test_stateless_pg.py b/experimental/test_stateless_pg.py new file mode 100644 index 00000000000..452fe1a8595 --- /dev/null +++ b/experimental/test_stateless_pg.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch.multiprocessing import spawn + +from vllm.distributed.utils import ( + stateless_destroy_torch_distributed_process_group, + stateless_init_torch_distributed_process_group) + + +def worker_process(rank: int, world_size: int, host: str, port1: int, + port2: int): + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Create first process group with all workers + pg1 = stateless_init_torch_distributed_process_group(host=host, + port=port1, + rank=rank, + world_size=world_size, + backend="gloo") + + # Create second process group with worldsize-1 workers (excluding last rank) + pg2 = None + if rank < world_size - 1: + pg2 = stateless_init_torch_distributed_process_group( + host=host, + port=port2, + rank=rank, + world_size=world_size - 1, + backend="gloo") + + # Test both groups work simultaneously + tensor1 = torch.tensor([rank], dtype=torch.float32) + torch.distributed.all_reduce(tensor1, group=pg1) + expected1 = sum(range(world_size)) + assert tensor1.item( + ) == expected1, f"PG1 failed: got {tensor1.item()}, expected {expected1}" + print(f"Rank {rank}: PG1 all_reduce passed") + + if pg2 is not None: + tensor2 = torch.tensor([rank], dtype=torch.float32) + torch.distributed.all_reduce(tensor2, group=pg2) + expected2 = sum(range(world_size - 1)) + assert tensor2.item() == expected2, ( + f"PG2 failed: got {tensor2.item()}, expected {expected2}") + print(f"Rank {rank}: PG2 all_reduce passed") + + # Destroy first process group + stateless_destroy_torch_distributed_process_group(pg1) + print(f"Rank {rank}: PG1 destroyed") + + # Last rank exits here + if rank == world_size - 1: + print(f"Rank {rank}: Exiting") + return + + # Test second group still works after destroying + # first group and last rank exit + tensor3 = torch.tensor([rank * 10], dtype=torch.float32) + torch.distributed.all_reduce(tensor3, group=pg2) + expected3 = sum(i * 10 for i in range(world_size - 1)) + assert tensor3.item() == expected3, ( + f"PG2 after PG1 destroy failed: got {tensor3.item()}, " + f"expected {expected3}") + print(f"Rank {rank}: PG2 after PG1 destroy passed") + + # Clean up + if pg2 is not None: + stateless_destroy_torch_distributed_process_group(pg2) + print(f"Rank {rank}: PG2 destroyed") + + +def test_stateless_process_groups(): + assert not torch.distributed.is_initialized( + ), "torch.distributed should not be initialized" + + world_size = 4 + host = "127.0.0.1" + port1 = 29600 + port2 = 29601 + + print(f"Testing stateless process groups with world_size={world_size}") + + spawn(worker_process, + args=(world_size, host, port1, port2), + nprocs=world_size, + join=True) + + print("Test completed successfully!") + + +if __name__ == "__main__": + test_stateless_process_groups() diff --git a/vllm/config.py b/vllm/config.py index 508e09174cc..e0154d48955 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1954,6 +1954,20 @@ def has_unfinished_dp(dp_group: "ProcessGroup", aggregated_has_unfinished = bool(tensor.item()) return aggregated_has_unfinished + # eep-dev + @staticmethod + def sync_kv_cache_memory(dp_group: "ProcessGroup", + kv_cache_memory: int) -> None: + if kv_cache_memory == -1: + kv_cache_memory = torch.iinfo(torch.int64).max + tensor = torch.tensor([kv_cache_memory], + dtype=torch.int64, + device="cpu") + # we cannot use broadcast for stateless dp group since it depends + # on global rank + torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) + return tensor.item() + def compute_hash(self): """ Provide a hash that uniquely identifies all the configs diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 6b0a126ca9b..1a29b3b12a6 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -29,12 +29,18 @@ import time from collections.abc import Sequence from dataclasses import dataclass +# eep-dev +from typing import Optional, Union import torch -from torch.distributed import all_gather, all_reduce +# eep-dev +from torch.distributed import ProcessGroup, all_gather, all_reduce from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_ep_group, get_node_count +from vllm.distributed.parallel_state import (get_ep_group, get_node_count, + in_the_same_node_as) +# eep-dev +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -172,6 +178,9 @@ def build( model: MixtureOfExperts, device: torch.device, parallel_config: ParallelConfig, + global_expert_load: Optional[torch.Tensor] = None, + old_global_expert_indices: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None, ) -> "EplbState": """ Build the initial EPLB state. @@ -185,8 +194,12 @@ def build( physical_to_logical_map_list, device=device, ) + # TODO(yongji): hard-wired to make sure the tensor does not get resized + MAX_PHYSICAL_EXPERT_FACTOR = 2 + max_slots_per_logical_expert = (model.num_logical_experts * + MAX_PHYSICAL_EXPERT_FACTOR) logical_to_physical_map = torch.full( - (model.num_logical_experts, model.num_redundant_experts + 1), + (model.num_logical_experts, max_slots_per_logical_expert), -1, device=device, ) @@ -235,11 +248,64 @@ def build( expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) + # eep-dev + if global_expert_load is not None: + ep_group = get_ep_group().device_group + assert global_expert_load.shape == (model.num_moe_layers, + model.num_logical_experts) + assert global_expert_load.dtype == torch.int64 + + num_replicas = model.num_physical_experts + num_groups = model.num_expert_groups + num_nodes = get_node_count() + num_gpus = ep_group.size() + + if num_gpus % num_nodes != 0: + num_nodes = 1 + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert max_physical_slots <= logical_to_physical_map.shape[-1] + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, logical_to_physical_map.shape[-1] - max_physical_slots), + value=-1, + ) + physical_to_logical_map = new_physical_to_logical_map.to(device) + logical_to_physical_map.copy_(new_logical_to_physical_map) + logical_replica_count.copy_(new_logical_replica_count) + model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) + if global_expert_load is not None: + rearrange_expert_weights_inplace( + old_global_expert_indices, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + False, + rank_mapping, + ) + expert_rearrangement_step = 0 return cls( physical_to_logical_map, @@ -337,7 +403,10 @@ def step(self, def rearrange(self, model: MixtureOfExperts, - is_profile: bool = False) -> None: + is_profile: bool = False, + execute_shuffle: bool = True, + global_expert_load: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None) -> None: """ Rearrange the experts according to the current load. """ @@ -353,42 +422,82 @@ def rearrange(self, logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") - # This mapping is only used here, so we do not store it in the state - physical_expert_start = ep_rank * model.num_local_physical_experts - physical_expert_end = (physical_expert_start + - model.num_local_physical_experts) - # (num_moe_layers, num_local_physical_experts) - local_physical_to_logical_map = self.physical_to_logical_map[ - :, - physical_expert_start:physical_expert_end, - ] + # eep-dev + if global_expert_load is None: + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] - # Map the local physical expert load to global logical experts - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - model.num_moe_layers, - model.num_logical_experts, - dtype=self.expert_load_window.dtype, - device=self.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=local_physical_to_logical_map.unsqueeze(0).expand_as( - self.expert_load_window).long(), - src=self.expert_load_window, - ) + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) - # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = logical_expert_load_window.sum(dim=0) - all_reduce(global_expert_load_window, group=ep_group) + if not execute_shuffle: + metadata = torch.tensor( + [ + model.num_moe_layers, model.num_logical_experts, + self.physical_to_logical_map.shape[1] + ], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.broadcast(metadata, + group=get_ep_group().cpu_group, + group_src=0) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) + + if not execute_shuffle: + # (num_moe_layers, old_num_physical_experts) + old_global_expert_indices = self.physical_to_logical_map + torch.distributed.broadcast(old_global_expert_indices, + group=ep_group, + group_src=0) + return global_expert_load_window + else: + assert execute_shuffle + global_expert_load_window = global_expert_load # TODO(bowen): Treat differently for prefill and decode nodes num_replicas = model.num_physical_experts num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() + if rank_mapping is not None and len(rank_mapping) == ep_group.size(): + # eep-dev + # NOTE(yongji): scale down, we need to rebalance the experts on + # remaining GPUs, transfer the experts while we haven't shutdown + # the GPUs to be released. + cpu_group = get_ep_group().cpu_group + num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) + num_gpus = sum(new_rank != -1 + for new_rank in rank_mapping.values()) + num_replicas = num_replicas // ep_group.size( + ) * num_gpus # handle num replicas change + else: + num_nodes = get_node_count() + num_gpus = ep_group.size() if num_gpus % num_nodes != 0: + # eep-dev + self.num_nodes = 1 logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" @@ -414,10 +523,24 @@ def rearrange(self, model.expert_weights, ep_group, is_profile, + rank_mapping, ) if not is_profile: - self.physical_to_logical_map.copy_(new_physical_to_logical_map) + if self.physical_to_logical_map.shape[ + 1] != new_physical_to_logical_map.shape[1]: + self.physical_to_logical_map = new_physical_to_logical_map.to( + self.physical_to_logical_map.device) + else: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert max_physical_slots <= self.logical_to_physical_map.shape[-1] + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, + self.logical_to_physical_map.shape[-1] - max_physical_slots), + value=-1, + ) self.logical_to_physical_map.copy_(new_logical_to_physical_map) self.logical_replica_count.copy_(new_logical_replica_count) @@ -430,3 +553,70 @@ def rearrange(self, " (profile) " if is_profile else " ", time_end - time_start, ) + + @staticmethod + def recv_state() -> tuple[torch.Tensor, torch.Tensor]: + """ + Receive the expert load and old placement from the master rank. + """ + ep_group = get_ep_group() + metadata = torch.empty(3, dtype=torch.int32, device="cpu") + torch.distributed.broadcast(metadata, + group=ep_group.cpu_group, + group_src=0) + num_moe_layers, num_logical_experts, num_old_physical_experts = ( + metadata.tolist()) + global_expert_load = torch.zeros( + (num_moe_layers, num_logical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + all_reduce(global_expert_load, group=ep_group.device_group) + old_global_expert_indices = torch.empty( + (num_moe_layers, num_old_physical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + torch.distributed.broadcast(old_global_expert_indices, + group=ep_group.device_group, + group_src=0) + + return global_expert_load, old_global_expert_indices + + +# eep-dev +def _node_count_with_rank_mapping( + pg: Union[ProcessGroup, StatelessProcessGroup], + rank_mapping: dict[int, int], +) -> int: + if isinstance(pg, ProcessGroup): + world_size = torch.distributed.get_world_size(group=pg) + else: + world_size = pg.world_size + + if world_size == 1: + return 1 + + # Build node assignment map + node_assignment = [0] * world_size # rank -> node_id + next_node_id = 0 + + for current_rank in range(world_size): + if node_assignment[current_rank] != 0: + continue # Already assigned to a node + + assert current_rank in rank_mapping + if rank_mapping[current_rank] == -1: + continue # Pending shutdown + + # Assign current rank to a new node + next_node_id += 1 + node_assignment[current_rank] = next_node_id + + # Find all ranks on the same node as current_rank + same_node_flags = in_the_same_node_as(pg, current_rank) + for other_rank, is_same_node in enumerate(same_node_flags): + if is_same_node and node_assignment[other_rank] == 0: + node_assignment[other_rank] = next_node_id + + return next_node_id diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 2ef8587b559..bffa4f31d6a 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -8,6 +8,8 @@ from collections.abc import Iterable, MutableSequence, Sequence from functools import partial +# eep-dev +from typing import Optional import torch from torch.distributed import (P2POp, ProcessGroup, all_gather, @@ -127,6 +129,9 @@ def shuffle_layer( dst_global = local2global(dst) if is_received_locally[dst]: continue + # eep-dev + if old_indices[src_global] == -1 or new_indices[dst_global] == -1: + continue if old_indices[src_global] == new_indices[dst_global]: is_received_locally[dst] = True for weight, buffer in zip(expert_weights, @@ -139,6 +144,9 @@ def shuffle_layer( experts_send_loc: dict[int, int] = {} for src in range(num_local_experts): expert = old_indices[local2global(src)] + # eep-dev + if expert == -1: + continue if expert in experts_send_loc: continue experts_send_loc[expert] = src @@ -181,6 +189,9 @@ def shuffle_layer( if is_received_locally[dst]: continue expert = new_indices[local2global(dst)] + # eep-dev + if expert == -1: + continue if expert in experts_recv_loc: continue experts_recv_loc[expert] = dst @@ -227,6 +238,9 @@ def shuffle_layer( weight[dst].copy_(buffer[dst]) else: expert = new_indices[local2global(dst)] + # eep-dev + if expert == -1: + continue src = experts_recv_loc[expert] for weight, buffer in zip(expert_weights, expert_weights_buffer): weight[dst].copy_(buffer[src]) @@ -238,6 +252,8 @@ def rearrange_expert_weights_inplace( expert_weights: Sequence[Iterable[torch.Tensor]], ep_group: ProcessGroup, is_profile: bool = False, + # eep-dev + rank_mapping: Optional[dict[int, int]] = None, ) -> None: """ Rearranges the expert weights in place according to the new expert indices. @@ -257,6 +273,27 @@ def rearrange_expert_weights_inplace( This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. """ + # eep-dev + if rank_mapping is not None: + if len(rank_mapping) == ep_group.size(): + # scale down + new_global_expert_indices = \ + _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices, + rank_mapping, + ) + else: + # scale up + old_global_expert_indices = \ + _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices, + rank_mapping, + ep_group.size(), + ) + + assert old_global_expert_indices.shape[ + 1] == new_global_expert_indices.shape[1] + num_moe_layers, num_physical_experts = old_global_expert_indices.shape assert len(expert_weights) == num_moe_layers @@ -304,4 +341,98 @@ def rearrange_expert_weights_inplace( ) +# eep-dev +def _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices: torch.Tensor, + rank_mapping: dict[int, int], + new_ep_size: int, +) -> torch.Tensor: + """ + Map the old global expert indices to the new global expert indices. + + Args: + old_global_expert_indices: + Shape (num_layers, old_ep_size * num_local_physical_experts). + rank_mapping: Mapping from old rank to new rank. + new_ep_size: New expert parallelism size. + + Returns: + Mapped expert indices with shape + (num_layers, new_ep_size * num_local_physical_experts). + """ + num_layers, old_num_physical_experts = old_global_expert_indices.shape + + if not rank_mapping: + # If no rank mapping, return the original tensor + return old_global_expert_indices + + # Get sizes from parameters and rank_mapping + old_ep_size = len(rank_mapping) + num_local_physical_experts = old_num_physical_experts // old_ep_size + new_num_physical_experts = new_ep_size * num_local_physical_experts + + # Create mapped tensor with new shape, initialized to -1 + mapped_expert_indices = torch.full( + (num_layers, new_num_physical_experts), + fill_value=-1, + dtype=old_global_expert_indices.dtype, + device=old_global_expert_indices.device, + ) + + # Handle rank mapping (scale up/down with rank changes) + for old_rank in range(old_ep_size): + new_rank = rank_mapping.get(old_rank) + if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size: + # This old rank exists in the new world + old_start_idx = old_rank * num_local_physical_experts + old_end_idx = (old_rank + 1) * num_local_physical_experts + new_start_idx = new_rank * num_local_physical_experts + new_end_idx = (new_rank + 1) * num_local_physical_experts + + mapped_expert_indices[:, new_start_idx:new_end_idx] = \ + old_global_expert_indices[:, old_start_idx:old_end_idx] + # If new_rank is None or >= new_ep_size, the experts remain -1 + # (scale down case) + + return mapped_expert_indices + + +# eep-dev +def _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices: torch.Tensor, + rank_mapping: dict[int, int], +) -> torch.Tensor: + num_layers, new_num_physical_experts = new_global_expert_indices.shape + + if not rank_mapping: + # If no rank mapping, return the original tensor + return new_global_expert_indices + + # Get sizes from parameters and rank_mapping + old_ep_size = len(rank_mapping) + new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values()) + num_local_physical_experts = new_num_physical_experts // new_ep_size + old_num_physical_experts = old_ep_size * num_local_physical_experts + + mapped_expert_indices = torch.full( + (num_layers, old_num_physical_experts), + fill_value=-1, + dtype=new_global_expert_indices.dtype, + device=new_global_expert_indices.device, + ) + + for old_rank in range(old_ep_size): + new_rank = rank_mapping[old_rank] + if new_rank >= 0 and new_rank < new_ep_size: + old_start_idx = old_rank * num_local_physical_experts + old_end_idx = (old_rank + 1) * num_local_physical_experts + new_start_idx = new_rank * num_local_physical_experts + new_end_idx = (new_rank + 1) * num_local_physical_experts + + mapped_expert_indices[:, old_start_idx:old_end_idx] = \ + new_global_expert_indices[:, new_start_idx:new_end_idx] + + return mapped_expert_indices + + __all__ = ["rearrange_expert_weights_inplace"] diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2f8b31c8a7b..95b8586a39f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -630,6 +630,11 @@ async def create_chat_completion(request: ChatCompletionRequest, return base(raw_request).create_error_response( message="The model does not support Chat Completions API") + if raw_request.app.state.scaling: + raise HTTPException( + status_code=503, + detail="The model is currently scaling. Please try again later.") + generator = await handler.create_chat_completion(request, raw_request) if isinstance(generator, ErrorResponse): @@ -668,6 +673,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Completions API") + if raw_request.app.state.scaling: + raise HTTPException( + status_code=503, + detail="The model is currently scaling. Please try again later.") + try: generator = await handler.create_completion(request, raw_request) except OverflowError as e: @@ -704,6 +714,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Embeddings API") + if raw_request.app.state.scaling: + raise HTTPException( + status_code=503, + detail="The model is currently scaling. Please try again later.") + generator = await handler.create_embedding(request, raw_request) if isinstance(generator, ErrorResponse): @@ -1032,6 +1047,53 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) +# eep-dev +@router.post("/scale", dependencies=[Depends(validate_json_request)]) +async def scale(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, + detail="Invalid JSON format") from e # noqa: B904 + + new_data_parallel_size = body.get("new_data_parallel_size") + drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes + + if new_data_parallel_size is None: + raise HTTPException(status_code=400, + detail="new_data_parallel_size is required") + + if not isinstance(new_data_parallel_size, + int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, + detail="new_data_parallel_size must be a positive integer") + + if not isinstance(drain_timeout, int) or drain_timeout <= 0: + raise HTTPException(status_code=400, + detail="drain_timeout must be a positive integer") + + # Set scaling flag to prevent new requests + raw_request.app.state.scaling = True + client = engine_client(raw_request) + try: + await client.scale(new_data_parallel_size, drain_timeout) + return JSONResponse({ + "message": + f"Scaled up to {new_data_parallel_size} " + "data parallel engines", + }) + except TimeoutError as e: + raise HTTPException( + status_code=408, + detail="Scale up failed due to request drain timeout " + f"after {drain_timeout} seconds") from e + except Exception as e: + raise HTTPException(status_code=500, detail="Scale up failed") from e + finally: + raw_request.app.state.scaling = False + + @router.post("/invocations", dependencies=[Depends(validate_json_request)], responses={ @@ -1586,6 +1648,7 @@ async def init_app_state( state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 + state.scaling = False def create_server_socket(addr: tuple[str, int]) -> socket.socket: diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7ebeb4a2255..100f14b57a9 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, run_method) +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -62,6 +63,14 @@ def check_health(self) -> None: # it's running. return + # eep-dev + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self.driver_worker.reinitialize_distributed(reconfig_request) + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + return + UniProcExecutorAsync = UniProcExecutor diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36ac75a8df4..13b6a78fe20 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -239,7 +239,8 @@ def select_gemm_impl( moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: - assert self.fused_experts == fused_experts + # eep-dev + # assert self.fused_experts == fused_experts if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): @@ -348,8 +349,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) return self.forward( x=x, @@ -366,7 +369,12 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def forward_cuda( self, @@ -385,6 +393,10 @@ def forward_cuda( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( @@ -398,7 +410,12 @@ def forward_cuda( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts( @@ -739,7 +756,8 @@ def __init__( if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod): + if not isinstance(quant_method, + (Fp8MoEMethod, UnquantizedFusedMoEMethod)): # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API @@ -823,6 +841,15 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + # eep-dev + def update_expert_map(self): + # ep_size and ep_rank should already be updated + with self.expert_map.device: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2fa1294b79b..5093a0c6487 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -770,6 +770,25 @@ def set_eplb_state( logical_replica_count=logical_replica_count, ) + # eep-dev + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, DeepseekV2MoE): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -925,9 +944,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): + if (hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx+i}."): diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 3863d8454bf..48bd3a70585 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -508,6 +508,13 @@ def set_eplb_state( """ ... + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + ... + def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: return isinstance(model, MixtureOfExperts) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cd..78e765615f1 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -177,3 +177,13 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b'\x03' # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b'\x04' + + +# eep-dev +class ReconfigureDistributedRequest(msgspec.Struct): + new_data_parallel_size: int + new_data_parallel_rank: int # -1 means keep current rank + new_data_parallel_rank_local: int # -1 means keep current local rank + # for NCCL/GLOO initialization + new_data_parallel_master_ip: str + new_data_parallel_master_port: int \ No newline at end of file diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaa..51ef348ff7c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import time from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Any, Optional, Union @@ -132,6 +133,7 @@ def __init__( for stat_logger in self.stat_loggers[0]: stat_logger.log_engine_initialized() self.output_handler: Optional[asyncio.Task] = None + self.scaling = False try: # Start output handler eagerly if we are in the asyncio eventloop. asyncio.get_running_loop() @@ -608,6 +610,75 @@ async def collective_rpc(self, return await self.engine_core.collective_rpc_async( method, timeout, args, kwargs) + # eep-dev + async def wait_for_requests_to_drain(self, drain_timeout: int = 300): + """Wait for all requests to be drained.""" + start_time = time.time() + while time.time() - start_time < drain_timeout: + if not self.engine_core.dp_engines_running(): + logger.info("Engines are idle, requests have been drained") + return + + logger.info( + "Engines are still running, waiting for requests to drain...") + await asyncio.sleep(1) # Wait 1 second before checking again + + raise TimeoutError(f"Timeout reached after {drain_timeout} seconds " + "waiting for requests to drain.") + + async def scale(self, + new_data_parallel_size: int, + drain_timeout: int = 300): + """ + Scale up or down the data parallel size by adding or removing + engine cores. + Args: + new_data_parallel_size: The new number of data parallel workers + drain_timeout: + Maximum time to wait for requests to drain (seconds) + """ + from vllm.v1.engine.core_client import RayDPClient + + if not isinstance(self.engine_core, RayDPClient): + raise NotImplementedError( + "Scale up/down only supported by RayDPClient") + + self.scaling = True + old_data_parallel_size = \ + self.vllm_config.parallel_config.data_parallel_size + try: + logger.info( + "Waiting for requests to drain before " + "scaling up to %s engines...", new_data_parallel_size) + await self.wait_for_requests_to_drain(drain_timeout) + logger.info( + "Requests have been drained, proceeding with scale " + "to %s engines", new_data_parallel_size) + if new_data_parallel_size > old_data_parallel_size: + await self.engine_core.scale_up(new_data_parallel_size) + else: + await self.engine_core.scale_down(new_data_parallel_size) + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + + # recreate stat loggers + if new_data_parallel_size > old_data_parallel_size: + stat_loggers: list[ + list[StatLoggerBase]] = setup_default_loggers( + vllm_config=self.vllm_config, + log_stats=self.log_stats, + engine_num=new_data_parallel_size, + custom_stat_loggers=None, + ) + num_new_engines = len(stat_loggers) - len(self.stat_loggers) + self.stat_loggers.extend(stat_loggers[-num_new_engines:]) + else: + for _ in range(old_data_parallel_size - + new_data_parallel_size): + self.stat_loggers.pop() + finally: + self.scaling = False + @property def is_running(self) -> bool: # Is None before the loop is started. diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index b3e7a2e85b8..4af46902a8b 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -200,11 +200,44 @@ def process_input_socket(self, front_publish_address: str, # Ignore subscription messages. continue + # eep-dev + decoded = msgspec.msgpack.decode(buffer) + if isinstance(decoded, list) and len( + decoded) == 2 and decoded[0] == "SCALE_UP": + # Handle scale up notification + new_engine_count = decoded[1] + current_count = len(self.engines) + if new_engine_count > current_count: + for _ in range(new_engine_count - current_count): + self.engines.append(EngineState()) + # NOTE(yongji): handle the case + # where newly started engines have current_wave = 0 + # if existing engines just finished a wave + # and engine_running isn't updated yet at + # CoordinatorProc requests routed to newly started + # engines may not wake up existing engines, as long + # as 0 < request.wave < existing engines' + # current_wave + # we note that 0 is the wave number for the new + # engine + self.engines_running = False + logger.info( + "DPCoordinator scaled up from %s to %s " + "engines", + current_count, new_engine_count) + else: + self.engines = self.engines[:new_engine_count] + logger.info( + "DPCoordinator scaled down from %s to %s " + "engines", + current_count, new_engine_count) + continue # Skip normal engine notification processing + # We received a message on the front-end XPUB socket, # from an API server sending a new request while the # engines are paused, so that we can wake the other # engines. - engine_to_exclude, wave = msgspec.msgpack.decode(buffer) + engine_to_exclude, wave = decoded if not self.engines_running: if wave < self.current_wave: # If the wave number is stale, ensure the message diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11..a2489648f1d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -31,8 +31,10 @@ from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler +# eep-dev from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, UtilityOutput) + EngineCoreRequestType, + ReconfigureDistributedRequest, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor @@ -77,6 +79,9 @@ def __init__(self, self.model_executor.register_failure_callback( executor_fail_callback) + # eep-dev + self.available_gpu_memory_for_kv_cache = -1 + # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ self._initialize_kv_caches(vllm_config) @@ -137,9 +142,19 @@ def _initialize_kv_caches( # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() - # Profiles the peak memory usage of the model to determine how much - # memory can be allocated for kv cache. - available_gpu_memory = self.model_executor.determine_available_memory() + # eep-dev + if os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1": + dp_group = getattr(self, "dp_group", None) + assert dp_group is not None + kv_cache_memory = ParallelConfig.sync_kv_cache_memory( + dp_group, -1) + available_gpu_memory = [kv_cache_memory] * len(kv_cache_specs) + else: + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. + available_gpu_memory = ( + self.model_executor.determine_available_memory()) + self.available_gpu_memory_for_kv_cache = available_gpu_memory[0] assert len(kv_cache_specs) == len(available_gpu_memory) # Get the kv cache tensor size @@ -853,7 +868,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig): local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 - assert 0 <= local_dp_rank <= dp_rank < dp_size + # assert 0 <= local_dp_rank <= dp_rank < dp_size if vllm_config.kv_transfer_config is not None: # modify the engine_id and append the local_dp_rank to it to ensure @@ -977,6 +992,48 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + # eep-dev + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + stateless_destroy_torch_distributed_process_group(self.dp_group) + self.shutdown() + + parallel_config = self.vllm_config.parallel_config + old_dp_size = parallel_config.data_parallel_size + parallel_config.data_parallel_size = \ + reconfig_request.new_data_parallel_size + if reconfig_request.new_data_parallel_rank != -1: + parallel_config.data_parallel_rank = \ + reconfig_request.new_data_parallel_rank + # local rank specifies device visibility, it should not be changed + assert reconfig_request.new_data_parallel_rank_local == -1 + parallel_config.data_parallel_master_ip = \ + reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port + if reconfig_request.new_data_parallel_rank != -2: + self.dp_rank = parallel_config.data_parallel_rank + self.dp_group = parallel_config.stateless_init_dp_group() + reconfig_request.new_data_parallel_master_port = \ + parallel_config.data_parallel_master_port + + self.model_executor.reinitialize_distributed(reconfig_request) + if reconfig_request.new_data_parallel_size > old_dp_size: + assert self.available_gpu_memory_for_kv_cache > 0 + # broadcast KV cache available memory for _initialize_kv_caches + # on new EngineCore + ParallelConfig.sync_kv_cache_memory( + self.dp_group, self.available_gpu_memory_for_kv_cache) + # NOTE(yongji): newly joined workers require dummy_run even + # CUDA graph is not used + self.model_executor.collective_rpc("compile_or_warm_up_model") + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) + else: + logger.info("Distributed environment reinitialized for DP rank %s", + self.dp_rank) + class DPEngineCoreActor(DPEngineCoreProc): """ diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index dafaa15f777..365cb372f1d 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,12 +23,14 @@ from vllm.lora.request import LoRARequest from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, UtilityOutput) + EngineCoreRequestType, + ReconfigureDistributedRequest, UtilityOutput) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager, launch_core_engines) +from vllm.v1.engine.utils import (CoreEngine, CoreEngineActorManager, + CoreEngineProcManager, EngineZmqAddresses, + launch_core_engines) from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr @@ -910,13 +912,29 @@ async def run_engine_stats_update_task(): events = await poller.poll() if not self.engines_running and len(events) == 2 or ( events[0][0] == first_req_rcv_socket): - # Send a message to notify the coordinator that + # eep-dev + # Check if this is a regular request notification or + # scale up notification + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + + # Check if this is a scale up notification + if len(buf) > 4 and buf[4:].startswith(b"SCALE_UP"): + # Extract new engine count from the first 4 bytes + new_engine_count = int.from_bytes( + buf[:4], "little") + # Send scale up notification to coordinator + scale_msg = msgspec.msgpack.encode( + ("SCALE_UP", new_engine_count)) + await socket.send(scale_msg) + continue + + # Regular request notification - send a message to + # notify the coordinator that # we're sending a request while the engines are # paused, so that it can wake the others up # (to run dummy EP loop). self.engines_running = True - buf = first_req_rcv_socket.recv( - flags=zmq.NOBLOCK).result() target_eng_index = int.from_bytes(buf, "little") msg = msgspec.msgpack.encode( (target_eng_index, self.current_wave)) @@ -1047,3 +1065,185 @@ async def _abort_requests(self, request_ids: list[str], engine: EngineIdentity) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) + + +class RayDPClient(DPAsyncMPClient): + """ + Ray-based client for multi-proc, multi-engine (data parallel) + EngineCore. + """ + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, + ): + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) + + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config + assert parallel_config.data_parallel_rank == 0 + assert local_start_index == 0 + + addresses = EngineZmqAddresses( + inputs=[input_address], + outputs=[output_address], + ) + + if len(self.core_engines) > 1: + coordinator = DPCoordinator(parallel_config) + self.resources.coordinator = coordinator + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) + + # Start all engines. + self.resources.engine_manager = (CoreEngineActorManager( + vllm_config=vllm_config, + addresses=addresses, + executor_class=executor_class, + log_stats=log_stats)) + + async def _send_reconfig_message( + self, reconfig_request: ReconfigureDistributedRequest, + engine: CoreEngine) -> asyncio.Future: + """Send reconfiguration message and return the result future without + waiting for completion.""" + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( + (self.client_index, call_id, "reinitialize_distributed", + (reconfig_request, )))) + await self._send_input_message(message, engine, reconfig_request) + self._ensure_output_queue_task() + return future + + async def scale_up(self, new_data_parallel_size: int) -> None: + """Scale up the data parallel size by creating new engine cores + and reconfiguring existing ones.""" + current_dp_size = len(self.core_engines) + + if new_data_parallel_size <= current_dp_size: + return + + # Phase 1: Send reconfigure messages to all existing engines and wait + # for them to be sent + reconfig_futures = [] + # one for stateless group in EngineCore, one for worker's distributed + # world group + self.vllm_config.parallel_config.data_parallel_master_port += 2 + for engine in self.core_engines: + reconfig_request = ReconfigureDistributedRequest( + new_data_parallel_size=new_data_parallel_size, + new_data_parallel_rank=-1, # Keep original rank + new_data_parallel_rank_local=-1, # Keep original local rank + new_data_parallel_master_ip=self.vllm_config.parallel_config. + data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config. + data_parallel_master_port) + future = await self._send_reconfig_message(reconfig_request, + engine) + reconfig_futures.append(future) + + logger.info("All reconfigure messages sent, starting engine creation") + + # Phase 2: Create new engines now that reconfig messages have been sent + # self.resources.engine_manager is guaranteed to be + # CoreEngineActorManager for RayDPClient + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + self.resources.engine_manager.scale_up(self.vllm_config, + new_data_parallel_size) + + # Create new CoreEngine objects for the new engines + new_engine_identities = set() + for i in range(current_dp_size, new_data_parallel_size): + # TODO(yongji): check if the engine is local + new_engine = CoreEngine(index=i, local=False) + self.core_engines.append(new_engine) + new_engine_identities.add(new_engine.identity) + + # Wait for ready messages from new engines on the input socket + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while new_engine_identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError( + "Timed out waiting for new engines to send initial " + "message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + new_engine_identities.discard(identity) + + # Phase 3: Wait for all existing engines to complete reconfiguration + logger.info("Waiting for existing engines to complete reconfiguration") + await asyncio.gather(*reconfig_futures) + + # Notify coordinator about scale up through existing + # stats_update_task connection + self._ensure_stats_update_task() + scale_up_marker = (new_data_parallel_size).to_bytes( + 4, "little") + b"SCALE_UP" + await self.first_req_send_socket.send(scale_up_marker) + + # Update the parallel config + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + logger.info( + "[Elastic EP] Scale up completed, new data parallel size: %s", + new_data_parallel_size) + + async def scale_down(self, new_data_parallel_size: int) -> None: + """Scale down the data parallel size by shutting down and + reconfiguring existing engine cores.""" + current_dp_size = len(self.core_engines) + + if new_data_parallel_size >= current_dp_size: + return + + # one for stateless group in EngineCore, one for worker's distributed + # world group + self.vllm_config.parallel_config.data_parallel_master_port += 2 + + reconfig_futures = [] + for old_dp_rank, engine in enumerate(self.core_engines): + reconfig_request = ReconfigureDistributedRequest( + new_data_parallel_size=new_data_parallel_size, + new_data_parallel_rank=-1, # Keep original rank + new_data_parallel_rank_local=-1, # Keep original local rank + new_data_parallel_master_ip=self.vllm_config.parallel_config. + data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config. + data_parallel_master_port) + if old_dp_rank >= new_data_parallel_size: + reconfig_request.new_data_parallel_rank = -2 + future = await self._send_reconfig_message(reconfig_request, + engine) + reconfig_futures.append(future) + + for _ in range(new_data_parallel_size, current_dp_size): + self.core_engines.pop() + + await asyncio.gather(*reconfig_futures) + + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + self.resources.engine_manager.scale_down(current_dp_size, + new_data_parallel_size) + + self._ensure_stats_update_task() + scale_up_marker = (new_data_parallel_size).to_bytes( + 4, "little") + b"SCALE_UP" + await self.first_req_send_socket.send(scale_up_marker) + + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + logger.info( + "[Elastic EP] Scale down completed, new data parallel size: %s", + new_data_parallel_size) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index ae104bd6eb9..4dd8a35aa4c 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -174,16 +174,22 @@ def __init__( self.local_engine_actors: list[ray.ActorHandle] = [] self.remote_engine_actors: list[ray.ActorHandle] = [] + + env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") + self.env_vars_dict = { + name: os.environ[name] + for name in env_vars_list if name in os.environ + } + runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) + + # eep-dev + self.addresses = addresses + self.executor_class = executor_class + self.log_stats = log_stats dp_size = vllm_config.parallel_config.data_parallel_size local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size - env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor") - env_vars_dict = { - name: os.environ[name] - for name in env_vars_set if name in os.environ - } - runtime_env = RuntimeEnv(env_vars=env_vars_dict) if ray.is_initialized(): logger.info( @@ -208,13 +214,14 @@ def __init__( assert len(placement_groups) == dp_size, ( "Number of placement groups must match data parallel size") + self.placement_group_is_local = [] refs = [] for index in range(dp_size): local_index = local_dp_ranks[index] dp_vllm_config = copy.deepcopy(vllm_config) pg = placement_groups[index] dp_vllm_config.parallel_config.placement_group = pg - local_client = index < local_engine_count + on_head_node = index < local_engine_count actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, @@ -223,14 +230,15 @@ def __init__( runtime_env=runtime_env).remote(vllm_config=dp_vllm_config, executor_class=executor_class, log_stats=log_stats, - local_client=local_client, + on_head_node=on_head_node, addresses=addresses, dp_rank=index, local_dp_rank=local_index) - if local_client: + if on_head_node: self.local_engine_actors.append(actor) else: self.remote_engine_actors.append(actor) + self.placement_group_is_local.append(on_head_node) refs.append(actor.wait_for_init.remote()) ray.get(refs) @@ -254,6 +262,7 @@ def create_dp_placement_groups( local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local + nodes = list_nodes() nodes = sorted(list_nodes(), key=lambda node: node.node_ip != dp_master_ip) assert nodes[0].node_ip == dp_master_ip, ( @@ -305,6 +314,196 @@ def create_dp_placement_groups( local_dp_ranks.append(i) return placement_groups, local_dp_ranks + @staticmethod + def scale_up_create_dp_placement_groups( + old_vllm_config: VllmConfig, new_data_parallel_size: int + ) -> tuple[list["PlacementGroup"], list[int]]: + import ray + from ray._private.state import (available_resources_per_node, + total_resources_per_node) + from ray.util.state import list_nodes + + old_dp_size = old_vllm_config.parallel_config.data_parallel_size + num_new_engines = new_data_parallel_size - old_dp_size + + if num_new_engines <= 0: + return [], [] + + dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip + world_size = old_vllm_config.parallel_config.world_size + + nodes = list_nodes() + nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip) + assert nodes[0].node_ip == dp_master_ip, ( + "The first node must be the head node") + assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( + "There can only be one head node") + + available_resources = available_resources_per_node() + total_resources = total_resources_per_node() + + placement_groups = [] + local_dp_ranks = [] + engines_created = 0 + + for node in nodes: + if engines_created >= num_new_engines: + break + + node_ip = node.node_ip + node_id = node.node_id + available_gpus = int(available_resources[node_id]["GPU"]) + + # Get total GPUs on this node from the node's resources + # Ray stores node resources with node ID as key + total_gpus = int(total_resources[node_id]["GPU"]) + + # Calculate used GPUs and used engines on this node + used_gpus = max(0, total_gpus - available_gpus) + used_engines_on_node = used_gpus // world_size + + # Calculate how many new engines this node can accommodate + available_engine_count = available_gpus // world_size + + # Create placement groups for new engines on this node + for i in range(available_engine_count): + if engines_created >= num_new_engines: + break + + rank = old_dp_size + engines_created + + # Create bundles with node constraint for master node + if node_ip == dp_master_ip: + bundles = [{ + "GPU": 1.0, + "node:" + dp_master_ip: 0.001 + }] * world_size + [{ + "CPU": 1.0 + }] + else: + bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + + pg = ray.util.placement_group( + name=f"dp_rank_{rank}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + + # Local rank starts from the number of engines already used + # on this node + local_rank = used_engines_on_node + i + local_dp_ranks.append(local_rank) + engines_created += 1 + + return placement_groups, local_dp_ranks + + def scale_up(self, old_vllm_config: VllmConfig, + new_data_parallel_size: int) -> None: + import copy + + import ray + from ray.runtime_env import RuntimeEnv + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + + from vllm.v1.engine.core import DPEngineCoreActor + + old_dp_size = len(self.local_engine_actors) + len( + self.remote_engine_actors) + + if new_data_parallel_size <= old_dp_size: + return + + placement_groups, local_dp_ranks = \ + self.scale_up_create_dp_placement_groups( + old_vllm_config, new_data_parallel_size) + + world_size = old_vllm_config.parallel_config.world_size + dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip + new_local_engines = 0 + + runtime_env = RuntimeEnv(env_vars=self.env_vars_dict + | {"VLLM_EEP_RECONFIGURE_LAUNCH": "1"}) + for i, (pg, + local_rank) in enumerate(zip(placement_groups, + local_dp_ranks)): + rank = old_dp_size + i + dp_vllm_config = copy.deepcopy(old_vllm_config) + dp_vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + dp_vllm_config.parallel_config.placement_group = pg + + # Check if this placement group is on the head node + on_head_node = any( + bundle.get("node:" + dp_master_ip, 0) > 0 + for bundle in pg.bundle_specs) + + if on_head_node: + new_local_engines += 1 + # Update data_parallel_size_local + dp_vllm_config.parallel_config.data_parallel_size_local = ( + old_vllm_config.parallel_config.data_parallel_size_local + + new_local_engines) + + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env).remote( + vllm_config=dp_vllm_config, + executor_class=self.executor_class, + log_stats=self.log_stats, + on_head_node=on_head_node, + addresses=self.addresses, + dp_rank=rank, + local_dp_rank=local_rank) + + if on_head_node: + self.local_engine_actors.append(actor) + else: + self.remote_engine_actors.append(actor) + self.created_placement_groups.append(pg) + self.placement_group_is_local.append(on_head_node) + + ray.get([ + actor.wait_for_init.remote() + for actor in (self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 else []) + + self.remote_engine_actors[-(len(placement_groups) - + new_local_engines):] + ]) + + actors = (self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 else []) + \ + self.remote_engine_actors[-(len(placement_groups) - + new_local_engines):] + + for actor in actors: + self.run_refs.append(actor.run.remote()) + + old_vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + # Update old_vllm_config with new data_parallel_size_local if any new + # local engines were added + if new_local_engines > 0: + old_vllm_config.parallel_config.data_parallel_size_local += \ + new_local_engines + + def scale_down(self, old_data_parallel_size: int, + new_data_parallel_size: int) -> None: + import ray + assert old_data_parallel_size > new_data_parallel_size + for _ in range(old_data_parallel_size - new_data_parallel_size): + pg = self.created_placement_groups.pop() + is_local = self.placement_group_is_local.pop() + if is_local: + self.local_engine_actors.pop() + else: + self.remote_engine_actors.pop() + ray.util.remove_placement_group(pg) + def get_run_refs(self): return self.run_refs diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b06b7cc804d..fe80f22e966 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -31,6 +31,8 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port) +# eep-dev +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase @@ -269,6 +271,15 @@ def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return + # eep-dev + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self.collective_rpc("reinitialize_distributed", + args=(reconfig_request, )) + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + return + @property def max_concurrent_batches(self) -> int: return self.parallel_config.pipeline_parallel_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ef03626cf14..c65a73690bf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1766,8 +1766,37 @@ def propose_ngram_draft_token_ids( draft_token_ids.append(drafter_output.tolist()) return draft_token_ids - def load_model(self) -> None: + def load_model(self, reconfigure: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) + # eep-dev + if reconfigure: + from vllm.distributed.parallel_state import get_ep_group + num_local_physical_experts = torch.empty(1, + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = num_local_physical_experts.item() + new_ep_size = get_ep_group().world_size + global_expert_load, old_global_expert_indices = ( + EplbState.recv_state()) + num_logical_experts = global_expert_load.shape[1] + self.parallel_config.num_redundant_experts = ( + num_local_physical_experts * new_ep_size - num_logical_experts) + assert old_global_expert_indices.shape[ + 1] % num_local_physical_experts == 0 + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + else: + global_expert_load = None + old_global_expert_indices = None + rank_mapping = None + with DeviceMemoryProfiler() as m: # noqa: SIM117 time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) @@ -1811,6 +1840,9 @@ def load_model(self) -> None: self.model, self.device, self.parallel_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, ) def save_tensorized_model( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 916052ca5eb..813332ad785 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -23,6 +23,8 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling +# eep-dev +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats @@ -182,8 +184,9 @@ def load_model(self) -> None: else: from contextlib import nullcontext context = nullcontext() + reconfigure = os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1" with context: - self.model_runner.load_model() + self.model_runner.load_model(reconfigure=reconfigure) @torch.inference_mode() def determine_available_memory(self) -> int: @@ -347,6 +350,122 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + from vllm.config import set_current_vllm_config + from vllm.distributed.parallel_state import ( + cleanup_dist_env_and_memory, get_dp_group, get_ep_group, + prepare_communication_buffer_for_model) + from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoEParallelConfig) + + old_ep_size = get_ep_group().world_size + old_ep_rank = get_ep_group().rank + new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( + ).world_size * get_pp_group().world_size + if new_ep_size < old_ep_size: + # scale down + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding " + "before scaling down...") + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + self.model_runner.eplb_state.rearrange(self.model_runner.model, + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping) + torch.cuda.synchronize() + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + cleanup_dist_env_and_memory() + + if reconfig_request.new_data_parallel_rank == -2: + assert old_ep_rank >= new_ep_size + # shutdown + return + + # Update parallel config with provided reconfig_request + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = \ + reconfig_request.new_data_parallel_size + # Only update rank if new value is provided (-1 means keep current) + if reconfig_request.new_data_parallel_rank != -1: + parallel_config.data_parallel_rank = \ + reconfig_request.new_data_parallel_rank + if reconfig_request.new_data_parallel_rank_local != -1: + parallel_config.data_parallel_rank_local = \ + reconfig_request.new_data_parallel_rank_local + parallel_config.data_parallel_master_ip = \ + reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port + + with set_current_vllm_config(self.vllm_config): + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank) + moe_modules = [ + module for module in self.model_runner.model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all(module.moe_config.num_local_experts == num_local_experts + for module in moe_modules), ( + "All MoE modules must have the same number of experts") + new_ep_size = get_ep_group().world_size + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + if new_ep_size < old_ep_size: + num_local_physical_experts = num_local_experts + new_physical_experts = \ + self.model_runner.eplb_state.physical_to_logical_map.shape[1] + parallel_config.num_redundant_experts = ( + new_physical_experts - + self.model_runner.eplb_state.logical_replica_count.shape[1]) + else: + num_local_physical_experts = torch.tensor([num_local_experts], + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = num_local_physical_experts.item() + new_physical_experts = num_local_physical_experts * new_ep_size + global_expert_load = self.model_runner.eplb_state.rearrange( + self.model_runner.model, execute_shuffle=False) + parallel_config.num_redundant_experts = ( + new_physical_experts - global_expert_load.shape[1]) + prepare_communication_buffer_for_model(self.model_runner.model) + self.model_runner.model.update_physical_experts_metadata( + num_physical_experts=new_physical_experts, + num_local_physical_experts=num_local_physical_experts) + if new_ep_size > old_ep_size: + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding " + "after scaling up...") + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + self.model_runner.eplb_state.rearrange( + self.model_runner.model, + execute_shuffle=True, + global_expert_load=global_expert_load, + rank_mapping=rank_mapping) + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + self.model_runner.eplb_state.expert_rearrangement_step = 0 + def save_sharded_state( self, path: str, From cbab40d660613db690bec48ff338e10403cff368 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Wed, 9 Jul 2025 16:27:42 -0700 Subject: [PATCH 02/18] fixes Signed-off-by: Rui Qiao --- vllm/entrypoints/openai/api_server.py | 12 ++++++------ vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/pplx_prepare_finalize.py | 11 +++++++---- vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/coordinator.py | 6 ++---- vllm/v1/engine/core.py | 3 +-- vllm/v1/engine/core_client.py | 17 ++++++++++------- vllm/v1/engine/utils.py | 18 +++++++++--------- vllm/v1/worker/cpu_model_runner.py | 2 +- vllm/v1/worker/gpu_worker.py | 5 ++++- 10 files changed, 42 insertions(+), 36 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 95b8586a39f..62866314b02 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1080,16 +1080,16 @@ async def scale(raw_request: Request): await client.scale(new_data_parallel_size, drain_timeout) return JSONResponse({ "message": - f"Scaled up to {new_data_parallel_size} " + f"Scaled to {new_data_parallel_size} " "data parallel engines", }) except TimeoutError as e: - raise HTTPException( - status_code=408, - detail="Scale up failed due to request drain timeout " - f"after {drain_timeout} seconds") from e + raise HTTPException(status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds") from e except Exception as e: - raise HTTPException(status_code=500, detail="Scale up failed") from e + logger.error("Scale failed: %s", e) + raise HTTPException(status_code=500, detail="Scale failed") from e finally: raw_request.app.state.scaling = False diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 13b6a78fe20..a9882f8429a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -223,7 +223,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore - self.topk_indices_dtype = None + self.topk_indices_dtype = torch.uint32 self.moe = moe self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 66c892ede11..6c222114106 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -78,7 +78,9 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.int32 + # FIXME(rui): this needs to be int32, + # see https://github.com/vllm-project/vllm/pull/20166 + return torch.uint32 def num_dispatchers(self) -> int: return self.num_dispatchers_ @@ -100,9 +102,10 @@ def prepare( hidden_dim = a1.size(-1) # K assert topk_ids.size(0) == num_tokens - assert expert_map is None, """with expert map, -1 id is used for - non-local token; this causes error when casting ids to the - topk_indices_dtype() uint32""" + # FIXME(rui) + # assert expert_map is None, """with expert map, -1 id is used for + # non-local token; this causes error when casting ids to the + # topk_indices_dtype() uint32""" # Is this always going to be a1.device? device = a1.device diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 51ef348ff7c..cd8f940a51a 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -638,7 +638,7 @@ async def scale(self, Maximum time to wait for requests to drain (seconds) """ from vllm.v1.engine.core_client import RayDPClient - + if not isinstance(self.engine_core, RayDPClient): raise NotImplementedError( "Scale up/down only supported by RayDPClient") diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 4af46902a8b..1601373d93f 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -223,14 +223,12 @@ def process_input_socket(self, front_publish_address: str, self.engines_running = False logger.info( "DPCoordinator scaled up from %s to %s " - "engines", - current_count, new_engine_count) + "engines", current_count, new_engine_count) else: self.engines = self.engines[:new_engine_count] logger.info( "DPCoordinator scaled down from %s to %s " - "engines", - current_count, new_engine_count) + "engines", current_count, new_engine_count) continue # Skip normal engine notification processing # We received a message on the front-end XPUB socket, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a2489648f1d..cca8edf9b83 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -146,8 +146,7 @@ def _initialize_kv_caches( if os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None - kv_cache_memory = ParallelConfig.sync_kv_cache_memory( - dp_group, -1) + kv_cache_memory = ParallelConfig.sync_kv_cache_memory(dp_group, -1) available_gpu_memory = [kv_cache_memory] * len(kv_cache_specs) else: # Profiles the peak memory usage of the model to determine how much diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 365cb372f1d..9cd46018fa0 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -28,7 +28,7 @@ from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.engine.utils import (CoreEngine, CoreEngineActorManager, +from vllm.v1.engine.utils import (CoreEngineActorManager, CoreEngineProcManager, EngineZmqAddresses, launch_core_engines) from vllm.v1.executor.abstract import Executor @@ -94,6 +94,8 @@ def make_async_mp_client( # External load balancer - client per DP rank. return DPAsyncMPClient(*client_args) # Internal load balancer - client balances to all DP ranks. + if parallel_config.data_parallel_backend == "ray": + return RayDPClient(*client_args) return DPLBAsyncMPClient(*client_args) return AsyncMPClient(*client_args) @@ -1115,7 +1117,7 @@ def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, async def _send_reconfig_message( self, reconfig_request: ReconfigureDistributedRequest, - engine: CoreEngine) -> asyncio.Future: + engine: EngineIdentity) -> asyncio.Future: """Send reconfiguration message and return the result future without waiting for completion.""" call_id = uuid.uuid1().int >> 64 @@ -1160,17 +1162,17 @@ async def scale_up(self, new_data_parallel_size: int) -> None: # Phase 2: Create new engines now that reconfig messages have been sent # self.resources.engine_manager is guaranteed to be # CoreEngineActorManager for RayDPClient - assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, + CoreEngineActorManager) self.resources.engine_manager.scale_up(self.vllm_config, new_data_parallel_size) # Create new CoreEngine objects for the new engines new_engine_identities = set() for i in range(current_dp_size, new_data_parallel_size): - # TODO(yongji): check if the engine is local - new_engine = CoreEngine(index=i, local=False) + new_engine = i.to_bytes(2, "little") self.core_engines.append(new_engine) - new_engine_identities.add(new_engine.identity) + new_engine_identities.add(new_engine) # Wait for ready messages from new engines on the input socket sync_input_socket = zmq.Socket.shadow(self.input_socket) @@ -1233,7 +1235,8 @@ async def scale_down(self, new_data_parallel_size: int) -> None: await asyncio.gather(*reconfig_futures) - assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + assert isinstance(self.resources.engine_manager, + CoreEngineActorManager) self.resources.engine_manager.scale_down(current_dp_size, new_data_parallel_size) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 4dd8a35aa4c..9cfe37475cf 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -221,7 +221,7 @@ def __init__( dp_vllm_config = copy.deepcopy(vllm_config) pg = placement_groups[index] dp_vllm_config.parallel_config.placement_group = pg - on_head_node = index < local_engine_count + local_client = index < local_engine_count actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, @@ -230,15 +230,15 @@ def __init__( runtime_env=runtime_env).remote(vllm_config=dp_vllm_config, executor_class=executor_class, log_stats=log_stats, - on_head_node=on_head_node, + local_client=local_client, addresses=addresses, dp_rank=index, local_dp_rank=local_index) - if on_head_node: + if local_client: self.local_engine_actors.append(actor) else: self.remote_engine_actors.append(actor) - self.placement_group_is_local.append(on_head_node) + self.placement_group_is_local.append(local_client) refs.append(actor.wait_for_init.remote()) ray.get(refs) @@ -435,11 +435,11 @@ def scale_up(self, old_vllm_config: VllmConfig, dp_vllm_config.parallel_config.placement_group = pg # Check if this placement group is on the head node - on_head_node = any( + local_client = any( bundle.get("node:" + dp_master_ip, 0) > 0 for bundle in pg.bundle_specs) - if on_head_node: + if local_client: new_local_engines += 1 # Update data_parallel_size_local dp_vllm_config.parallel_config.data_parallel_size_local = ( @@ -455,17 +455,17 @@ def scale_up(self, old_vllm_config: VllmConfig, vllm_config=dp_vllm_config, executor_class=self.executor_class, log_stats=self.log_stats, - on_head_node=on_head_node, + local_client=local_client, addresses=self.addresses, dp_rank=rank, local_dp_rank=local_rank) - if on_head_node: + if local_client: self.local_engine_actors.append(actor) else: self.remote_engine_actors.append(actor) self.created_placement_groups.append(pg) - self.placement_group_is_local.append(on_head_node) + self.placement_group_is_local.append(local_client) ray.get([ actor.wait_for_init.remote() diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 410a54e7466..87f28661adc 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -50,7 +50,7 @@ def replace_tensor(obj: Any, cpu_attr_name: str, if k.endswith("_cpu") and isinstance(v, torch.Tensor): replace_tensor(self.input_batch.block_table, k, k[:-4]) - def load_model(self) -> None: + def load_model(self, reconfigure: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 813332ad785..96c5dfe7f9d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -372,6 +372,7 @@ def reinitialize_distributed( old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 for old_ep_rank in range(old_ep_size) } + assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange(self.model_runner.model, execute_shuffle=True, global_expert_load=None, @@ -427,6 +428,7 @@ def reinitialize_distributed( module.moe_config.moe_parallel_config = module.moe_parallel_config if new_ep_size < old_ep_size: num_local_physical_experts = num_local_experts + assert self.model_runner.eplb_state is not None new_physical_experts = \ self.model_runner.eplb_state.physical_to_logical_map.shape[1] parallel_config.num_redundant_experts = ( @@ -441,6 +443,7 @@ def reinitialize_distributed( group_src=0) num_local_physical_experts = num_local_physical_experts.item() new_physical_experts = num_local_physical_experts * new_ep_size + assert self.model_runner.eplb_state is not None global_expert_load = self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=False) parallel_config.num_redundant_experts = ( @@ -457,6 +460,7 @@ def reinitialize_distributed( old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) } + assert self.model_runner.eplb_state is not None self.model_runner.eplb_state.rearrange( self.model_runner.model, execute_shuffle=True, @@ -464,7 +468,6 @@ def reinitialize_distributed( rank_mapping=rank_mapping) if get_ep_group().rank == 0: logger.info("[Elastic EP] Expert resharding completed!") - self.model_runner.eplb_state.expert_rearrangement_step = 0 def save_sharded_state( self, From ea8742465ae3db8ab42abdd6a96f6193f44589df Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:01:38 -0700 Subject: [PATCH 03/18] clean up placement_group functions Signed-off-by: Rui Qiao --- vllm/v1/engine/utils.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 9cfe37475cf..cb33f514400 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -250,6 +250,9 @@ def __init__( def create_dp_placement_groups( vllm_config: VllmConfig ) -> tuple[list["PlacementGroup"], list[int]]: + """ + Create placement groups for data parallel. + """ import ray from ray._private.state import available_resources_per_node @@ -258,7 +261,7 @@ def create_dp_placement_groups( logger.info("Creating placement groups for data parallel") dp_master_ip = \ vllm_config.parallel_config.data_parallel_master_ip - dp_size = vllm_config.parallel_config.data_parallel_size + num_pg_to_create = vllm_config.parallel_config.data_parallel_size local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local @@ -302,7 +305,7 @@ def create_dp_placement_groups( local_dp_ranks.append(i) else: for i in range(available_engine_count): - if len(placement_groups) == dp_size: + if len(placement_groups) == num_pg_to_create: break bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( @@ -315,18 +318,21 @@ def create_dp_placement_groups( return placement_groups, local_dp_ranks @staticmethod - def scale_up_create_dp_placement_groups( + def add_dp_placement_groups( old_vllm_config: VllmConfig, new_data_parallel_size: int ) -> tuple[list["PlacementGroup"], list[int]]: + """ + Add placement groups for new data parallel size. + """ import ray from ray._private.state import (available_resources_per_node, total_resources_per_node) from ray.util.state import list_nodes old_dp_size = old_vllm_config.parallel_config.data_parallel_size - num_new_engines = new_data_parallel_size - old_dp_size + num_pg_to_create = new_data_parallel_size - old_dp_size - if num_new_engines <= 0: + if num_pg_to_create <= 0: return [], [] dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip @@ -344,10 +350,10 @@ def scale_up_create_dp_placement_groups( placement_groups = [] local_dp_ranks = [] - engines_created = 0 + num_pg_created = 0 for node in nodes: - if engines_created >= num_new_engines: + if num_pg_created >= num_pg_to_create: break node_ip = node.node_ip @@ -367,10 +373,10 @@ def scale_up_create_dp_placement_groups( # Create placement groups for new engines on this node for i in range(available_engine_count): - if engines_created >= num_new_engines: + if num_pg_created >= num_pg_to_create: break - rank = old_dp_size + engines_created + rank = old_dp_size + num_pg_created # Create bundles with node constraint for master node if node_ip == dp_master_ip: @@ -394,7 +400,7 @@ def scale_up_create_dp_placement_groups( # on this node local_rank = used_engines_on_node + i local_dp_ranks.append(local_rank) - engines_created += 1 + num_pg_created += 1 return placement_groups, local_dp_ranks @@ -416,7 +422,7 @@ def scale_up(self, old_vllm_config: VllmConfig, return placement_groups, local_dp_ranks = \ - self.scale_up_create_dp_placement_groups( + self.add_dp_placement_groups( old_vllm_config, new_data_parallel_size) world_size = old_vllm_config.parallel_config.world_size From cb183c27284bb9291e284ba937336075d1114190 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:06:58 -0700 Subject: [PATCH 04/18] cleanup Signed-off-by: Rui Qiao --- vllm/v1/engine/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index cb33f514400..bd79f5a8999 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -415,11 +415,12 @@ def scale_up(self, old_vllm_config: VllmConfig, from vllm.v1.engine.core import DPEngineCoreActor - old_dp_size = len(self.local_engine_actors) + len( - self.remote_engine_actors) + old_data_parallel_size = len(self.local_engine_actors) + \ + len(self.remote_engine_actors) - if new_data_parallel_size <= old_dp_size: - return + assert new_data_parallel_size > old_data_parallel_size, ( + "New data parallel size must be greater than old data parallel " + "size for scale up") placement_groups, local_dp_ranks = \ self.add_dp_placement_groups( @@ -434,7 +435,7 @@ def scale_up(self, old_vllm_config: VllmConfig, for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)): - rank = old_dp_size + i + rank = old_data_parallel_size + i dp_vllm_config = copy.deepcopy(old_vllm_config) dp_vllm_config.parallel_config.data_parallel_size = \ new_data_parallel_size From 4b543ea0b9a11b8bc611d437102d39ec08b91a00 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:10:33 -0700 Subject: [PATCH 05/18] cleanup Signed-off-by: Rui Qiao --- vllm/config.py | 1 - vllm/distributed/eplb/eplb_state.py | 8 -------- vllm/distributed/eplb/rebalance_execute.py | 9 --------- vllm/entrypoints/openai/api_server.py | 1 - vllm/executor/uniproc_executor.py | 1 - vllm/model_executor/layers/fused_moe/layer.py | 2 -- vllm/model_executor/models/deepseek_v2.py | 1 - vllm/v1/engine/__init__.py | 1 - vllm/v1/engine/async_llm.py | 1 - vllm/v1/engine/coordinator.py | 1 - vllm/v1/engine/core.py | 4 ---- vllm/v1/engine/core_client.py | 1 - vllm/v1/engine/utils.py | 1 - vllm/v1/executor/multiproc_executor.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 1 - vllm/v1/worker/gpu_worker.py | 1 - 16 files changed, 36 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e0154d48955..5f53c486eac 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1954,7 +1954,6 @@ def has_unfinished_dp(dp_group: "ProcessGroup", aggregated_has_unfinished = bool(tensor.item()) return aggregated_has_unfinished - # eep-dev @staticmethod def sync_kv_cache_memory(dp_group: "ProcessGroup", kv_cache_memory: int) -> None: diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 1a29b3b12a6..b88deb9a84d 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -29,17 +29,14 @@ import time from collections.abc import Sequence from dataclasses import dataclass -# eep-dev from typing import Optional, Union import torch -# eep-dev from torch.distributed import ProcessGroup, all_gather, all_reduce from vllm.config import ParallelConfig from vllm.distributed.parallel_state import (get_ep_group, get_node_count, in_the_same_node_as) -# eep-dev from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -248,7 +245,6 @@ def build( expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) - # eep-dev if global_expert_load is not None: ep_group = get_ep_group().device_group assert global_expert_load.shape == (model.num_moe_layers, @@ -422,7 +418,6 @@ def rearrange(self, logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") - # eep-dev if global_expert_load is None: # This mapping is only used here, so we do not store it in the state physical_expert_start = ep_rank * model.num_local_physical_experts @@ -481,7 +476,6 @@ def rearrange(self, num_replicas = model.num_physical_experts num_groups = model.num_expert_groups if rank_mapping is not None and len(rank_mapping) == ep_group.size(): - # eep-dev # NOTE(yongji): scale down, we need to rebalance the experts on # remaining GPUs, transfer the experts while we haven't shutdown # the GPUs to be released. @@ -496,7 +490,6 @@ def rearrange(self, num_gpus = ep_group.size() if num_gpus % num_nodes != 0: - # eep-dev self.num_nodes = 1 logger.warning_once( f"num_gpus % num_nodes != 0, " @@ -584,7 +577,6 @@ def recv_state() -> tuple[torch.Tensor, torch.Tensor]: return global_expert_load, old_global_expert_indices -# eep-dev def _node_count_with_rank_mapping( pg: Union[ProcessGroup, StatelessProcessGroup], rank_mapping: dict[int, int], diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index bffa4f31d6a..081c887eec4 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -8,7 +8,6 @@ from collections.abc import Iterable, MutableSequence, Sequence from functools import partial -# eep-dev from typing import Optional import torch @@ -129,7 +128,6 @@ def shuffle_layer( dst_global = local2global(dst) if is_received_locally[dst]: continue - # eep-dev if old_indices[src_global] == -1 or new_indices[dst_global] == -1: continue if old_indices[src_global] == new_indices[dst_global]: @@ -144,7 +142,6 @@ def shuffle_layer( experts_send_loc: dict[int, int] = {} for src in range(num_local_experts): expert = old_indices[local2global(src)] - # eep-dev if expert == -1: continue if expert in experts_send_loc: @@ -189,7 +186,6 @@ def shuffle_layer( if is_received_locally[dst]: continue expert = new_indices[local2global(dst)] - # eep-dev if expert == -1: continue if expert in experts_recv_loc: @@ -238,7 +234,6 @@ def shuffle_layer( weight[dst].copy_(buffer[dst]) else: expert = new_indices[local2global(dst)] - # eep-dev if expert == -1: continue src = experts_recv_loc[expert] @@ -252,7 +247,6 @@ def rearrange_expert_weights_inplace( expert_weights: Sequence[Iterable[torch.Tensor]], ep_group: ProcessGroup, is_profile: bool = False, - # eep-dev rank_mapping: Optional[dict[int, int]] = None, ) -> None: """ @@ -273,7 +267,6 @@ def rearrange_expert_weights_inplace( This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. """ - # eep-dev if rank_mapping is not None: if len(rank_mapping) == ep_group.size(): # scale down @@ -341,7 +334,6 @@ def rearrange_expert_weights_inplace( ) -# eep-dev def _map_old_expert_indices_with_rank_mapping( old_global_expert_indices: torch.Tensor, rank_mapping: dict[int, int], @@ -397,7 +389,6 @@ def _map_old_expert_indices_with_rank_mapping( return mapped_expert_indices -# eep-dev def _map_new_expert_indices_with_rank_mapping( new_global_expert_indices: torch.Tensor, rank_mapping: dict[int, int], diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 62866314b02..bbbbe660536 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1047,7 +1047,6 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) -# eep-dev @router.post("/scale", dependencies=[Depends(validate_json_request)]) async def scale(raw_request: Request): try: diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 100f14b57a9..4401db38178 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -63,7 +63,6 @@ def check_health(self) -> None: # it's running. return - # eep-dev def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest) -> None: self.driver_worker.reinitialize_distributed(reconfig_request) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a9882f8429a..384bef6ef02 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -239,7 +239,6 @@ def select_gemm_impl( moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: - # eep-dev # assert self.fused_experts == fused_experts if (prepare_finalize.activation_format == @@ -841,7 +840,6 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels - # eep-dev def update_expert_map(self): # ep_size and ep_rank should already be updated with self.expert_map.device: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5093a0c6487..465eb8c6d53 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -770,7 +770,6 @@ def set_eplb_state( logical_replica_count=logical_replica_count, ) - # eep-dev def update_physical_experts_metadata( self, num_physical_experts: int, diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 78e765615f1..6021055f6e4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -179,7 +179,6 @@ class EngineCoreRequestType(enum.Enum): EXECUTOR_FAILED = b'\x04' -# eep-dev class ReconfigureDistributedRequest(msgspec.Struct): new_data_parallel_size: int new_data_parallel_rank: int # -1 means keep current rank diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index cd8f940a51a..efb3e2c1488 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -610,7 +610,6 @@ async def collective_rpc(self, return await self.engine_core.collective_rpc_async( method, timeout, args, kwargs) - # eep-dev async def wait_for_requests_to_drain(self, drain_timeout: int = 300): """Wait for all requests to be drained.""" start_time = time.time() diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 1601373d93f..99c5455a51f 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -200,7 +200,6 @@ def process_input_socket(self, front_publish_address: str, # Ignore subscription messages. continue - # eep-dev decoded = msgspec.msgpack.decode(buffer) if isinstance(decoded, list) and len( decoded) == 2 and decoded[0] == "SCALE_UP": diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index cca8edf9b83..ab077f1443a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -31,7 +31,6 @@ from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler -# eep-dev from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, UtilityOutput) @@ -79,7 +78,6 @@ def __init__(self, self.model_executor.register_failure_callback( executor_fail_callback) - # eep-dev self.available_gpu_memory_for_kv_cache = -1 # Setup KV Caches and update CacheConfig after profiling. @@ -142,7 +140,6 @@ def _initialize_kv_caches( # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() - # eep-dev if os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None @@ -991,7 +988,6 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) - # eep-dev def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest) -> None: stateless_destroy_torch_distributed_process_group(self.dp_group) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9cd46018fa0..f4fbcaeba7a 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -914,7 +914,6 @@ async def run_engine_stats_update_task(): events = await poller.poll() if not self.engines_running and len(events) == 2 or ( events[0][0] == first_req_rcv_socket): - # eep-dev # Check if this is a regular request notification or # scale up notification buf = first_req_rcv_socket.recv( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index bd79f5a8999..da0dd2c63ba 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -182,7 +182,6 @@ def __init__( } runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) - # eep-dev self.addresses = addresses self.executor_class = executor_class self.log_stats = log_stats diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index fe80f22e966..c5d3926a4e9 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -31,7 +31,6 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port) -# eep-dev from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput @@ -271,7 +270,6 @@ def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return - # eep-dev def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest) -> None: self.collective_rpc("reinitialize_distributed", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c65a73690bf..801e163cb1b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1768,7 +1768,6 @@ def propose_ngram_draft_token_ids( def load_model(self, reconfigure: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) - # eep-dev if reconfigure: from vllm.distributed.parallel_state import get_ep_group num_local_physical_experts = torch.empty(1, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 96c5dfe7f9d..b8625742b38 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -23,7 +23,6 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling -# eep-dev from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput From a9cbefe9f44a27d3a21c3b62876abbe34da85f0d Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:16:54 -0700 Subject: [PATCH 06/18] cleanup Signed-off-by: Rui Qiao --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 5f53c486eac..f31c2a2c619 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1956,7 +1956,7 @@ def has_unfinished_dp(dp_group: "ProcessGroup", @staticmethod def sync_kv_cache_memory(dp_group: "ProcessGroup", - kv_cache_memory: int) -> None: + kv_cache_memory: int) -> int: if kv_cache_memory == -1: kv_cache_memory = torch.iinfo(torch.int64).max tensor = torch.tensor([kv_cache_memory], From 637aca25e551f433edafaf12b688bc020d105a7c Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:26:13 -0700 Subject: [PATCH 07/18] cleanup Signed-off-by: Rui Qiao --- experimental/test_stateless_pg.py | 93 ------------------------------- 1 file changed, 93 deletions(-) delete mode 100644 experimental/test_stateless_pg.py diff --git a/experimental/test_stateless_pg.py b/experimental/test_stateless_pg.py deleted file mode 100644 index 452fe1a8595..00000000000 --- a/experimental/test_stateless_pg.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch -from torch.multiprocessing import spawn - -from vllm.distributed.utils import ( - stateless_destroy_torch_distributed_process_group, - stateless_init_torch_distributed_process_group) - - -def worker_process(rank: int, world_size: int, host: str, port1: int, - port2: int): - torch.cuda.set_device(rank % torch.cuda.device_count()) - - # Create first process group with all workers - pg1 = stateless_init_torch_distributed_process_group(host=host, - port=port1, - rank=rank, - world_size=world_size, - backend="gloo") - - # Create second process group with worldsize-1 workers (excluding last rank) - pg2 = None - if rank < world_size - 1: - pg2 = stateless_init_torch_distributed_process_group( - host=host, - port=port2, - rank=rank, - world_size=world_size - 1, - backend="gloo") - - # Test both groups work simultaneously - tensor1 = torch.tensor([rank], dtype=torch.float32) - torch.distributed.all_reduce(tensor1, group=pg1) - expected1 = sum(range(world_size)) - assert tensor1.item( - ) == expected1, f"PG1 failed: got {tensor1.item()}, expected {expected1}" - print(f"Rank {rank}: PG1 all_reduce passed") - - if pg2 is not None: - tensor2 = torch.tensor([rank], dtype=torch.float32) - torch.distributed.all_reduce(tensor2, group=pg2) - expected2 = sum(range(world_size - 1)) - assert tensor2.item() == expected2, ( - f"PG2 failed: got {tensor2.item()}, expected {expected2}") - print(f"Rank {rank}: PG2 all_reduce passed") - - # Destroy first process group - stateless_destroy_torch_distributed_process_group(pg1) - print(f"Rank {rank}: PG1 destroyed") - - # Last rank exits here - if rank == world_size - 1: - print(f"Rank {rank}: Exiting") - return - - # Test second group still works after destroying - # first group and last rank exit - tensor3 = torch.tensor([rank * 10], dtype=torch.float32) - torch.distributed.all_reduce(tensor3, group=pg2) - expected3 = sum(i * 10 for i in range(world_size - 1)) - assert tensor3.item() == expected3, ( - f"PG2 after PG1 destroy failed: got {tensor3.item()}, " - f"expected {expected3}") - print(f"Rank {rank}: PG2 after PG1 destroy passed") - - # Clean up - if pg2 is not None: - stateless_destroy_torch_distributed_process_group(pg2) - print(f"Rank {rank}: PG2 destroyed") - - -def test_stateless_process_groups(): - assert not torch.distributed.is_initialized( - ), "torch.distributed should not be initialized" - - world_size = 4 - host = "127.0.0.1" - port1 = 29600 - port2 = 29601 - - print(f"Testing stateless process groups with world_size={world_size}") - - spawn(worker_process, - args=(world_size, host, port1, port2), - nprocs=world_size, - join=True) - - print("Test completed successfully!") - - -if __name__ == "__main__": - test_stateless_process_groups() From 26af1a86867d0f9f8713732dd6efcdb199a2898d Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:38:59 -0700 Subject: [PATCH 08/18] ray cleanup Signed-off-by: Rui Qiao --- vllm/v1/executor/multiproc_executor.py | 9 --------- vllm/v1/executor/ray_distributed_executor.py | 8 ++++++++ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index c5d3926a4e9..b06b7cc804d 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -31,7 +31,6 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_mp_context, get_open_port) -from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase @@ -270,14 +269,6 @@ def check_health(self) -> None: self.collective_rpc("check_health", timeout=10) return - def reinitialize_distributed( - self, reconfig_request: ReconfigureDistributedRequest) -> None: - self.collective_rpc("reinitialize_distributed", - args=(reconfig_request, )) - if reconfig_request.new_data_parallel_rank == -2: - self.shutdown() - return - @property def max_concurrent_batches(self) -> int: return self.parallel_config.pipeline_parallel_size diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 257564793cf..f58b4c90ed4 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -6,6 +6,7 @@ from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -60,3 +61,10 @@ def execute_model( # When PP is used, we return a FutureWrapper immediately so that # the scheduler can yield to the next batch. return FutureWrapper(refs[0]) + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self._run_workers("reinitialize_distributed", reconfig_request) + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + return From a804a0f62f821918f0c53b92cc750644c323b018 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 15:47:49 -0700 Subject: [PATCH 09/18] reorg dir Signed-off-by: Rui Qiao --- .../online_serving/elastic_ep}/bench.sh | 0 .../online_serving/elastic_ep/scale.py | 20 ++++++------------- .../elastic_ep}/serve_deepseek_v2.sh | 0 3 files changed, 6 insertions(+), 14 deletions(-) rename {experimental => examples/online_serving/elastic_ep}/bench.sh (100%) rename experimental/test_scale.py => examples/online_serving/elastic_ep/scale.py (66%) rename {experimental => examples/online_serving/elastic_ep}/serve_deepseek_v2.sh (100%) diff --git a/experimental/bench.sh b/examples/online_serving/elastic_ep/bench.sh similarity index 100% rename from experimental/bench.sh rename to examples/online_serving/elastic_ep/bench.sh diff --git a/experimental/test_scale.py b/examples/online_serving/elastic_ep/scale.py similarity index 66% rename from experimental/test_scale.py rename to examples/online_serving/elastic_ep/scale.py index 0c453159889..9e82fcaec1e 100644 --- a/experimental/test_scale.py +++ b/examples/online_serving/elastic_ep/scale.py @@ -18,10 +18,7 @@ def test_scale(host, port, new_dp_size): print(f"Payload: {json.dumps(payload, indent=2)}") try: - response = requests.post(url, - json=payload, - headers=headers, - timeout=300) + response = requests.post(url, json=payload, headers=headers, timeout=300) print(f"Status Code: {response.status_code}") print(f"Response: {response.text}") @@ -39,17 +36,12 @@ def test_scale(host, port, new_dp_size): def main(): - parser = argparse.ArgumentParser( - description="Test scale up/down functionality") + parser = argparse.ArgumentParser(description="Test scale up/down functionality") parser.add_argument("--host", default="localhost", help="API server host") - parser.add_argument("--port", - type=int, - default=8006, - help="API server port") - parser.add_argument("--new_dp_size", - type=int, - default=2, - help="New data parallel size") + parser.add_argument("--port", type=int, default=8006, help="API server port") + parser.add_argument( + "--new_dp_size", type=int, default=2, help="New data parallel size" + ) args = parser.parse_args() diff --git a/experimental/serve_deepseek_v2.sh b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh similarity index 100% rename from experimental/serve_deepseek_v2.sh rename to examples/online_serving/elastic_ep/serve_deepseek_v2.sh From f3c0360b9bd61c0dcd5541504ef9ac56d4dbe6d1 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 16:10:18 -0700 Subject: [PATCH 10/18] minor refactor Signed-off-by: Rui Qiao --- examples/online_serving/elastic_ep/scale.py | 6 +++--- vllm/config.py | 4 ++-- vllm/v1/engine/core.py | 15 +++++++++------ vllm/v1/engine/utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 10 +++++++--- vllm/v1/worker/gpu_worker.py | 4 ++-- 6 files changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/online_serving/elastic_ep/scale.py b/examples/online_serving/elastic_ep/scale.py index 9e82fcaec1e..ae9c1e85eb9 100644 --- a/examples/online_serving/elastic_ep/scale.py +++ b/examples/online_serving/elastic_ep/scale.py @@ -9,7 +9,7 @@ import requests -def test_scale(host, port, new_dp_size): +def scale(host, port, new_dp_size): url = f"http://{host}:{port}/scale" payload = {"new_data_parallel_size": new_dp_size} headers = {"Content-Type": "application/json"} @@ -40,12 +40,12 @@ def main(): parser.add_argument("--host", default="localhost", help="API server host") parser.add_argument("--port", type=int, default=8006, help="API server port") parser.add_argument( - "--new_dp_size", type=int, default=2, help="New data parallel size" + "--new-dp-size", type=int, default=2, help="New data parallel size" ) args = parser.parse_args() - success = test_scale(args.host, args.port, args.new_dp_size) + success = scale(args.host, args.port, args.new_dp_size) sys.exit(0 if success else 1) diff --git a/vllm/config.py b/vllm/config.py index f31c2a2c619..657bb7216cd 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1955,8 +1955,8 @@ def has_unfinished_dp(dp_group: "ProcessGroup", return aggregated_has_unfinished @staticmethod - def sync_kv_cache_memory(dp_group: "ProcessGroup", - kv_cache_memory: int) -> int: + def sync_kv_cache_memory_size(dp_group: "ProcessGroup", + kv_cache_memory: int) -> int: if kv_cache_memory == -1: kv_cache_memory = torch.iinfo(torch.int64).max tensor = torch.tensor([kv_cache_memory], diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ab077f1443a..7cd1942bc72 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -140,11 +140,13 @@ def _initialize_kv_caches( # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() - if os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1": + if os.environ.get("VLLM_EEP_SCALE_UP_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None - kv_cache_memory = ParallelConfig.sync_kv_cache_memory(dp_group, -1) - available_gpu_memory = [kv_cache_memory] * len(kv_cache_specs) + kv_cache_memory_size = ParallelConfig.sync_kv_cache_memory_size( + dp_group, -1) + available_gpu_memory = [kv_cache_memory_size] * \ + len(kv_cache_specs) else: # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. @@ -1015,9 +1017,10 @@ def reinitialize_distributed( self.model_executor.reinitialize_distributed(reconfig_request) if reconfig_request.new_data_parallel_size > old_dp_size: assert self.available_gpu_memory_for_kv_cache > 0 - # broadcast KV cache available memory for _initialize_kv_caches - # on new EngineCore - ParallelConfig.sync_kv_cache_memory( + # pass available_gpu_memory_for_kv_cache from existing + # engine-cores to new engine-cores so they can directly + # use it in _initialize_kv_caches() rather than profiling. + ParallelConfig.sync_kv_cache_memory_size( self.dp_group, self.available_gpu_memory_for_kv_cache) # NOTE(yongji): newly joined workers require dummy_run even # CUDA graph is not used diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index da0dd2c63ba..b166691c321 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -430,7 +430,7 @@ def scale_up(self, old_vllm_config: VllmConfig, new_local_engines = 0 runtime_env = RuntimeEnv(env_vars=self.env_vars_dict - | {"VLLM_EEP_RECONFIGURE_LAUNCH": "1"}) + | {"VLLM_EEP_SCALE_UP_LAUNCH": "1"}) for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 801e163cb1b..163a7c0bcd5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1766,9 +1766,13 @@ def propose_ngram_draft_token_ids( draft_token_ids.append(drafter_output.tolist()) return draft_token_ids - def load_model(self, reconfigure: bool = False) -> None: + def load_model(self, eep_scale_up: bool = False) -> None: + """ + Args: + eep_scale_up: the model loading is for elastic EP scale up. + """ logger.info("Starting to load model %s...", self.model_config.model) - if reconfigure: + if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group num_local_physical_experts = torch.empty(1, dtype=torch.int32, @@ -1776,7 +1780,7 @@ def load_model(self, reconfigure: bool = False) -> None: torch.distributed.broadcast(num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0) - num_local_physical_experts = num_local_physical_experts.item() + num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size global_expert_load, old_global_expert_indices = ( EplbState.recv_state()) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b8625742b38..4bda6886d7a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -183,9 +183,9 @@ def load_model(self) -> None: else: from contextlib import nullcontext context = nullcontext() - reconfigure = os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1" + eep_scale_up = os.environ.get("VLLM_EEP_SCALE_UP_LAUNCH") == "1" with context: - self.model_runner.load_model(reconfigure=reconfigure) + self.model_runner.load_model(eep_scale_up=eep_scale_up) @torch.inference_mode() def determine_available_memory(self) -> int: From 65075360d694500e81f584d06eba23efb2a9b334 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 16:55:11 -0700 Subject: [PATCH 11/18] fix repeated scale up Signed-off-by: Rui Qiao --- vllm/v1/engine/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 7cd1942bc72..f465cc4127e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -143,10 +143,10 @@ def _initialize_kv_caches( if os.environ.get("VLLM_EEP_SCALE_UP_LAUNCH") == "1": dp_group = getattr(self, "dp_group", None) assert dp_group is not None - kv_cache_memory_size = ParallelConfig.sync_kv_cache_memory_size( - dp_group, -1) - available_gpu_memory = [kv_cache_memory_size] * \ - len(kv_cache_specs) + self.available_gpu_memory_for_kv_cache = \ + ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) + available_gpu_memory = [self.available_gpu_memory_for_kv_cache + ] * len(kv_cache_specs) else: # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. From 11feb5c09cfe2bcf0735eea6b274241be49b40f7 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Thu, 10 Jul 2025 17:03:17 -0700 Subject: [PATCH 12/18] move nvshmem.patch Signed-off-by: Rui Qiao --- {experimental => tools/ep_kernels/elastic_ep}/nvshmem.patch | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {experimental => tools/ep_kernels/elastic_ep}/nvshmem.patch (100%) diff --git a/experimental/nvshmem.patch b/tools/ep_kernels/elastic_ep/nvshmem.patch similarity index 100% rename from experimental/nvshmem.patch rename to tools/ep_kernels/elastic_ep/nvshmem.patch From 2ec9ddc248088b4f3df44a9ce8093b5c849e8224 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Sun, 13 Jul 2025 16:01:26 -0700 Subject: [PATCH 13/18] factor out RayDPClient Signed-off-by: Rui Qiao --- vllm/v1/engine/async_llm.py | 6 ---- vllm/v1/engine/core_client.py | 58 +++++------------------------------ 2 files changed, 7 insertions(+), 57 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index efb3e2c1488..7d33ff335d2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -636,12 +636,6 @@ async def scale(self, drain_timeout: Maximum time to wait for requests to drain (seconds) """ - from vllm.v1.engine.core_client import RayDPClient - - if not isinstance(self.engine_core, RayDPClient): - raise NotImplementedError( - "Scale up/down only supported by RayDPClient") - self.scaling = True old_data_parallel_size = \ self.vllm_config.parallel_config.data_parallel_size diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f4fbcaeba7a..ecfaa42725a 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -29,8 +29,7 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.utils import (CoreEngineActorManager, - CoreEngineProcManager, EngineZmqAddresses, - launch_core_engines) + CoreEngineProcManager, launch_core_engines) from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr @@ -94,8 +93,6 @@ def make_async_mp_client( # External load balancer - client per DP rank. return DPAsyncMPClient(*client_args) # Internal load balancer - client balances to all DP ranks. - if parallel_config.data_parallel_backend == "ray": - return RayDPClient(*client_args) return DPLBAsyncMPClient(*client_args) return AsyncMPClient(*client_args) @@ -166,6 +163,12 @@ def dp_engines_running(self) -> bool: running state.""" raise NotImplementedError + async def scale_up(self, new_data_parallel_size: int) -> None: + raise NotImplementedError + + async def scale_down(self, new_data_parallel_size: int) -> None: + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -1067,53 +1070,6 @@ async def _abort_requests(self, request_ids: list[str], await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) - -class RayDPClient(DPAsyncMPClient): - """ - Ray-based client for multi-proc, multi-engine (data parallel) - EngineCore. - """ - - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - client_addresses: Optional[dict[str, str]] = None, - client_index: int = 0, - ): - super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_index) - - def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, - local_start_index: int, input_address: str, - output_address: str, - executor_class: type[Executor], log_stats: bool): - """Self-contained client mode, launch engine and coordinator process - as needed.""" - - parallel_config = vllm_config.parallel_config - assert parallel_config.data_parallel_rank == 0 - assert local_start_index == 0 - - addresses = EngineZmqAddresses( - inputs=[input_address], - outputs=[output_address], - ) - - if len(self.core_engines) > 1: - coordinator = DPCoordinator(parallel_config) - self.resources.coordinator = coordinator - addresses.coordinator_input, addresses.coordinator_output = ( - coordinator.get_engine_socket_addresses()) - - # Start all engines. - self.resources.engine_manager = (CoreEngineActorManager( - vllm_config=vllm_config, - addresses=addresses, - executor_class=executor_class, - log_stats=log_stats)) - async def _send_reconfig_message( self, reconfig_request: ReconfigureDistributedRequest, engine: EngineIdentity) -> asyncio.Future: From 07e6719225f7f0351d527bb28b14c74137b83228 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Mon, 14 Jul 2025 13:18:41 -0700 Subject: [PATCH 14/18] use middleware Signed-off-by: Rui Qiao --- vllm/entrypoints/openai/api_server.py | 58 +++++++++++++++++++-------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 4b4d6db19f2..49994c39851 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -628,11 +628,6 @@ async def create_chat_completion(request: ChatCompletionRequest, return base(raw_request).create_error_response( message="The model does not support Chat Completions API") - if raw_request.app.state.scaling: - raise HTTPException( - status_code=503, - detail="The model is currently scaling. Please try again later.") - generator = await handler.create_chat_completion(request, raw_request) if isinstance(generator, ErrorResponse): @@ -671,11 +666,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Completions API") - if raw_request.app.state.scaling: - raise HTTPException( - status_code=503, - detail="The model is currently scaling. Please try again later.") - try: generator = await handler.create_completion(request, raw_request) except OverflowError as e: @@ -712,11 +702,6 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): return base(raw_request).create_error_response( message="The model does not support Embeddings API") - if raw_request.app.state.scaling: - raise HTTPException( - status_code=503, - detail="The model is currently scaling. Please try again later.") - generator = await handler.create_embedding(request, raw_request) if isinstance(generator, ErrorResponse): @@ -1046,7 +1031,8 @@ async def scale(raw_request: Request): detail="drain_timeout must be a positive integer") # Set scaling flag to prevent new requests - raw_request.app.state.scaling = True + global _scaling_state + _scaling_state = True client = engine_client(raw_request) try: await client.scale(new_data_parallel_size, drain_timeout) @@ -1063,7 +1049,7 @@ async def scale(raw_request: Request): logger.error("Scale failed: %s", e) raise HTTPException(status_code=500, detail="Scale failed") from e finally: - raw_request.app.state.scaling = False + _scaling_state = False # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers @@ -1264,6 +1250,41 @@ async def send_with_request_id(message: Message) -> None: return self.app(scope, receive, send_with_request_id) +# Global variable to track scaling state +_scaling_state = False + + +class ScalingMiddleware: + """ + Middleware that checks if the model is currently scaling and + returns a 503 Service Unavailable response if it is. + + This middleware applies to all HTTP requests and prevents + processing when the model is in a scaling state. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] != "http": + return self.app(scope, receive, send) + + # Check global scaling state + global _scaling_state + if _scaling_state: + # Return 503 Service Unavailable response + response = JSONResponse(content={ + "error": + "The model is currently scaling. Please try again later." + }, + status_code=503) + return response(scope, receive, send) + + return self.app(scope, receive, send) + + def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: @@ -1452,6 +1473,9 @@ async def validation_exception_handler(_: Request, if args.enable_request_id_headers: app.add_middleware(XRequestIdMiddleware) + # Add scaling middleware to check for scaling state + app.add_middleware(ScalingMiddleware) + if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: logger.warning("CAUTION: Enabling log response in the API Server. " "This can include sensitive information and should be " From 0aec9461e7fac31c4de0b601e0aee071c44b760e Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Mon, 14 Jul 2025 16:38:32 -0700 Subject: [PATCH 15/18] rename Signed-off-by: Rui Qiao --- vllm/v1/engine/core.py | 2 +- vllm/v1/engine/core_client.py | 26 ++++++++++++++----------- vllm/v1/engine/utils.py | 36 +++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 28 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f465cc4127e..0a624356c28 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -866,7 +866,7 @@ def _init_data_parallel(self, vllm_config: VllmConfig): local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 - # assert 0 <= local_dp_rank <= dp_rank < dp_size + assert 0 <= local_dp_rank <= dp_rank < dp_size if vllm_config.kv_transfer_config is not None: # modify the engine_id and append the local_dp_rank to it to ensure diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index ecfaa42725a..c4b939ffb5e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1088,10 +1088,12 @@ async def _send_reconfig_message( async def scale_up(self, new_data_parallel_size: int) -> None: """Scale up the data parallel size by creating new engine cores and reconfiguring existing ones.""" - current_dp_size = len(self.core_engines) + cur_data_parallel_size = len(self.core_engines) - if new_data_parallel_size <= current_dp_size: - return + assert new_data_parallel_size > cur_data_parallel_size, ( + f"new_data_parallel_size {new_data_parallel_size} must be greater " + f"than cur_data_parallel_size {cur_data_parallel_size} " + "for scale up") # Phase 1: Send reconfigure messages to all existing engines and wait # for them to be sent @@ -1124,7 +1126,7 @@ async def scale_up(self, new_data_parallel_size: int) -> None: # Create new CoreEngine objects for the new engines new_engine_identities = set() - for i in range(current_dp_size, new_data_parallel_size): + for i in range(cur_data_parallel_size, new_data_parallel_size): new_engine = i.to_bytes(2, "little") self.core_engines.append(new_engine) new_engine_identities.add(new_engine) @@ -1160,17 +1162,19 @@ async def scale_up(self, new_data_parallel_size: int) -> None: async def scale_down(self, new_data_parallel_size: int) -> None: """Scale down the data parallel size by shutting down and reconfiguring existing engine cores.""" - current_dp_size = len(self.core_engines) + cur_data_parallel_size = len(self.core_engines) - if new_data_parallel_size >= current_dp_size: - return + assert new_data_parallel_size <= cur_data_parallel_size, ( + f"new_data_parallel_size {new_data_parallel_size} must be less " + f"than cur_data_parallel_size {cur_data_parallel_size} " + "for scale down") # one for stateless group in EngineCore, one for worker's distributed # world group self.vllm_config.parallel_config.data_parallel_master_port += 2 reconfig_futures = [] - for old_dp_rank, engine in enumerate(self.core_engines): + for cur_dp_rank, engine in enumerate(self.core_engines): reconfig_request = ReconfigureDistributedRequest( new_data_parallel_size=new_data_parallel_size, new_data_parallel_rank=-1, # Keep original rank @@ -1179,20 +1183,20 @@ async def scale_down(self, new_data_parallel_size: int) -> None: data_parallel_master_ip, new_data_parallel_master_port=self.vllm_config.parallel_config. data_parallel_master_port) - if old_dp_rank >= new_data_parallel_size: + if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = -2 future = await self._send_reconfig_message(reconfig_request, engine) reconfig_futures.append(future) - for _ in range(new_data_parallel_size, current_dp_size): + for _ in range(new_data_parallel_size, cur_data_parallel_size): self.core_engines.pop() await asyncio.gather(*reconfig_futures) assert isinstance(self.resources.engine_manager, CoreEngineActorManager) - self.resources.engine_manager.scale_down(current_dp_size, + self.resources.engine_manager.scale_down(cur_data_parallel_size, new_data_parallel_size) self._ensure_stats_update_task() diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index b166691c321..cc5a5116b4f 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -403,7 +403,7 @@ def add_dp_placement_groups( return placement_groups, local_dp_ranks - def scale_up(self, old_vllm_config: VllmConfig, + def scale_up(self, cur_vllm_config: VllmConfig, new_data_parallel_size: int) -> None: import copy @@ -414,19 +414,20 @@ def scale_up(self, old_vllm_config: VllmConfig, from vllm.v1.engine.core import DPEngineCoreActor - old_data_parallel_size = len(self.local_engine_actors) + \ + cur_data_parallel_size = len(self.local_engine_actors) + \ len(self.remote_engine_actors) - assert new_data_parallel_size > old_data_parallel_size, ( - "New data parallel size must be greater than old data parallel " - "size for scale up") + assert new_data_parallel_size > cur_data_parallel_size, ( + f"New data parallel size {new_data_parallel_size} must be greater " + f"than current data parallel size {cur_data_parallel_size} " + "for scale up") placement_groups, local_dp_ranks = \ self.add_dp_placement_groups( - old_vllm_config, new_data_parallel_size) + cur_vllm_config, new_data_parallel_size) - world_size = old_vllm_config.parallel_config.world_size - dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip + world_size = cur_vllm_config.parallel_config.world_size + dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip new_local_engines = 0 runtime_env = RuntimeEnv(env_vars=self.env_vars_dict @@ -434,8 +435,8 @@ def scale_up(self, old_vllm_config: VllmConfig, for i, (pg, local_rank) in enumerate(zip(placement_groups, local_dp_ranks)): - rank = old_data_parallel_size + i - dp_vllm_config = copy.deepcopy(old_vllm_config) + rank = cur_data_parallel_size + i + dp_vllm_config = copy.deepcopy(cur_vllm_config) dp_vllm_config.parallel_config.data_parallel_size = \ new_data_parallel_size dp_vllm_config.parallel_config.placement_group = pg @@ -449,7 +450,7 @@ def scale_up(self, old_vllm_config: VllmConfig, new_local_engines += 1 # Update data_parallel_size_local dp_vllm_config.parallel_config.data_parallel_size_local = ( - old_vllm_config.parallel_config.data_parallel_size_local + + cur_vllm_config.parallel_config.data_parallel_size_local + new_local_engines) actor = ray.remote(DPEngineCoreActor).options( @@ -489,19 +490,22 @@ def scale_up(self, old_vllm_config: VllmConfig, for actor in actors: self.run_refs.append(actor.run.remote()) - old_vllm_config.parallel_config.data_parallel_size = \ + cur_vllm_config.parallel_config.data_parallel_size = \ new_data_parallel_size # Update old_vllm_config with new data_parallel_size_local if any new # local engines were added if new_local_engines > 0: - old_vllm_config.parallel_config.data_parallel_size_local += \ + cur_vllm_config.parallel_config.data_parallel_size_local += \ new_local_engines - def scale_down(self, old_data_parallel_size: int, + def scale_down(self, cur_data_parallel_size: int, new_data_parallel_size: int) -> None: import ray - assert old_data_parallel_size > new_data_parallel_size - for _ in range(old_data_parallel_size - new_data_parallel_size): + assert cur_data_parallel_size > new_data_parallel_size, ( + f"cur_data_parallel_size {cur_data_parallel_size} must be greater " + f"than new_data_parallel_size {new_data_parallel_size} " + "for scale down") + for _ in range(cur_data_parallel_size - new_data_parallel_size): pg = self.created_placement_groups.pop() is_local = self.placement_group_is_local.pop() if is_local: From 91799cdee694253eb205736b3d249d641663dedc Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Mon, 14 Jul 2025 17:18:25 -0700 Subject: [PATCH 16/18] msgspec for SCALE_DP Signed-off-by: Rui Qiao --- vllm/v1/engine/coordinator.py | 6 ++++-- vllm/v1/engine/core_client.py | 22 +++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 99c5455a51f..ab4eba9b86d 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -201,8 +201,8 @@ def process_input_socket(self, front_publish_address: str, continue decoded = msgspec.msgpack.decode(buffer) - if isinstance(decoded, list) and len( - decoded) == 2 and decoded[0] == "SCALE_UP": + if isinstance(decoded, (list, tuple)) and len( + decoded) == 2 and decoded[0] == "SCALE_EP": # Handle scale up notification new_engine_count = decoded[1] current_count = len(self.engines) @@ -230,6 +230,8 @@ def process_input_socket(self, front_publish_address: str, "engines", current_count, new_engine_count) continue # Skip normal engine notification processing + logger.info("Received scale up notification: %s", decoded) + # We received a message on the front-end XPUB socket, # from an API server sending a new request while the # engines are paused, so that we can wake the other diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c4b939ffb5e..f4c75708e8e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -923,15 +923,19 @@ async def run_engine_stats_update_task(): flags=zmq.NOBLOCK).result() # Check if this is a scale up notification - if len(buf) > 4 and buf[4:].startswith(b"SCALE_UP"): - # Extract new engine count from the first 4 bytes - new_engine_count = int.from_bytes( - buf[:4], "little") + decoded = msgspec.msgpack.decode(buf) + if isinstance(decoded, (list, tuple)) and len( + decoded) == 2 and decoded[0] == "SCALE_DP": + # Extract new engine count from the decoded message + new_engine_count = decoded[1] # Send scale up notification to coordinator scale_msg = msgspec.msgpack.encode( - ("SCALE_UP", new_engine_count)) + ("SCALE_DP", new_engine_count)) await socket.send(scale_msg) continue + logger.error( + "Received invalid scale up notification: %s", + decoded) # Regular request notification - send a message to # notify the coordinator that @@ -1148,8 +1152,8 @@ async def scale_up(self, new_data_parallel_size: int) -> None: # Notify coordinator about scale up through existing # stats_update_task connection self._ensure_stats_update_task() - scale_up_marker = (new_data_parallel_size).to_bytes( - 4, "little") + b"SCALE_UP" + scale_up_marker = msgspec.msgpack.encode( + ("SCALE_DP", new_data_parallel_size)) await self.first_req_send_socket.send(scale_up_marker) # Update the parallel config @@ -1200,8 +1204,8 @@ async def scale_down(self, new_data_parallel_size: int) -> None: new_data_parallel_size) self._ensure_stats_update_task() - scale_up_marker = (new_data_parallel_size).to_bytes( - 4, "little") + b"SCALE_UP" + scale_up_marker = msgspec.msgpack.encode( + ("SCALE_DP", new_data_parallel_size)) await self.first_req_send_socket.send(scale_up_marker) self.vllm_config.parallel_config.data_parallel_size = \ From 86bc80db003aab01a97885edb375a7fc84aa3aa8 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Mon, 14 Jul 2025 17:57:26 -0700 Subject: [PATCH 17/18] int32 Signed-off-by: Rui Qiao --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py | 4 +--- vllm/v1/engine/coordinator.py | 2 -- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3a047d5fbf8..0776111d806 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -237,7 +237,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: FusedMoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore - self.topk_indices_dtype = torch.uint32 + self.topk_indices_dtype = None self.moe = moe self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 20ad10304bb..5a23a9f1ab0 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -83,9 +83,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - # FIXME(rui): this needs to be int32, - # see https://github.com/vllm-project/vllm/pull/20166 - return torch.uint32 + return torch.int32 def num_dispatchers(self) -> int: return self.num_dispatchers_ diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index ab4eba9b86d..b10f06bf8f2 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -230,8 +230,6 @@ def process_input_socket(self, front_publish_address: str, "engines", current_count, new_engine_count) continue # Skip normal engine notification processing - logger.info("Received scale up notification: %s", decoded) - # We received a message on the front-end XPUB socket, # from an API server sending a new request while the # engines are paused, so that we can wake the other From d504cbbda6e067618971f90fced8926228232db3 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Tue, 15 Jul 2025 15:21:45 -0700 Subject: [PATCH 18/18] up Signed-off-by: Rui Qiao --- vllm/distributed/eplb/rebalance_execute.py | 13 ++++--------- vllm/engine/protocol.py | 7 +++++++ vllm/model_executor/layers/fused_moe/layer.py | 3 --- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 081c887eec4..f8a7d1170bb 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -266,6 +266,7 @@ def rearrange_expert_weights_inplace( is_profile (bool): If `True`, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. + rank_mapping: A dictionary mapping old rank to new rank. """ if rank_mapping is not None: if len(rank_mapping) == ep_group.size(): @@ -353,10 +354,7 @@ def _map_old_expert_indices_with_rank_mapping( (num_layers, new_ep_size * num_local_physical_experts). """ num_layers, old_num_physical_experts = old_global_expert_indices.shape - - if not rank_mapping: - # If no rank mapping, return the original tensor - return old_global_expert_indices + assert rank_mapping, "Rank mapping is required" # Get sizes from parameters and rank_mapping old_ep_size = len(rank_mapping) @@ -375,7 +373,7 @@ def _map_old_expert_indices_with_rank_mapping( for old_rank in range(old_ep_size): new_rank = rank_mapping.get(old_rank) if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size: - # This old rank exists in the new world + # This old rank exists in the new configuration old_start_idx = old_rank * num_local_physical_experts old_end_idx = (old_rank + 1) * num_local_physical_experts new_start_idx = new_rank * num_local_physical_experts @@ -394,10 +392,7 @@ def _map_new_expert_indices_with_rank_mapping( rank_mapping: dict[int, int], ) -> torch.Tensor: num_layers, new_num_physical_experts = new_global_expert_indices.shape - - if not rank_mapping: - # If no rank mapping, return the original tensor - return new_global_expert_indices + assert rank_mapping, "Rank mapping is required" # Get sizes from parameters and rank_mapping old_ep_size = len(rank_mapping) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 8688fcc82cd..e7f11901c4d 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -324,3 +324,10 @@ async def is_sleeping(self) -> bool: async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" ... + + @abstractmethod + async def scale(self, + new_data_parallel_size: int, + drain_timeout: int = 300) -> None: + """Scale the engine""" + ... diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0776111d806..6e5d6de6351 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -252,9 +252,6 @@ def select_gemm_impl( prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: - - # assert self.fused_experts == fused_experts - if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe)