diff --git a/examples/online_serving/elastic_ep/bench.sh b/examples/online_serving/elastic_ep/bench.sh new file mode 100644 index 00000000000..c81a2eab07a --- /dev/null +++ b/examples/online_serving/elastic_ep/bench.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite-Chat" +HOST="localhost" +PORT=8006 + +vllm bench serve \ + --model $MODEL_NAME \ + --host $HOST \ + --port $PORT \ + --num-prompts 5 diff --git a/examples/online_serving/elastic_ep/scale.py b/examples/online_serving/elastic_ep/scale.py new file mode 100644 index 00000000000..ae9c1e85eb9 --- /dev/null +++ b/examples/online_serving/elastic_ep/scale.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import json +import sys + +import requests + + +def scale(host, port, new_dp_size): + url = f"http://{host}:{port}/scale" + payload = {"new_data_parallel_size": new_dp_size} + headers = {"Content-Type": "application/json"} + + print(f"Sending scale request to {url}") + print(f"Payload: {json.dumps(payload, indent=2)}") + + try: + response = requests.post(url, json=payload, headers=headers, timeout=300) + + print(f"Status Code: {response.status_code}") + print(f"Response: {response.text}") + + if response.status_code == 200: + print("Scale up/down request successful!") + return True + else: + print("Scale up/down request failed!") + return False + + except requests.exceptions.RequestException as e: + print(f"Request failed: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Test scale up/down functionality") + parser.add_argument("--host", default="localhost", help="API server host") + parser.add_argument("--port", type=int, default=8006, help="API server port") + parser.add_argument( + "--new-dp-size", type=int, default=2, help="New data parallel size" + ) + + args = parser.parse_args() + + success = scale(args.host, args.port, args.new_dp_size) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/elastic_ep/serve_deepseek_v2.sh b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh new file mode 100644 index 00000000000..8f1d7cf6bb9 --- /dev/null +++ b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Serve DeepSeek V2 model with vLLM +# This script demonstrates how to serve the DeepSeek V2 model using vLLM's V1 engine + +# MODEL_NAME="gaunernst/DeepSeek-V2-Lite-Chat-FP8" +MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite-Chat" +HOST="0.0.0.0" +PORT=8006 + +DATA_PARALLEL_SIZE=3 +DATA_PARALLEL_SIZE_LOCAL=$DATA_PARALLEL_SIZE + +export VLLM_USE_V1=1 +export VLLM_ALL2ALL_BACKEND="pplx" +export VLLM_USE_DEEP_GEMM=1 + +# Launch the vLLM server +vllm serve $MODEL_NAME --trust-remote-code \ + --disable-log-requests \ + --host $HOST \ + --port $PORT \ + --tensor-parallel-size 1 \ + --enable-expert-parallel \ + --enable-eplb \ + --num-redundant-experts 32 \ + --enforce-eager \ + --data-parallel-backend ray \ + --data-parallel-size $DATA_PARALLEL_SIZE \ + --data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \ + --data-parallel-start-rank 0 \ No newline at end of file diff --git a/tools/ep_kernels/elastic_ep/nvshmem.patch b/tools/ep_kernels/elastic_ep/nvshmem.patch new file mode 100644 index 00000000000..5ebdaea58dd --- /dev/null +++ b/tools/ep_kernels/elastic_ep/nvshmem.patch @@ -0,0 +1,92 @@ +From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001 +From: Yongji Wu +Date: Tue, 20 May 2025 13:41:12 -0700 +Subject: [PATCH] fix reinit issues due to states not cleaned up + +fix double free +--- + src/host/init/init.cu | 10 ++++++++++ + .../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++ + src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++ + 3 files changed, 30 insertions(+) + +diff --git a/src/host/init/init.cu b/src/host/init/init.cu +index b1c5dbf..1fecb4b 100644 +--- a/src/host/init/init.cu ++++ b/src/host/init/init.cu +@@ -43,6 +43,8 @@ + #include "internal/host/nvshmemi_types.h" + #include "internal/host/shared_memory.h" + #include "internal/host/nvshmemi_symmetric_heap.hpp" ++// eep-dev ++#include "internal/host/nvshmemi_mem_transport.hpp" + + extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d; + static std::map registered_device_states; +@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) { + /* Multi-init Multi-fini*/ + nvshmemi_state = NULL; + nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0; ++ ++ // eep-dev ++ nvshmemi_mem_p2p_transport::destroy_instance(); ++ nvshmemi_mem_remote_transport::destroy_instance(); ++ free(nvshmemi_default_session); ++ nvshmemi_default_session = nullptr; ++ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false; ++ + nvshmemi_is_device_state_ready = false; + } else + nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle); +diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp +index 2495844..e4f408a 100644 +--- a/src/include/internal/host/nvshmemi_mem_transport.hpp ++++ b/src/include/internal/host/nvshmemi_mem_transport.hpp +@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final { + return p2p_objref_; + } + } ++ // eep-dev ++ static void destroy_instance(void) { ++ if (p2p_objref_ != nullptr) { ++ delete p2p_objref_; ++ p2p_objref_ = nullptr; ++ } ++ } + + void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj); + +@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final { + } + } + ++ // eep-dev ++ static void destroy_instance(void) { ++ if (remote_objref_ != nullptr) { ++ delete remote_objref_; ++ remote_objref_ = nullptr; ++ } ++ } ++ + int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size); + /* On-demand registration and release of memory */ + int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx, +diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp +index a1fa748..788fa96 100644 +--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp ++++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp +@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi + // Discover the network for bootstrap, if not done previously. + // This code needs to be stateful to be able to be called multiple times by the caller + BOOTSTRAP_CHECK(bootstrap_net_init()); ++ // eep-dev ++ if (handle->pre_init_ops != nullptr) { ++ BOOTSTRAP_PTR_FREE(handle->pre_init_ops); ++ handle->pre_init_ops = nullptr; ++ } + if (handle->pre_init_ops == nullptr) { + BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1); + handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id; +-- +2.43.0 + diff --git a/vllm/config.py b/vllm/config.py index 6c56ac1eec8..46408ae7ff6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2001,6 +2001,19 @@ def has_unfinished_dp(dp_group: "ProcessGroup", aggregated_has_unfinished = bool(tensor.item()) return aggregated_has_unfinished + @staticmethod + def sync_kv_cache_memory_size(dp_group: "ProcessGroup", + kv_cache_memory: int) -> int: + if kv_cache_memory == -1: + kv_cache_memory = torch.iinfo(torch.int64).max + tensor = torch.tensor([kv_cache_memory], + dtype=torch.int64, + device="cpu") + # we cannot use broadcast for stateless dp group since it depends + # on global rank + torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) + return tensor.item() + def compute_hash(self): """ Provide a hash that uniquely identifies all the configs diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 6b0a126ca9b..b88deb9a84d 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -29,12 +29,15 @@ import time from collections.abc import Sequence from dataclasses import dataclass +from typing import Optional, Union import torch -from torch.distributed import all_gather, all_reduce +from torch.distributed import ProcessGroup, all_gather, all_reduce from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_ep_group, get_node_count +from vllm.distributed.parallel_state import (get_ep_group, get_node_count, + in_the_same_node_as) +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -172,6 +175,9 @@ def build( model: MixtureOfExperts, device: torch.device, parallel_config: ParallelConfig, + global_expert_load: Optional[torch.Tensor] = None, + old_global_expert_indices: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None, ) -> "EplbState": """ Build the initial EPLB state. @@ -185,8 +191,12 @@ def build( physical_to_logical_map_list, device=device, ) + # TODO(yongji): hard-wired to make sure the tensor does not get resized + MAX_PHYSICAL_EXPERT_FACTOR = 2 + max_slots_per_logical_expert = (model.num_logical_experts * + MAX_PHYSICAL_EXPERT_FACTOR) logical_to_physical_map = torch.full( - (model.num_logical_experts, model.num_redundant_experts + 1), + (model.num_logical_experts, max_slots_per_logical_expert), -1, device=device, ) @@ -235,11 +245,63 @@ def build( expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) + if global_expert_load is not None: + ep_group = get_ep_group().device_group + assert global_expert_load.shape == (model.num_moe_layers, + model.num_logical_experts) + assert global_expert_load.dtype == torch.int64 + + num_replicas = model.num_physical_experts + num_groups = model.num_expert_groups + num_nodes = get_node_count() + num_gpus = ep_group.size() + + if num_gpus % num_nodes != 0: + num_nodes = 1 + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert max_physical_slots <= logical_to_physical_map.shape[-1] + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, logical_to_physical_map.shape[-1] - max_physical_slots), + value=-1, + ) + physical_to_logical_map = new_physical_to_logical_map.to(device) + logical_to_physical_map.copy_(new_logical_to_physical_map) + logical_replica_count.copy_(new_logical_replica_count) + model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) + if global_expert_load is not None: + rearrange_expert_weights_inplace( + old_global_expert_indices, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + False, + rank_mapping, + ) + expert_rearrangement_step = 0 return cls( physical_to_logical_map, @@ -337,7 +399,10 @@ def step(self, def rearrange(self, model: MixtureOfExperts, - is_profile: bool = False) -> None: + is_profile: bool = False, + execute_shuffle: bool = True, + global_expert_load: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None) -> None: """ Rearrange the experts according to the current load. """ @@ -353,42 +418,79 @@ def rearrange(self, logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") - # This mapping is only used here, so we do not store it in the state - physical_expert_start = ep_rank * model.num_local_physical_experts - physical_expert_end = (physical_expert_start + - model.num_local_physical_experts) - # (num_moe_layers, num_local_physical_experts) - local_physical_to_logical_map = self.physical_to_logical_map[ - :, - physical_expert_start:physical_expert_end, - ] + if global_expert_load is None: + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] - # Map the local physical expert load to global logical experts - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - model.num_moe_layers, - model.num_logical_experts, - dtype=self.expert_load_window.dtype, - device=self.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=local_physical_to_logical_map.unsqueeze(0).expand_as( - self.expert_load_window).long(), - src=self.expert_load_window, - ) + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) - # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = logical_expert_load_window.sum(dim=0) - all_reduce(global_expert_load_window, group=ep_group) + if not execute_shuffle: + metadata = torch.tensor( + [ + model.num_moe_layers, model.num_logical_experts, + self.physical_to_logical_map.shape[1] + ], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.broadcast(metadata, + group=get_ep_group().cpu_group, + group_src=0) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) + + if not execute_shuffle: + # (num_moe_layers, old_num_physical_experts) + old_global_expert_indices = self.physical_to_logical_map + torch.distributed.broadcast(old_global_expert_indices, + group=ep_group, + group_src=0) + return global_expert_load_window + else: + assert execute_shuffle + global_expert_load_window = global_expert_load # TODO(bowen): Treat differently for prefill and decode nodes num_replicas = model.num_physical_experts num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() + if rank_mapping is not None and len(rank_mapping) == ep_group.size(): + # NOTE(yongji): scale down, we need to rebalance the experts on + # remaining GPUs, transfer the experts while we haven't shutdown + # the GPUs to be released. + cpu_group = get_ep_group().cpu_group + num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) + num_gpus = sum(new_rank != -1 + for new_rank in rank_mapping.values()) + num_replicas = num_replicas // ep_group.size( + ) * num_gpus # handle num replicas change + else: + num_nodes = get_node_count() + num_gpus = ep_group.size() if num_gpus % num_nodes != 0: + self.num_nodes = 1 logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" @@ -414,10 +516,24 @@ def rearrange(self, model.expert_weights, ep_group, is_profile, + rank_mapping, ) if not is_profile: - self.physical_to_logical_map.copy_(new_physical_to_logical_map) + if self.physical_to_logical_map.shape[ + 1] != new_physical_to_logical_map.shape[1]: + self.physical_to_logical_map = new_physical_to_logical_map.to( + self.physical_to_logical_map.device) + else: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert max_physical_slots <= self.logical_to_physical_map.shape[-1] + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, + self.logical_to_physical_map.shape[-1] - max_physical_slots), + value=-1, + ) self.logical_to_physical_map.copy_(new_logical_to_physical_map) self.logical_replica_count.copy_(new_logical_replica_count) @@ -430,3 +546,69 @@ def rearrange(self, " (profile) " if is_profile else " ", time_end - time_start, ) + + @staticmethod + def recv_state() -> tuple[torch.Tensor, torch.Tensor]: + """ + Receive the expert load and old placement from the master rank. + """ + ep_group = get_ep_group() + metadata = torch.empty(3, dtype=torch.int32, device="cpu") + torch.distributed.broadcast(metadata, + group=ep_group.cpu_group, + group_src=0) + num_moe_layers, num_logical_experts, num_old_physical_experts = ( + metadata.tolist()) + global_expert_load = torch.zeros( + (num_moe_layers, num_logical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + all_reduce(global_expert_load, group=ep_group.device_group) + old_global_expert_indices = torch.empty( + (num_moe_layers, num_old_physical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + torch.distributed.broadcast(old_global_expert_indices, + group=ep_group.device_group, + group_src=0) + + return global_expert_load, old_global_expert_indices + + +def _node_count_with_rank_mapping( + pg: Union[ProcessGroup, StatelessProcessGroup], + rank_mapping: dict[int, int], +) -> int: + if isinstance(pg, ProcessGroup): + world_size = torch.distributed.get_world_size(group=pg) + else: + world_size = pg.world_size + + if world_size == 1: + return 1 + + # Build node assignment map + node_assignment = [0] * world_size # rank -> node_id + next_node_id = 0 + + for current_rank in range(world_size): + if node_assignment[current_rank] != 0: + continue # Already assigned to a node + + assert current_rank in rank_mapping + if rank_mapping[current_rank] == -1: + continue # Pending shutdown + + # Assign current rank to a new node + next_node_id += 1 + node_assignment[current_rank] = next_node_id + + # Find all ranks on the same node as current_rank + same_node_flags = in_the_same_node_as(pg, current_rank) + for other_rank, is_same_node in enumerate(same_node_flags): + if is_same_node and node_assignment[other_rank] == 0: + node_assignment[other_rank] = next_node_id + + return next_node_id diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 2ef8587b559..f8a7d1170bb 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -8,6 +8,7 @@ from collections.abc import Iterable, MutableSequence, Sequence from functools import partial +from typing import Optional import torch from torch.distributed import (P2POp, ProcessGroup, all_gather, @@ -127,6 +128,8 @@ def shuffle_layer( dst_global = local2global(dst) if is_received_locally[dst]: continue + if old_indices[src_global] == -1 or new_indices[dst_global] == -1: + continue if old_indices[src_global] == new_indices[dst_global]: is_received_locally[dst] = True for weight, buffer in zip(expert_weights, @@ -139,6 +142,8 @@ def shuffle_layer( experts_send_loc: dict[int, int] = {} for src in range(num_local_experts): expert = old_indices[local2global(src)] + if expert == -1: + continue if expert in experts_send_loc: continue experts_send_loc[expert] = src @@ -181,6 +186,8 @@ def shuffle_layer( if is_received_locally[dst]: continue expert = new_indices[local2global(dst)] + if expert == -1: + continue if expert in experts_recv_loc: continue experts_recv_loc[expert] = dst @@ -227,6 +234,8 @@ def shuffle_layer( weight[dst].copy_(buffer[dst]) else: expert = new_indices[local2global(dst)] + if expert == -1: + continue src = experts_recv_loc[expert] for weight, buffer in zip(expert_weights, expert_weights_buffer): weight[dst].copy_(buffer[src]) @@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace( expert_weights: Sequence[Iterable[torch.Tensor]], ep_group: ProcessGroup, is_profile: bool = False, + rank_mapping: Optional[dict[int, int]] = None, ) -> None: """ Rearranges the expert weights in place according to the new expert indices. @@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace( is_profile (bool): If `True`, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. + rank_mapping: A dictionary mapping old rank to new rank. """ + if rank_mapping is not None: + if len(rank_mapping) == ep_group.size(): + # scale down + new_global_expert_indices = \ + _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices, + rank_mapping, + ) + else: + # scale up + old_global_expert_indices = \ + _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices, + rank_mapping, + ep_group.size(), + ) + + assert old_global_expert_indices.shape[ + 1] == new_global_expert_indices.shape[1] + num_moe_layers, num_physical_experts = old_global_expert_indices.shape assert len(expert_weights) == num_moe_layers @@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace( ) +def _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices: torch.Tensor, + rank_mapping: dict[int, int], + new_ep_size: int, +) -> torch.Tensor: + """ + Map the old global expert indices to the new global expert indices. + + Args: + old_global_expert_indices: + Shape (num_layers, old_ep_size * num_local_physical_experts). + rank_mapping: Mapping from old rank to new rank. + new_ep_size: New expert parallelism size. + + Returns: + Mapped expert indices with shape + (num_layers, new_ep_size * num_local_physical_experts). + """ + num_layers, old_num_physical_experts = old_global_expert_indices.shape + assert rank_mapping, "Rank mapping is required" + + # Get sizes from parameters and rank_mapping + old_ep_size = len(rank_mapping) + num_local_physical_experts = old_num_physical_experts // old_ep_size + new_num_physical_experts = new_ep_size * num_local_physical_experts + + # Create mapped tensor with new shape, initialized to -1 + mapped_expert_indices = torch.full( + (num_layers, new_num_physical_experts), + fill_value=-1, + dtype=old_global_expert_indices.dtype, + device=old_global_expert_indices.device, + ) + + # Handle rank mapping (scale up/down with rank changes) + for old_rank in range(old_ep_size): + new_rank = rank_mapping.get(old_rank) + if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size: + # This old rank exists in the new configuration + old_start_idx = old_rank * num_local_physical_experts + old_end_idx = (old_rank + 1) * num_local_physical_experts + new_start_idx = new_rank * num_local_physical_experts + new_end_idx = (new_rank + 1) * num_local_physical_experts + + mapped_expert_indices[:, new_start_idx:new_end_idx] = \ + old_global_expert_indices[:, old_start_idx:old_end_idx] + # If new_rank is None or >= new_ep_size, the experts remain -1 + # (scale down case) + + return mapped_expert_indices + + +def _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices: torch.Tensor, + rank_mapping: dict[int, int], +) -> torch.Tensor: + num_layers, new_num_physical_experts = new_global_expert_indices.shape + assert rank_mapping, "Rank mapping is required" + + # Get sizes from parameters and rank_mapping + old_ep_size = len(rank_mapping) + new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values()) + num_local_physical_experts = new_num_physical_experts // new_ep_size + old_num_physical_experts = old_ep_size * num_local_physical_experts + + mapped_expert_indices = torch.full( + (num_layers, old_num_physical_experts), + fill_value=-1, + dtype=new_global_expert_indices.dtype, + device=new_global_expert_indices.device, + ) + + for old_rank in range(old_ep_size): + new_rank = rank_mapping[old_rank] + if new_rank >= 0 and new_rank < new_ep_size: + old_start_idx = old_rank * num_local_physical_experts + old_end_idx = (old_rank + 1) * num_local_physical_experts + new_start_idx = new_rank * num_local_physical_experts + new_end_idx = (new_rank + 1) * num_local_physical_experts + + mapped_expert_indices[:, old_start_idx:old_end_idx] = \ + new_global_expert_indices[:, new_start_idx:new_end_idx] + + return mapped_expert_indices + + __all__ = ["rearrange_expert_weights_inplace"] diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 8688fcc82cd..e7f11901c4d 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -324,3 +324,10 @@ async def is_sleeping(self) -> bool: async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" ... + + @abstractmethod + async def scale(self, + new_data_parallel_size: int, + drain_timeout: int = 300) -> None: + """Scale the engine""" + ... diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 65ceeff8eb4..17235afadbc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1005,6 +1005,53 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) +@router.post("/scale", dependencies=[Depends(validate_json_request)]) +async def scale(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, + detail="Invalid JSON format") from e # noqa: B904 + + new_data_parallel_size = body.get("new_data_parallel_size") + drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes + + if new_data_parallel_size is None: + raise HTTPException(status_code=400, + detail="new_data_parallel_size is required") + + if not isinstance(new_data_parallel_size, + int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, + detail="new_data_parallel_size must be a positive integer") + + if not isinstance(drain_timeout, int) or drain_timeout <= 0: + raise HTTPException(status_code=400, + detail="drain_timeout must be a positive integer") + + # Set scaling flag to prevent new requests + global _scaling_state + _scaling_state = True + client = engine_client(raw_request) + try: + await client.scale(new_data_parallel_size, drain_timeout) + return JSONResponse({ + "message": + f"Scaled to {new_data_parallel_size} " + "data parallel engines", + }) + except TimeoutError as e: + raise HTTPException(status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds") from e + except Exception as e: + logger.error("Scale failed: %s", e) + raise HTTPException(status_code=500, detail="Scale failed") from e + finally: + _scaling_state = False + + # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) RequestType = Any @@ -1203,6 +1250,41 @@ async def send_with_request_id(message: Message) -> None: return self.app(scope, receive, send_with_request_id) +# Global variable to track scaling state +_scaling_state = False + + +class ScalingMiddleware: + """ + Middleware that checks if the model is currently scaling and + returns a 503 Service Unavailable response if it is. + + This middleware applies to all HTTP requests and prevents + processing when the model is in a scaling state. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] != "http": + return self.app(scope, receive, send) + + # Check global scaling state + global _scaling_state + if _scaling_state: + # Return 503 Service Unavailable response + response = JSONResponse(content={ + "error": + "The model is currently scaling. Please try again later." + }, + status_code=503) + return response(scope, receive, send) + + return self.app(scope, receive, send) + + def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: @@ -1391,6 +1473,9 @@ async def validation_exception_handler(_: Request, if args.enable_request_id_headers: app.add_middleware(XRequestIdMiddleware) + # Add scaling middleware to check for scaling state + app.add_middleware(ScalingMiddleware) + if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: logger.warning("CAUTION: Enabling log response in the API Server. " "This can include sensitive information and should be " @@ -1600,6 +1685,7 @@ async def init_app_state( state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0 + state.scaling = False def create_server_socket(addr: tuple[str, int]) -> socket.socket: diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7ebeb4a2255..4401db38178 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, run_method) +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -62,6 +63,13 @@ def check_health(self) -> None: # it's running. return + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self.driver_worker.reinitialize_distributed(reconfig_request) + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + return + UniProcExecutorAsync = UniProcExecutor diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da772c11155..6e5d6de6351 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -252,9 +252,6 @@ def select_gemm_impl( prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: - - assert self.fused_experts == fused_experts - if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) @@ -362,8 +359,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) return self.forward( x=x, @@ -380,7 +379,12 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def forward_cuda( self, @@ -399,6 +403,10 @@ def forward_cuda( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( @@ -412,7 +420,12 @@ def forward_cuda( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts( @@ -753,7 +766,8 @@ def __init__( if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod): + if not isinstance(quant_method, + (Fp8MoEMethod, UnquantizedFusedMoEMethod)): # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API @@ -837,6 +851,14 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + def update_expert_map(self): + # ep_size and ep_rank should already be updated + with self.expert_map.device: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8d36dda65b5..5106b9914b5 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -776,6 +776,24 @@ def set_eplb_state( logical_replica_count=logical_replica_count, ) + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, DeepseekV2MoE): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -931,9 +949,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): + if (hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx+i}."): diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 92ecb8972d5..136c1e11e7b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -588,6 +588,13 @@ def set_eplb_state( """ ... + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + ... + def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: return isinstance(model, MixtureOfExperts) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cd..6021055f6e4 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -177,3 +177,12 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b'\x03' # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b'\x04' + + +class ReconfigureDistributedRequest(msgspec.Struct): + new_data_parallel_size: int + new_data_parallel_rank: int # -1 means keep current rank + new_data_parallel_rank_local: int # -1 means keep current local rank + # for NCCL/GLOO initialization + new_data_parallel_master_ip: str + new_data_parallel_master_port: int \ No newline at end of file diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaa..7d33ff335d2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import time from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Any, Optional, Union @@ -132,6 +133,7 @@ def __init__( for stat_logger in self.stat_loggers[0]: stat_logger.log_engine_initialized() self.output_handler: Optional[asyncio.Task] = None + self.scaling = False try: # Start output handler eagerly if we are in the asyncio eventloop. asyncio.get_running_loop() @@ -608,6 +610,68 @@ async def collective_rpc(self, return await self.engine_core.collective_rpc_async( method, timeout, args, kwargs) + async def wait_for_requests_to_drain(self, drain_timeout: int = 300): + """Wait for all requests to be drained.""" + start_time = time.time() + while time.time() - start_time < drain_timeout: + if not self.engine_core.dp_engines_running(): + logger.info("Engines are idle, requests have been drained") + return + + logger.info( + "Engines are still running, waiting for requests to drain...") + await asyncio.sleep(1) # Wait 1 second before checking again + + raise TimeoutError(f"Timeout reached after {drain_timeout} seconds " + "waiting for requests to drain.") + + async def scale(self, + new_data_parallel_size: int, + drain_timeout: int = 300): + """ + Scale up or down the data parallel size by adding or removing + engine cores. + Args: + new_data_parallel_size: The new number of data parallel workers + drain_timeout: + Maximum time to wait for requests to drain (seconds) + """ + self.scaling = True + old_data_parallel_size = \ + self.vllm_config.parallel_config.data_parallel_size + try: + logger.info( + "Waiting for requests to drain before " + "scaling up to %s engines...", new_data_parallel_size) + await self.wait_for_requests_to_drain(drain_timeout) + logger.info( + "Requests have been drained, proceeding with scale " + "to %s engines", new_data_parallel_size) + if new_data_parallel_size > old_data_parallel_size: + await self.engine_core.scale_up(new_data_parallel_size) + else: + await self.engine_core.scale_down(new_data_parallel_size) + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + + # recreate stat loggers + if new_data_parallel_size > old_data_parallel_size: + stat_loggers: list[ + list[StatLoggerBase]] = setup_default_loggers( + vllm_config=self.vllm_config, + log_stats=self.log_stats, + engine_num=new_data_parallel_size, + custom_stat_loggers=None, + ) + num_new_engines = len(stat_loggers) - len(self.stat_loggers) + self.stat_loggers.extend(stat_loggers[-num_new_engines:]) + else: + for _ in range(old_data_parallel_size - + new_data_parallel_size): + self.stat_loggers.pop() + finally: + self.scaling = False + @property def is_running(self) -> bool: # Is None before the loop is started. diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index b3e7a2e85b8..b10f06bf8f2 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -200,11 +200,41 @@ def process_input_socket(self, front_publish_address: str, # Ignore subscription messages. continue + decoded = msgspec.msgpack.decode(buffer) + if isinstance(decoded, (list, tuple)) and len( + decoded) == 2 and decoded[0] == "SCALE_EP": + # Handle scale up notification + new_engine_count = decoded[1] + current_count = len(self.engines) + if new_engine_count > current_count: + for _ in range(new_engine_count - current_count): + self.engines.append(EngineState()) + # NOTE(yongji): handle the case + # where newly started engines have current_wave = 0 + # if existing engines just finished a wave + # and engine_running isn't updated yet at + # CoordinatorProc requests routed to newly started + # engines may not wake up existing engines, as long + # as 0 < request.wave < existing engines' + # current_wave + # we note that 0 is the wave number for the new + # engine + self.engines_running = False + logger.info( + "DPCoordinator scaled up from %s to %s " + "engines", current_count, new_engine_count) + else: + self.engines = self.engines[:new_engine_count] + logger.info( + "DPCoordinator scaled down from %s to %s " + "engines", current_count, new_engine_count) + continue # Skip normal engine notification processing + # We received a message on the front-end XPUB socket, # from an API server sending a new request while the # engines are paused, so that we can wake the other # engines. - engine_to_exclude, wave = msgspec.msgpack.decode(buffer) + engine_to_exclude, wave = decoded if not self.engines_running: if wave < self.current_wave: # If the wave number is stale, ensure the message diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f5c59bef478..61efcc22cc8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -32,7 +32,8 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, UtilityOutput) + EngineCoreRequestType, + ReconfigureDistributedRequest, UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor @@ -77,6 +78,8 @@ def __init__(self, self.model_executor.register_failure_callback( executor_fail_callback) + self.available_gpu_memory_for_kv_cache = -1 + # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ self._initialize_kv_caches(vllm_config) @@ -137,12 +140,23 @@ def _initialize_kv_caches( # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() - # Profiles the peak memory usage of the model to determine how much - # memory can be allocated for kv cache. has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) if has_kv_cache: - available_gpu_memory = \ - self.model_executor.determine_available_memory() + if os.environ.get("VLLM_EEP_SCALE_UP_LAUNCH") == "1": + dp_group = getattr(self, "dp_group", None) + assert dp_group is not None + self.available_gpu_memory_for_kv_cache = \ + ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) + available_gpu_memory = [ + self.available_gpu_memory_for_kv_cache + ] * len(kv_cache_specs) + else: + # Profiles the peak memory usage of the model to determine how + # much memory can be allocated for kv cache. + available_gpu_memory = ( + self.model_executor.determine_available_memory()) + self.available_gpu_memory_for_kv_cache = \ + available_gpu_memory[0] else: # Attention free models don't need memory for kv cache available_gpu_memory = [0] * len(kv_cache_specs) @@ -983,6 +997,48 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + stateless_destroy_torch_distributed_process_group(self.dp_group) + self.shutdown() + + parallel_config = self.vllm_config.parallel_config + old_dp_size = parallel_config.data_parallel_size + parallel_config.data_parallel_size = \ + reconfig_request.new_data_parallel_size + if reconfig_request.new_data_parallel_rank != -1: + parallel_config.data_parallel_rank = \ + reconfig_request.new_data_parallel_rank + # local rank specifies device visibility, it should not be changed + assert reconfig_request.new_data_parallel_rank_local == -1 + parallel_config.data_parallel_master_ip = \ + reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port + if reconfig_request.new_data_parallel_rank != -2: + self.dp_rank = parallel_config.data_parallel_rank + self.dp_group = parallel_config.stateless_init_dp_group() + reconfig_request.new_data_parallel_master_port = \ + parallel_config.data_parallel_master_port + + self.model_executor.reinitialize_distributed(reconfig_request) + if reconfig_request.new_data_parallel_size > old_dp_size: + assert self.available_gpu_memory_for_kv_cache > 0 + # pass available_gpu_memory_for_kv_cache from existing + # engine-cores to new engine-cores so they can directly + # use it in _initialize_kv_caches() rather than profiling. + ParallelConfig.sync_kv_cache_memory_size( + self.dp_group, self.available_gpu_memory_for_kv_cache) + # NOTE(yongji): newly joined workers require dummy_run even + # CUDA graph is not used + self.model_executor.collective_rpc("compile_or_warm_up_model") + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) + else: + logger.info("Distributed environment reinitialized for DP rank %s", + self.dp_rank) + class DPEngineCoreActor(DPEngineCoreProc): """ diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index dafaa15f777..f4c75708e8e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -23,7 +23,8 @@ from vllm.lora.request import LoRARequest from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, UtilityOutput) + EngineCoreRequestType, + ReconfigureDistributedRequest, UtilityOutput) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError @@ -162,6 +163,12 @@ def dp_engines_running(self) -> bool: running state.""" raise NotImplementedError + async def scale_up(self, new_data_parallel_size: int) -> None: + raise NotImplementedError + + async def scale_down(self, new_data_parallel_size: int) -> None: + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -910,13 +917,32 @@ async def run_engine_stats_update_task(): events = await poller.poll() if not self.engines_running and len(events) == 2 or ( events[0][0] == first_req_rcv_socket): - # Send a message to notify the coordinator that + # Check if this is a regular request notification or + # scale up notification + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + + # Check if this is a scale up notification + decoded = msgspec.msgpack.decode(buf) + if isinstance(decoded, (list, tuple)) and len( + decoded) == 2 and decoded[0] == "SCALE_DP": + # Extract new engine count from the decoded message + new_engine_count = decoded[1] + # Send scale up notification to coordinator + scale_msg = msgspec.msgpack.encode( + ("SCALE_DP", new_engine_count)) + await socket.send(scale_msg) + continue + logger.error( + "Received invalid scale up notification: %s", + decoded) + + # Regular request notification - send a message to + # notify the coordinator that # we're sending a request while the engines are # paused, so that it can wake the others up # (to run dummy EP loop). self.engines_running = True - buf = first_req_rcv_socket.recv( - flags=zmq.NOBLOCK).result() target_eng_index = int.from_bytes(buf, "little") msg = msgspec.msgpack.encode( (target_eng_index, self.current_wave)) @@ -1047,3 +1073,143 @@ async def _abort_requests(self, request_ids: list[str], engine: EngineIdentity) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) + + async def _send_reconfig_message( + self, reconfig_request: ReconfigureDistributedRequest, + engine: EngineIdentity) -> asyncio.Future: + """Send reconfiguration message and return the result future without + waiting for completion.""" + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( + (self.client_index, call_id, "reinitialize_distributed", + (reconfig_request, )))) + await self._send_input_message(message, engine, reconfig_request) + self._ensure_output_queue_task() + return future + + async def scale_up(self, new_data_parallel_size: int) -> None: + """Scale up the data parallel size by creating new engine cores + and reconfiguring existing ones.""" + cur_data_parallel_size = len(self.core_engines) + + assert new_data_parallel_size > cur_data_parallel_size, ( + f"new_data_parallel_size {new_data_parallel_size} must be greater " + f"than cur_data_parallel_size {cur_data_parallel_size} " + "for scale up") + + # Phase 1: Send reconfigure messages to all existing engines and wait + # for them to be sent + reconfig_futures = [] + # one for stateless group in EngineCore, one for worker's distributed + # world group + self.vllm_config.parallel_config.data_parallel_master_port += 2 + for engine in self.core_engines: + reconfig_request = ReconfigureDistributedRequest( + new_data_parallel_size=new_data_parallel_size, + new_data_parallel_rank=-1, # Keep original rank + new_data_parallel_rank_local=-1, # Keep original local rank + new_data_parallel_master_ip=self.vllm_config.parallel_config. + data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config. + data_parallel_master_port) + future = await self._send_reconfig_message(reconfig_request, + engine) + reconfig_futures.append(future) + + logger.info("All reconfigure messages sent, starting engine creation") + + # Phase 2: Create new engines now that reconfig messages have been sent + # self.resources.engine_manager is guaranteed to be + # CoreEngineActorManager for RayDPClient + assert isinstance(self.resources.engine_manager, + CoreEngineActorManager) + self.resources.engine_manager.scale_up(self.vllm_config, + new_data_parallel_size) + + # Create new CoreEngine objects for the new engines + new_engine_identities = set() + for i in range(cur_data_parallel_size, new_data_parallel_size): + new_engine = i.to_bytes(2, "little") + self.core_engines.append(new_engine) + new_engine_identities.add(new_engine) + + # Wait for ready messages from new engines on the input socket + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while new_engine_identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError( + "Timed out waiting for new engines to send initial " + "message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + new_engine_identities.discard(identity) + + # Phase 3: Wait for all existing engines to complete reconfiguration + logger.info("Waiting for existing engines to complete reconfiguration") + await asyncio.gather(*reconfig_futures) + + # Notify coordinator about scale up through existing + # stats_update_task connection + self._ensure_stats_update_task() + scale_up_marker = msgspec.msgpack.encode( + ("SCALE_DP", new_data_parallel_size)) + await self.first_req_send_socket.send(scale_up_marker) + + # Update the parallel config + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + logger.info( + "[Elastic EP] Scale up completed, new data parallel size: %s", + new_data_parallel_size) + + async def scale_down(self, new_data_parallel_size: int) -> None: + """Scale down the data parallel size by shutting down and + reconfiguring existing engine cores.""" + cur_data_parallel_size = len(self.core_engines) + + assert new_data_parallel_size <= cur_data_parallel_size, ( + f"new_data_parallel_size {new_data_parallel_size} must be less " + f"than cur_data_parallel_size {cur_data_parallel_size} " + "for scale down") + + # one for stateless group in EngineCore, one for worker's distributed + # world group + self.vllm_config.parallel_config.data_parallel_master_port += 2 + + reconfig_futures = [] + for cur_dp_rank, engine in enumerate(self.core_engines): + reconfig_request = ReconfigureDistributedRequest( + new_data_parallel_size=new_data_parallel_size, + new_data_parallel_rank=-1, # Keep original rank + new_data_parallel_rank_local=-1, # Keep original local rank + new_data_parallel_master_ip=self.vllm_config.parallel_config. + data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config. + data_parallel_master_port) + if cur_dp_rank >= new_data_parallel_size: + reconfig_request.new_data_parallel_rank = -2 + future = await self._send_reconfig_message(reconfig_request, + engine) + reconfig_futures.append(future) + + for _ in range(new_data_parallel_size, cur_data_parallel_size): + self.core_engines.pop() + + await asyncio.gather(*reconfig_futures) + + assert isinstance(self.resources.engine_manager, + CoreEngineActorManager) + self.resources.engine_manager.scale_down(cur_data_parallel_size, + new_data_parallel_size) + + self._ensure_stats_update_task() + scale_up_marker = msgspec.msgpack.encode( + ("SCALE_DP", new_data_parallel_size)) + await self.first_req_send_socket.send(scale_up_marker) + + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + logger.info( + "[Elastic EP] Scale down completed, new data parallel size: %s", + new_data_parallel_size) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index ae104bd6eb9..cc5a5116b4f 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -174,16 +174,21 @@ def __init__( self.local_engine_actors: list[ray.ActorHandle] = [] self.remote_engine_actors: list[ray.ActorHandle] = [] + + env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") + self.env_vars_dict = { + name: os.environ[name] + for name in env_vars_list if name in os.environ + } + runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) + + self.addresses = addresses + self.executor_class = executor_class + self.log_stats = log_stats dp_size = vllm_config.parallel_config.data_parallel_size local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size - env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor") - env_vars_dict = { - name: os.environ[name] - for name in env_vars_set if name in os.environ - } - runtime_env = RuntimeEnv(env_vars=env_vars_dict) if ray.is_initialized(): logger.info( @@ -208,6 +213,7 @@ def __init__( assert len(placement_groups) == dp_size, ( "Number of placement groups must match data parallel size") + self.placement_group_is_local = [] refs = [] for index in range(dp_size): local_index = local_dp_ranks[index] @@ -231,6 +237,7 @@ def __init__( self.local_engine_actors.append(actor) else: self.remote_engine_actors.append(actor) + self.placement_group_is_local.append(local_client) refs.append(actor.wait_for_init.remote()) ray.get(refs) @@ -242,6 +249,9 @@ def __init__( def create_dp_placement_groups( vllm_config: VllmConfig ) -> tuple[list["PlacementGroup"], list[int]]: + """ + Create placement groups for data parallel. + """ import ray from ray._private.state import available_resources_per_node @@ -250,10 +260,11 @@ def create_dp_placement_groups( logger.info("Creating placement groups for data parallel") dp_master_ip = \ vllm_config.parallel_config.data_parallel_master_ip - dp_size = vllm_config.parallel_config.data_parallel_size + num_pg_to_create = vllm_config.parallel_config.data_parallel_size local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local + nodes = list_nodes() nodes = sorted(list_nodes(), key=lambda node: node.node_ip != dp_master_ip) assert nodes[0].node_ip == dp_master_ip, ( @@ -293,7 +304,7 @@ def create_dp_placement_groups( local_dp_ranks.append(i) else: for i in range(available_engine_count): - if len(placement_groups) == dp_size: + if len(placement_groups) == num_pg_to_create: break bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( @@ -305,6 +316,204 @@ def create_dp_placement_groups( local_dp_ranks.append(i) return placement_groups, local_dp_ranks + @staticmethod + def add_dp_placement_groups( + old_vllm_config: VllmConfig, new_data_parallel_size: int + ) -> tuple[list["PlacementGroup"], list[int]]: + """ + Add placement groups for new data parallel size. + """ + import ray + from ray._private.state import (available_resources_per_node, + total_resources_per_node) + from ray.util.state import list_nodes + + old_dp_size = old_vllm_config.parallel_config.data_parallel_size + num_pg_to_create = new_data_parallel_size - old_dp_size + + if num_pg_to_create <= 0: + return [], [] + + dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip + world_size = old_vllm_config.parallel_config.world_size + + nodes = list_nodes() + nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip) + assert nodes[0].node_ip == dp_master_ip, ( + "The first node must be the head node") + assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( + "There can only be one head node") + + available_resources = available_resources_per_node() + total_resources = total_resources_per_node() + + placement_groups = [] + local_dp_ranks = [] + num_pg_created = 0 + + for node in nodes: + if num_pg_created >= num_pg_to_create: + break + + node_ip = node.node_ip + node_id = node.node_id + available_gpus = int(available_resources[node_id]["GPU"]) + + # Get total GPUs on this node from the node's resources + # Ray stores node resources with node ID as key + total_gpus = int(total_resources[node_id]["GPU"]) + + # Calculate used GPUs and used engines on this node + used_gpus = max(0, total_gpus - available_gpus) + used_engines_on_node = used_gpus // world_size + + # Calculate how many new engines this node can accommodate + available_engine_count = available_gpus // world_size + + # Create placement groups for new engines on this node + for i in range(available_engine_count): + if num_pg_created >= num_pg_to_create: + break + + rank = old_dp_size + num_pg_created + + # Create bundles with node constraint for master node + if node_ip == dp_master_ip: + bundles = [{ + "GPU": 1.0, + "node:" + dp_master_ip: 0.001 + }] * world_size + [{ + "CPU": 1.0 + }] + else: + bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + + pg = ray.util.placement_group( + name=f"dp_rank_{rank}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + + # Local rank starts from the number of engines already used + # on this node + local_rank = used_engines_on_node + i + local_dp_ranks.append(local_rank) + num_pg_created += 1 + + return placement_groups, local_dp_ranks + + def scale_up(self, cur_vllm_config: VllmConfig, + new_data_parallel_size: int) -> None: + import copy + + import ray + from ray.runtime_env import RuntimeEnv + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + + from vllm.v1.engine.core import DPEngineCoreActor + + cur_data_parallel_size = len(self.local_engine_actors) + \ + len(self.remote_engine_actors) + + assert new_data_parallel_size > cur_data_parallel_size, ( + f"New data parallel size {new_data_parallel_size} must be greater " + f"than current data parallel size {cur_data_parallel_size} " + "for scale up") + + placement_groups, local_dp_ranks = \ + self.add_dp_placement_groups( + cur_vllm_config, new_data_parallel_size) + + world_size = cur_vllm_config.parallel_config.world_size + dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip + new_local_engines = 0 + + runtime_env = RuntimeEnv(env_vars=self.env_vars_dict + | {"VLLM_EEP_SCALE_UP_LAUNCH": "1"}) + for i, (pg, + local_rank) in enumerate(zip(placement_groups, + local_dp_ranks)): + rank = cur_data_parallel_size + i + dp_vllm_config = copy.deepcopy(cur_vllm_config) + dp_vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + dp_vllm_config.parallel_config.placement_group = pg + + # Check if this placement group is on the head node + local_client = any( + bundle.get("node:" + dp_master_ip, 0) > 0 + for bundle in pg.bundle_specs) + + if local_client: + new_local_engines += 1 + # Update data_parallel_size_local + dp_vllm_config.parallel_config.data_parallel_size_local = ( + cur_vllm_config.parallel_config.data_parallel_size_local + + new_local_engines) + + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env).remote( + vllm_config=dp_vllm_config, + executor_class=self.executor_class, + log_stats=self.log_stats, + local_client=local_client, + addresses=self.addresses, + dp_rank=rank, + local_dp_rank=local_rank) + + if local_client: + self.local_engine_actors.append(actor) + else: + self.remote_engine_actors.append(actor) + self.created_placement_groups.append(pg) + self.placement_group_is_local.append(local_client) + + ray.get([ + actor.wait_for_init.remote() + for actor in (self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 else []) + + self.remote_engine_actors[-(len(placement_groups) - + new_local_engines):] + ]) + + actors = (self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 else []) + \ + self.remote_engine_actors[-(len(placement_groups) - + new_local_engines):] + + for actor in actors: + self.run_refs.append(actor.run.remote()) + + cur_vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + # Update old_vllm_config with new data_parallel_size_local if any new + # local engines were added + if new_local_engines > 0: + cur_vllm_config.parallel_config.data_parallel_size_local += \ + new_local_engines + + def scale_down(self, cur_data_parallel_size: int, + new_data_parallel_size: int) -> None: + import ray + assert cur_data_parallel_size > new_data_parallel_size, ( + f"cur_data_parallel_size {cur_data_parallel_size} must be greater " + f"than new_data_parallel_size {new_data_parallel_size} " + "for scale down") + for _ in range(cur_data_parallel_size - new_data_parallel_size): + pg = self.created_placement_groups.pop() + is_local = self.placement_group_is_local.pop() + if is_local: + self.local_engine_actors.pop() + else: + self.remote_engine_actors.pop() + ray.util.remove_placement_group(pg) + def get_run_refs(self): return self.run_refs diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index daca7c0faf6..be8edf3e3cb 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -6,6 +6,7 @@ from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput @@ -62,3 +63,10 @@ def execute_model( # When PP is used, we return a FutureWrapper immediately so that # the scheduler can yield to the next batch. return FutureWrapper(refs[0]) + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self._run_workers("reinitialize_distributed", reconfig_request) + if reconfig_request.new_data_parallel_rank == -2: + self.shutdown() + return diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 410a54e7466..87f28661adc 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -50,7 +50,7 @@ def replace_tensor(obj: Any, cpu_attr_name: str, if k.endswith("_cpu") and isinstance(v, torch.Tensor): replace_tensor(self.input_batch.block_table, k, k[:-4]) - def load_model(self) -> None: + def load_model(self, reconfigure: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c90..fde815cca9c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1738,8 +1738,40 @@ def update_config(self, overrides: dict[str, Any]) -> None: new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) - def load_model(self) -> None: + def load_model(self, eep_scale_up: bool = False) -> None: + """ + Args: + eep_scale_up: the model loading is for elastic EP scale up. + """ logger.info("Starting to load model %s...", self.model_config.model) + if eep_scale_up: + from vllm.distributed.parallel_state import get_ep_group + num_local_physical_experts = torch.empty(1, + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = int(num_local_physical_experts.item()) + new_ep_size = get_ep_group().world_size + global_expert_load, old_global_expert_indices = ( + EplbState.recv_state()) + num_logical_experts = global_expert_load.shape[1] + self.parallel_config.num_redundant_experts = ( + num_local_physical_experts * new_ep_size - num_logical_experts) + assert old_global_expert_indices.shape[ + 1] % num_local_physical_experts == 0 + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + else: + global_expert_load = None + old_global_expert_indices = None + rank_mapping = None + with DeviceMemoryProfiler() as m: # noqa: SIM117 time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) @@ -1783,6 +1815,9 @@ def load_model(self) -> None: self.model, self.device, self.parallel_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, ) def save_tensorized_model( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6458b55777a..a02c6a0cef1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -25,6 +25,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling +from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats @@ -190,8 +191,9 @@ def load_model(self) -> None: else: from contextlib import nullcontext context = nullcontext() + eep_scale_up = os.environ.get("VLLM_EEP_SCALE_UP_LAUNCH") == "1" with context: - self.model_runner.load_model() + self.model_runner.load_model(eep_scale_up=eep_scale_up) def update_config(self, overrides: dict[str, Any]) -> None: self.model_runner.update_config(overrides) @@ -380,6 +382,125 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + from vllm.config import set_current_vllm_config + from vllm.distributed.parallel_state import ( + cleanup_dist_env_and_memory, get_dp_group, get_ep_group, + prepare_communication_buffer_for_model) + from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoEParallelConfig) + + old_ep_size = get_ep_group().world_size + old_ep_rank = get_ep_group().rank + new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( + ).world_size * get_pp_group().world_size + if new_ep_size < old_ep_size: + # scale down + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding " + "before scaling down...") + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange(self.model_runner.model, + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping) + torch.cuda.synchronize() + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + cleanup_dist_env_and_memory() + + if reconfig_request.new_data_parallel_rank == -2: + assert old_ep_rank >= new_ep_size + # shutdown + return + + # Update parallel config with provided reconfig_request + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = \ + reconfig_request.new_data_parallel_size + # Only update rank if new value is provided (-1 means keep current) + if reconfig_request.new_data_parallel_rank != -1: + parallel_config.data_parallel_rank = \ + reconfig_request.new_data_parallel_rank + if reconfig_request.new_data_parallel_rank_local != -1: + parallel_config.data_parallel_rank_local = \ + reconfig_request.new_data_parallel_rank_local + parallel_config.data_parallel_master_ip = \ + reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port + + with set_current_vllm_config(self.vllm_config): + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank) + moe_modules = [ + module for module in self.model_runner.model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all(module.moe_config.num_local_experts == num_local_experts + for module in moe_modules), ( + "All MoE modules must have the same number of experts") + new_ep_size = get_ep_group().world_size + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + if new_ep_size < old_ep_size: + num_local_physical_experts = num_local_experts + assert self.model_runner.eplb_state is not None + new_physical_experts = \ + self.model_runner.eplb_state.physical_to_logical_map.shape[1] + parallel_config.num_redundant_experts = ( + new_physical_experts - + self.model_runner.eplb_state.logical_replica_count.shape[1]) + else: + num_local_physical_experts = torch.tensor([num_local_experts], + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = num_local_physical_experts.item() + new_physical_experts = num_local_physical_experts * new_ep_size + assert self.model_runner.eplb_state is not None + global_expert_load = self.model_runner.eplb_state.rearrange( + self.model_runner.model, execute_shuffle=False) + parallel_config.num_redundant_experts = ( + new_physical_experts - global_expert_load.shape[1]) + prepare_communication_buffer_for_model(self.model_runner.model) + self.model_runner.model.update_physical_experts_metadata( + num_physical_experts=new_physical_experts, + num_local_physical_experts=num_local_physical_experts) + if new_ep_size > old_ep_size: + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding " + "after scaling up...") + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange( + self.model_runner.model, + execute_shuffle=True, + global_expert_load=global_expert_load, + rank_mapping=rank_mapping) + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + def save_sharded_state( self, path: str,