Skip to content

Commit cbab40d

Browse files
committed
fixes
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
1 parent 5c78497 commit cbab40d

File tree

10 files changed

+42
-36
lines changed

10 files changed

+42
-36
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,16 +1080,16 @@ async def scale(raw_request: Request):
10801080
await client.scale(new_data_parallel_size, drain_timeout)
10811081
return JSONResponse({
10821082
"message":
1083-
f"Scaled up to {new_data_parallel_size} "
1083+
f"Scaled to {new_data_parallel_size} "
10841084
"data parallel engines",
10851085
})
10861086
except TimeoutError as e:
1087-
raise HTTPException(
1088-
status_code=408,
1089-
detail="Scale up failed due to request drain timeout "
1090-
f"after {drain_timeout} seconds") from e
1087+
raise HTTPException(status_code=408,
1088+
detail="Scale failed due to request drain timeout "
1089+
f"after {drain_timeout} seconds") from e
10911090
except Exception as e:
1092-
raise HTTPException(status_code=500, detail="Scale up failed") from e
1091+
logger.error("Scale failed: %s", e)
1092+
raise HTTPException(status_code=500, detail="Scale failed") from e
10931093
finally:
10941094
raw_request.app.state.scaling = False
10951095

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
223223
def __init__(self, moe: FusedMoEConfig):
224224
super().__init__()
225225
self.fused_experts = fused_experts # type: ignore
226-
self.topk_indices_dtype = None
226+
self.topk_indices_dtype = torch.uint32
227227
self.moe = moe
228228

229229
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def max_num_tokens_per_rank(self) -> Optional[int]:
7878
return self.max_num_tokens
7979

8080
def topk_indices_dtype(self) -> Optional[torch.dtype]:
81-
return torch.int32
81+
# FIXME(rui): this needs to be int32,
82+
# see https://github.com/vllm-project/vllm/pull/20166
83+
return torch.uint32
8284

8385
def num_dispatchers(self) -> int:
8486
return self.num_dispatchers_
@@ -100,9 +102,10 @@ def prepare(
100102
hidden_dim = a1.size(-1) # K
101103

102104
assert topk_ids.size(0) == num_tokens
103-
assert expert_map is None, """with expert map, -1 id is used for
104-
non-local token; this causes error when casting ids to the
105-
topk_indices_dtype() uint32"""
105+
# FIXME(rui)
106+
# assert expert_map is None, """with expert map, -1 id is used for
107+
# non-local token; this causes error when casting ids to the
108+
# topk_indices_dtype() uint32"""
106109

107110
# Is this always going to be a1.device?
108111
device = a1.device

vllm/v1/engine/async_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ async def scale(self,
638638
Maximum time to wait for requests to drain (seconds)
639639
"""
640640
from vllm.v1.engine.core_client import RayDPClient
641-
641+
642642
if not isinstance(self.engine_core, RayDPClient):
643643
raise NotImplementedError(
644644
"Scale up/down only supported by RayDPClient")

vllm/v1/engine/coordinator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,12 @@ def process_input_socket(self, front_publish_address: str,
223223
self.engines_running = False
224224
logger.info(
225225
"DPCoordinator scaled up from %s to %s "
226-
"engines",
227-
current_count, new_engine_count)
226+
"engines", current_count, new_engine_count)
228227
else:
229228
self.engines = self.engines[:new_engine_count]
230229
logger.info(
231230
"DPCoordinator scaled down from %s to %s "
232-
"engines",
233-
current_count, new_engine_count)
231+
"engines", current_count, new_engine_count)
234232
continue # Skip normal engine notification processing
235233

236234
# We received a message on the front-end XPUB socket,

vllm/v1/engine/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def _initialize_kv_caches(
146146
if os.environ.get("VLLM_EEP_RECONFIGURE_LAUNCH") == "1":
147147
dp_group = getattr(self, "dp_group", None)
148148
assert dp_group is not None
149-
kv_cache_memory = ParallelConfig.sync_kv_cache_memory(
150-
dp_group, -1)
149+
kv_cache_memory = ParallelConfig.sync_kv_cache_memory(dp_group, -1)
151150
available_gpu_memory = [kv_cache_memory] * len(kv_cache_specs)
152151
else:
153152
# Profiles the peak memory usage of the model to determine how much

vllm/v1/engine/core_client.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.v1.engine.coordinator import DPCoordinator
2929
from vllm.v1.engine.core import EngineCore, EngineCoreProc
3030
from vllm.v1.engine.exceptions import EngineDeadError
31-
from vllm.v1.engine.utils import (CoreEngine, CoreEngineActorManager,
31+
from vllm.v1.engine.utils import (CoreEngineActorManager,
3232
CoreEngineProcManager, EngineZmqAddresses,
3333
launch_core_engines)
3434
from vllm.v1.executor.abstract import Executor
@@ -94,6 +94,8 @@ def make_async_mp_client(
9494
# External load balancer - client per DP rank.
9595
return DPAsyncMPClient(*client_args)
9696
# Internal load balancer - client balances to all DP ranks.
97+
if parallel_config.data_parallel_backend == "ray":
98+
return RayDPClient(*client_args)
9799
return DPLBAsyncMPClient(*client_args)
98100
return AsyncMPClient(*client_args)
99101

@@ -1115,7 +1117,7 @@ def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
11151117

11161118
async def _send_reconfig_message(
11171119
self, reconfig_request: ReconfigureDistributedRequest,
1118-
engine: CoreEngine) -> asyncio.Future:
1120+
engine: EngineIdentity) -> asyncio.Future:
11191121
"""Send reconfiguration message and return the result future without
11201122
waiting for completion."""
11211123
call_id = uuid.uuid1().int >> 64
@@ -1160,17 +1162,17 @@ async def scale_up(self, new_data_parallel_size: int) -> None:
11601162
# Phase 2: Create new engines now that reconfig messages have been sent
11611163
# self.resources.engine_manager is guaranteed to be
11621164
# CoreEngineActorManager for RayDPClient
1163-
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
1165+
assert isinstance(self.resources.engine_manager,
1166+
CoreEngineActorManager)
11641167
self.resources.engine_manager.scale_up(self.vllm_config,
11651168
new_data_parallel_size)
11661169

11671170
# Create new CoreEngine objects for the new engines
11681171
new_engine_identities = set()
11691172
for i in range(current_dp_size, new_data_parallel_size):
1170-
# TODO(yongji): check if the engine is local
1171-
new_engine = CoreEngine(index=i, local=False)
1173+
new_engine = i.to_bytes(2, "little")
11721174
self.core_engines.append(new_engine)
1173-
new_engine_identities.add(new_engine.identity)
1175+
new_engine_identities.add(new_engine)
11741176

11751177
# Wait for ready messages from new engines on the input socket
11761178
sync_input_socket = zmq.Socket.shadow(self.input_socket)
@@ -1233,7 +1235,8 @@ async def scale_down(self, new_data_parallel_size: int) -> None:
12331235

12341236
await asyncio.gather(*reconfig_futures)
12351237

1236-
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
1238+
assert isinstance(self.resources.engine_manager,
1239+
CoreEngineActorManager)
12371240
self.resources.engine_manager.scale_down(current_dp_size,
12381241
new_data_parallel_size)
12391242

vllm/v1/engine/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(
221221
dp_vllm_config = copy.deepcopy(vllm_config)
222222
pg = placement_groups[index]
223223
dp_vllm_config.parallel_config.placement_group = pg
224-
on_head_node = index < local_engine_count
224+
local_client = index < local_engine_count
225225
actor = ray.remote(DPEngineCoreActor).options(
226226
scheduling_strategy=PlacementGroupSchedulingStrategy(
227227
placement_group=pg,
@@ -230,15 +230,15 @@ def __init__(
230230
runtime_env=runtime_env).remote(vllm_config=dp_vllm_config,
231231
executor_class=executor_class,
232232
log_stats=log_stats,
233-
on_head_node=on_head_node,
233+
local_client=local_client,
234234
addresses=addresses,
235235
dp_rank=index,
236236
local_dp_rank=local_index)
237-
if on_head_node:
237+
if local_client:
238238
self.local_engine_actors.append(actor)
239239
else:
240240
self.remote_engine_actors.append(actor)
241-
self.placement_group_is_local.append(on_head_node)
241+
self.placement_group_is_local.append(local_client)
242242
refs.append(actor.wait_for_init.remote())
243243

244244
ray.get(refs)
@@ -435,11 +435,11 @@ def scale_up(self, old_vllm_config: VllmConfig,
435435
dp_vllm_config.parallel_config.placement_group = pg
436436

437437
# Check if this placement group is on the head node
438-
on_head_node = any(
438+
local_client = any(
439439
bundle.get("node:" + dp_master_ip, 0) > 0
440440
for bundle in pg.bundle_specs)
441441

442-
if on_head_node:
442+
if local_client:
443443
new_local_engines += 1
444444
# Update data_parallel_size_local
445445
dp_vllm_config.parallel_config.data_parallel_size_local = (
@@ -455,17 +455,17 @@ def scale_up(self, old_vllm_config: VllmConfig,
455455
vllm_config=dp_vllm_config,
456456
executor_class=self.executor_class,
457457
log_stats=self.log_stats,
458-
on_head_node=on_head_node,
458+
local_client=local_client,
459459
addresses=self.addresses,
460460
dp_rank=rank,
461461
local_dp_rank=local_rank)
462462

463-
if on_head_node:
463+
if local_client:
464464
self.local_engine_actors.append(actor)
465465
else:
466466
self.remote_engine_actors.append(actor)
467467
self.created_placement_groups.append(pg)
468-
self.placement_group_is_local.append(on_head_node)
468+
self.placement_group_is_local.append(local_client)
469469

470470
ray.get([
471471
actor.wait_for_init.remote()

vllm/v1/worker/cpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def replace_tensor(obj: Any, cpu_attr_name: str,
5050
if k.endswith("_cpu") and isinstance(v, torch.Tensor):
5151
replace_tensor(self.input_batch.block_table, k, k[:-4])
5252

53-
def load_model(self) -> None:
53+
def load_model(self, reconfigure: bool = False) -> None:
5454
logger.info("Starting to load model %s...", self.model_config.model)
5555
self.model = get_model(vllm_config=self.vllm_config)
5656

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def reinitialize_distributed(
372372
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
373373
for old_ep_rank in range(old_ep_size)
374374
}
375+
assert self.model_runner.eplb_state is not None
375376
self.model_runner.eplb_state.rearrange(self.model_runner.model,
376377
execute_shuffle=True,
377378
global_expert_load=None,
@@ -427,6 +428,7 @@ def reinitialize_distributed(
427428
module.moe_config.moe_parallel_config = module.moe_parallel_config
428429
if new_ep_size < old_ep_size:
429430
num_local_physical_experts = num_local_experts
431+
assert self.model_runner.eplb_state is not None
430432
new_physical_experts = \
431433
self.model_runner.eplb_state.physical_to_logical_map.shape[1]
432434
parallel_config.num_redundant_experts = (
@@ -441,6 +443,7 @@ def reinitialize_distributed(
441443
group_src=0)
442444
num_local_physical_experts = num_local_physical_experts.item()
443445
new_physical_experts = num_local_physical_experts * new_ep_size
446+
assert self.model_runner.eplb_state is not None
444447
global_expert_load = self.model_runner.eplb_state.rearrange(
445448
self.model_runner.model, execute_shuffle=False)
446449
parallel_config.num_redundant_experts = (
@@ -457,14 +460,14 @@ def reinitialize_distributed(
457460
old_ep_rank: old_ep_rank
458461
for old_ep_rank in range(old_ep_size)
459462
}
463+
assert self.model_runner.eplb_state is not None
460464
self.model_runner.eplb_state.rearrange(
461465
self.model_runner.model,
462466
execute_shuffle=True,
463467
global_expert_load=global_expert_load,
464468
rank_mapping=rank_mapping)
465469
if get_ep_group().rank == 0:
466470
logger.info("[Elastic EP] Expert resharding completed!")
467-
self.model_runner.eplb_state.expert_rearrangement_step = 0
468471

469472
def save_sharded_state(
470473
self,

0 commit comments

Comments
 (0)