Skip to content

Elastic Expert Parallel Initial Support #20775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/online_serving/elastic_ep/bench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite-Chat"
HOST="localhost"
PORT=8006

vllm bench serve \
--model $MODEL_NAME \
--host $HOST \
--port $PORT \
--num-prompts 5
53 changes: 53 additions & 0 deletions examples/online_serving/elastic_ep/scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import json
import sys

import requests


def scale(host, port, new_dp_size):
url = f"http://{host}:{port}/scale"
payload = {"new_data_parallel_size": new_dp_size}
headers = {"Content-Type": "application/json"}

print(f"Sending scale request to {url}")
print(f"Payload: {json.dumps(payload, indent=2)}")

try:
response = requests.post(url, json=payload, headers=headers, timeout=300)

print(f"Status Code: {response.status_code}")
print(f"Response: {response.text}")

if response.status_code == 200:
print("Scale up/down request successful!")
return True
else:
print("Scale up/down request failed!")
return False

except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
return False


def main():
parser = argparse.ArgumentParser(description="Test scale up/down functionality")
parser.add_argument("--host", default="localhost", help="API server host")
parser.add_argument("--port", type=int, default=8006, help="API server port")
parser.add_argument(
"--new-dp-size", type=int, default=2, help="New data parallel size"
)

args = parser.parse_args()

success = scale(args.host, args.port, args.new_dp_size)
sys.exit(0 if success else 1)


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions examples/online_serving/elastic_ep/serve_deepseek_v2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Serve DeepSeek V2 model with vLLM
# This script demonstrates how to serve the DeepSeek V2 model using vLLM's V1 engine

# MODEL_NAME="gaunernst/DeepSeek-V2-Lite-Chat-FP8"
MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite-Chat"
HOST="0.0.0.0"
PORT=8006

DATA_PARALLEL_SIZE=3
DATA_PARALLEL_SIZE_LOCAL=$DATA_PARALLEL_SIZE

export VLLM_USE_V1=1
export VLLM_ALL2ALL_BACKEND="pplx"
export VLLM_USE_DEEP_GEMM=1

# Launch the vLLM server
vllm serve $MODEL_NAME --trust-remote-code \
--disable-log-requests \
--host $HOST \
--port $PORT \
--tensor-parallel-size 1 \
--enable-expert-parallel \
--enable-eplb \
--num-redundant-experts 32 \
--enforce-eager \
--data-parallel-backend ray \
--data-parallel-size $DATA_PARALLEL_SIZE \
--data-parallel-size-local $DATA_PARALLEL_SIZE_LOCAL \
--data-parallel-start-rank 0
92 changes: 92 additions & 0 deletions tools/ep_kernels/elastic_ep/nvshmem.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001
From: Yongji Wu <wuyongji317@gmail.com>
Date: Tue, 20 May 2025 13:41:12 -0700
Subject: [PATCH] fix reinit issues due to states not cleaned up

fix double free
---
src/host/init/init.cu | 10 ++++++++++
.../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++
src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++
3 files changed, 30 insertions(+)

diff --git a/src/host/init/init.cu b/src/host/init/init.cu
index b1c5dbf..1fecb4b 100644
--- a/src/host/init/init.cu
+++ b/src/host/init/init.cu
@@ -43,6 +43,8 @@
#include "internal/host/nvshmemi_types.h"
#include "internal/host/shared_memory.h"
#include "internal/host/nvshmemi_symmetric_heap.hpp"
+// eep-dev
+#include "internal/host/nvshmemi_mem_transport.hpp"

extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d;
static std::map<void *, int> registered_device_states;
@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) {
/* Multi-init Multi-fini*/
nvshmemi_state = NULL;
nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0;
+
+ // eep-dev
+ nvshmemi_mem_p2p_transport::destroy_instance();
+ nvshmemi_mem_remote_transport::destroy_instance();
+ free(nvshmemi_default_session);
+ nvshmemi_default_session = nullptr;
+ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false;
+
nvshmemi_is_device_state_ready = false;
} else
nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle);
diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp
index 2495844..e4f408a 100644
--- a/src/include/internal/host/nvshmemi_mem_transport.hpp
+++ b/src/include/internal/host/nvshmemi_mem_transport.hpp
@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final {
return p2p_objref_;
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (p2p_objref_ != nullptr) {
+ delete p2p_objref_;
+ p2p_objref_ = nullptr;
+ }
+ }

void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj);

@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final {
}
}

+ // eep-dev
+ static void destroy_instance(void) {
+ if (remote_objref_ != nullptr) {
+ delete remote_objref_;
+ remote_objref_ = nullptr;
+ }
+ }
+
int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size);
/* On-demand registration and release of memory */
int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx,
diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp
index a1fa748..788fa96 100644
--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp
+++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp
@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi
// Discover the network for bootstrap, if not done previously.
// This code needs to be stateful to be able to be called multiple times by the caller
BOOTSTRAP_CHECK(bootstrap_net_init());
+ // eep-dev
+ if (handle->pre_init_ops != nullptr) {
+ BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
+ handle->pre_init_ops = nullptr;
+ }
if (handle->pre_init_ops == nullptr) {
BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1);
handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id;
--
2.43.0

13 changes: 13 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,19 @@ def has_unfinished_dp(dp_group: "ProcessGroup",
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished

@staticmethod
def sync_kv_cache_memory_size(dp_group: "ProcessGroup",
kv_cache_memory: int) -> int:
if kv_cache_memory == -1:
kv_cache_memory = torch.iinfo(torch.int64).max
tensor = torch.tensor([kv_cache_memory],
dtype=torch.int64,
device="cpu")
# we cannot use broadcast for stateless dp group since it depends
# on global rank
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
return tensor.item()

def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
Expand Down
Loading
Loading