Skip to content

Commit 8d1096e

Browse files
authored
[Bugfix] Register reducer even if transformers_modules not available (#19510)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
1 parent 8d775dd commit 8d1096e

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

tests/config/test_mp_reducer.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import sys
4+
from unittest.mock import patch
5+
6+
from vllm.config import VllmConfig
7+
from vllm.engine.arg_utils import AsyncEngineArgs
8+
from vllm.v1.engine.async_llm import AsyncLLM
9+
10+
11+
def test_mp_reducer(monkeypatch):
12+
"""
13+
Test that _reduce_config reducer is registered when AsyncLLM is instantiated
14+
without transformers_modules. This is a regression test for
15+
https://github.com/vllm-project/vllm/pull/18640.
16+
"""
17+
18+
# Use V1 AsyncLLM which calls maybe_register_config_serialize_by_value
19+
monkeypatch.setenv('VLLM_USE_V1', '1')
20+
21+
# Ensure transformers_modules is not in sys.modules
22+
if 'transformers_modules' in sys.modules:
23+
del sys.modules['transformers_modules']
24+
25+
with patch('multiprocessing.reducer.register') as mock_register:
26+
engine_args = AsyncEngineArgs(
27+
model="facebook/opt-125m",
28+
max_model_len=32,
29+
gpu_memory_utilization=0.1,
30+
disable_log_stats=True,
31+
disable_log_requests=True,
32+
)
33+
34+
async_llm = AsyncLLM.from_engine_args(
35+
engine_args,
36+
start_engine_loop=False,
37+
)
38+
39+
assert mock_register.called, (
40+
"multiprocessing.reducer.register should have been called")
41+
42+
vllm_config_registered = False
43+
for call_args in mock_register.call_args_list:
44+
# Verify that a reducer for VllmConfig was registered
45+
if len(call_args[0]) >= 2 and call_args[0][0] == VllmConfig:
46+
vllm_config_registered = True
47+
48+
reducer_func = call_args[0][1]
49+
assert callable(
50+
reducer_func), "Reducer function should be callable"
51+
break
52+
53+
assert vllm_config_registered, (
54+
"VllmConfig should have been registered to multiprocessing.reducer"
55+
)
56+
57+
async_llm.shutdown()

vllm/transformers_utils/config.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -655,34 +655,35 @@ class module does not need to be importable on the receiving end.
655655
""" # noqa
656656
try:
657657
import transformers_modules
658+
transformers_modules_available = True
658659
except ImportError:
659-
# the config does not need trust_remote_code
660-
return
660+
transformers_modules_available = False
661661

662662
try:
663-
import cloudpickle
664-
cloudpickle.register_pickle_by_value(transformers_modules)
665-
666-
# ray vendors its own version of cloudpickle
667-
from vllm.executor.ray_utils import ray
668-
if ray:
669-
ray.cloudpickle.register_pickle_by_value(transformers_modules)
670-
671-
# multiprocessing uses pickle to serialize arguments when using spawn
672-
# Here we get pickle to use cloudpickle to serialize config objects
673-
# that contain instances of the custom config class to avoid
674-
# serialization problems if the generated module (and model) has a `.`
675-
# in its name
676663
import multiprocessing
677664
import pickle
678665

666+
import cloudpickle
667+
679668
from vllm.config import VllmConfig
680669

670+
# Register multiprocessing reducers to handle cross-process
671+
# serialization of VllmConfig objects that may contain custom configs
672+
# from transformers_modules
681673
def _reduce_config(config: VllmConfig):
682674
return (pickle.loads, (cloudpickle.dumps(config), ))
683675

684676
multiprocessing.reducer.register(VllmConfig, _reduce_config)
685677

678+
# Register transformers_modules with cloudpickle if available
679+
if transformers_modules_available:
680+
cloudpickle.register_pickle_by_value(transformers_modules)
681+
682+
# ray vendors its own version of cloudpickle
683+
from vllm.executor.ray_utils import ray
684+
if ray:
685+
ray.cloudpickle.register_pickle_by_value(transformers_modules)
686+
686687
except Exception as e:
687688
logger.warning(
688689
"Unable to register remote classes used by"

0 commit comments

Comments
 (0)