Skip to content

Commit a6d795d

Browse files
authored
[DP] Copy environment variables to Ray DPEngineCoreActors (#20344)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
1 parent a37d75b commit a6d795d

File tree

3 files changed

+93
-35
lines changed

3 files changed

+93
-35
lines changed

vllm/executor/ray_distributed_executor.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import asyncio
5-
import json
65
import os
76
from collections import defaultdict
87
from dataclasses import dataclass
@@ -20,6 +19,7 @@
2019
from vllm.logger import init_logger
2120
from vllm.model_executor.layers.sampler import SamplerOutput
2221
from vllm.platforms import current_platform
22+
from vllm.ray.ray_env import get_env_vars_to_copy
2323
from vllm.sequence import ExecuteModelRequest
2424
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
2525
get_ip, get_open_port, make_async)
@@ -58,17 +58,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
5858
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
5959
}
6060

61-
config_home = envs.VLLM_CONFIG_ROOT
62-
# This file contains a list of env vars that should not be copied
63-
# from the driver to the Ray workers.
64-
non_carry_over_env_vars_file = os.path.join(
65-
config_home, "ray_non_carry_over_env_vars.json")
66-
if os.path.exists(non_carry_over_env_vars_file):
67-
with open(non_carry_over_env_vars_file) as f:
68-
non_carry_over_env_vars = set(json.load(f))
69-
else:
70-
non_carry_over_env_vars = set()
71-
7261
uses_ray: bool = True
7362

7463
def _init_executor(self) -> None:
@@ -335,13 +324,10 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
335324
} for (node_id, _) in worker_node_and_gpu_ids]
336325

337326
# Environment variables to copy from driver to workers
338-
env_vars_to_copy = [
339-
v for v in envs.environment_variables
340-
if v not in self.WORKER_SPECIFIC_ENV_VARS
341-
and v not in self.non_carry_over_env_vars
342-
]
343-
344-
env_vars_to_copy.extend(current_platform.additional_env_vars)
327+
env_vars_to_copy = get_env_vars_to_copy(
328+
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
329+
additional_vars=set(current_platform.additional_env_vars),
330+
destination="workers")
345331

346332
# Copy existing env vars to each worker's args
347333
for args in all_args_to_update_environment_variables:
@@ -350,15 +336,6 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
350336
if name in os.environ:
351337
args[name] = os.environ[name]
352338

353-
logger.info("non_carry_over_env_vars from config: %s",
354-
self.non_carry_over_env_vars)
355-
logger.info(
356-
"Copying the following environment variables to workers: %s",
357-
[v for v in env_vars_to_copy if v in os.environ])
358-
logger.info(
359-
"If certain env vars should NOT be copied to workers, add them to "
360-
"%s file", self.non_carry_over_env_vars_file)
361-
362339
self._env_vars_for_all_workers = (
363340
all_args_to_update_environment_variables)
364341

vllm/ray/ray_env.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import json
4+
import os
5+
from typing import Optional
6+
7+
import vllm.envs as envs
8+
from vllm.logger import init_logger
9+
10+
logger = init_logger(__name__)
11+
12+
CONFIG_HOME = envs.VLLM_CONFIG_ROOT
13+
14+
# This file contains a list of env vars that should not be copied
15+
# from the driver to the Ray workers.
16+
RAY_NON_CARRY_OVER_ENV_VARS_FILE = os.path.join(
17+
CONFIG_HOME, "ray_non_carry_over_env_vars.json")
18+
19+
try:
20+
if os.path.exists(RAY_NON_CARRY_OVER_ENV_VARS_FILE):
21+
with open(RAY_NON_CARRY_OVER_ENV_VARS_FILE) as f:
22+
RAY_NON_CARRY_OVER_ENV_VARS = set(json.load(f))
23+
else:
24+
RAY_NON_CARRY_OVER_ENV_VARS = set()
25+
except json.JSONDecodeError:
26+
logger.warning(
27+
"Failed to parse %s. Using an empty set for non-carry-over env vars.",
28+
RAY_NON_CARRY_OVER_ENV_VARS_FILE)
29+
RAY_NON_CARRY_OVER_ENV_VARS = set()
30+
31+
32+
def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None,
33+
additional_vars: Optional[set[str]] = None,
34+
destination: Optional[str] = None) -> set[str]:
35+
"""
36+
Get the environment variables to copy to downstream Ray actors.
37+
38+
Example use cases:
39+
- Copy environment variables from RayDistributedExecutor to Ray workers.
40+
- Copy environment variables from RayDPClient to Ray DPEngineCoreActor.
41+
42+
Args:
43+
exclude_vars: A set of vllm defined environment variables to exclude
44+
from copying.
45+
additional_vars: A set of additional environment variables to copy.
46+
destination: The destination of the environment variables.
47+
Returns:
48+
A set of environment variables to copy.
49+
"""
50+
exclude_vars = exclude_vars or set()
51+
additional_vars = additional_vars or set()
52+
53+
env_vars_to_copy = {
54+
v
55+
for v in envs.environment_variables
56+
if v not in exclude_vars and v not in RAY_NON_CARRY_OVER_ENV_VARS
57+
}
58+
env_vars_to_copy.update(additional_vars)
59+
60+
to_destination = " to " + destination if destination is not None else ""
61+
62+
logger.info("RAY_NON_CARRY_OVER_ENV_VARS from config: %s",
63+
RAY_NON_CARRY_OVER_ENV_VARS)
64+
logger.info("Copying the following environment variables%s: %s",
65+
to_destination,
66+
[v for v in env_vars_to_copy if v in os.environ])
67+
logger.info(
68+
"If certain env vars should NOT be copied, add them to "
69+
"%s file", RAY_NON_CARRY_OVER_ENV_VARS_FILE)
70+
71+
return env_vars_to_copy

vllm/v1/engine/utils.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import contextlib
5+
import os
56
import weakref
67
from collections.abc import Iterator
78
from dataclasses import dataclass
@@ -15,6 +16,7 @@
1516

1617
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
1718
from vllm.logger import init_logger
19+
from vllm.ray.ray_env import get_env_vars_to_copy
1820
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
1921
from vllm.v1.engine.coordinator import DPCoordinator
2022
from vllm.v1.executor.abstract import Executor
@@ -164,6 +166,7 @@ def __init__(
164166
import copy
165167

166168
import ray
169+
from ray.runtime_env import RuntimeEnv
167170
from ray.util.scheduling_strategies import (
168171
PlacementGroupSchedulingStrategy)
169172

@@ -175,6 +178,12 @@ def __init__(
175178
local_engine_count = \
176179
vllm_config.parallel_config.data_parallel_size_local
177180
world_size = vllm_config.parallel_config.world_size
181+
env_vars_set = get_env_vars_to_copy(destination="DPEngineCoreActor")
182+
env_vars_dict = {
183+
name: os.environ[name]
184+
for name in env_vars_set if name in os.environ
185+
}
186+
runtime_env = RuntimeEnv(env_vars=env_vars_dict)
178187

179188
if ray.is_initialized():
180189
logger.info(
@@ -210,13 +219,14 @@ def __init__(
210219
scheduling_strategy=PlacementGroupSchedulingStrategy(
211220
placement_group=pg,
212221
placement_group_bundle_index=world_size,
213-
)).remote(vllm_config=dp_vllm_config,
214-
executor_class=executor_class,
215-
log_stats=log_stats,
216-
local_client=local_client,
217-
addresses=addresses,
218-
dp_rank=index,
219-
local_dp_rank=local_index)
222+
),
223+
runtime_env=runtime_env).remote(vllm_config=dp_vllm_config,
224+
executor_class=executor_class,
225+
log_stats=log_stats,
226+
local_client=local_client,
227+
addresses=addresses,
228+
dp_rank=index,
229+
local_dp_rank=local_index)
220230
if local_client:
221231
self.local_engine_actors.append(actor)
222232
else:

0 commit comments

Comments
 (0)