Skip to content

Commit ffbcc9e

Browse files
authored
[BugFix] Fix VllmConfig() construction on all platforms (#20695)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 59389c9 commit ffbcc9e

File tree

5 files changed

+19
-16
lines changed

5 files changed

+19
-16
lines changed

vllm/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4722,7 +4722,6 @@ def _set_cudagraph_sizes(self):
47224722
# calculate the default `batch_size_capture_list`
47234723
if not envs.VLLM_USE_V1:
47244724
batch_size_capture_list = []
4725-
max_batchsize_to_capture = 0
47264725
if self.scheduler_config is not None and \
47274726
self.model_config is not None and \
47284727
not self.model_config.enforce_eager:

vllm/platforms/cpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
9696
from vllm.utils import GiB_bytes
9797
model_config = vllm_config.model_config
9898

99-
model_config.disable_cascade_attn = True
99+
if model_config is not None:
100+
model_config.disable_cascade_attn = True
100101

101102
cache_config = vllm_config.cache_config
102103

@@ -123,7 +124,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
123124
"CPU backend doesn't support fp8_e4m3 KV cache type, "
124125
"cast to fp8_e5m2.")
125126

126-
if (cache_config.cache_dtype != "auto"
127+
if (cache_config.cache_dtype != "auto" and model_config is not None
127128
and model_config.dtype == torch.half):
128129
logger.warning("FP8 KV cache on the CPU backend only does not"
129130
" support fp16 for now, cast to bf16.")
@@ -229,7 +230,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
229230
os.environ["LOCAL_WORLD_SIZE"] = str(
230231
vllm_config.parallel_config.tensor_parallel_size)
231232

232-
if vllm_config.model_config and vllm_config.model_config.use_mla:
233+
if model_config is not None and model_config.use_mla:
233234
logger.info(
234235
"MLA is enabled on a non-GPU platform; forcing chunked "
235236
"prefill and prefix caching to be disabled.")

vllm/platforms/cuda.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,19 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
166166
logger.info(
167167
"Forcing kv cache block size to 64 for FlashMLA backend.")
168168

169+
compilation_config = vllm_config.compilation_config
169170
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
170171
and parallel_config.data_parallel_size > 1
171-
and vllm_config.compilation_config.use_cudagraph):
172+
and compilation_config.use_cudagraph):
172173
logger.info(
173174
"Data Parallel: Forcing enforce eager to be True since DP "
174175
"with DeepEP high-throughput kernels are not CUDA Graph "
175176
"compatible. The DeepEP low-latency kernels are CUDA Graph "
176177
"compatible. Set the all_to_all backend to deepep_low_latency "
177178
"to use those kernels instead.")
178-
vllm_config.compilation_config.use_cudagraph = False
179-
vllm_config.model_config.enforce_eager = True
179+
compilation_config.use_cudagraph = False
180+
if model_config is not None:
181+
model_config.enforce_eager = True
180182
# TODO (varun): Turning this ON gives incorrect results for the
181183
# Deepseek-V2-lite model.
182184
vllm_config.compilation_config.use_inductor = False

vllm/platforms/tpu.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
116116
assert vllm_config.speculative_config is None, \
117117
"TPU does not support speculative decoding"
118118

119-
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
119+
model_config = vllm_config.model_config
120+
if model_config is not None and model_config.dtype in (torch.float16,
121+
torch.float32):
120122
logger.warning(
121123
"The TPU backend currently does not support %s. "
122-
"Using bfloat16 instead.", vllm_config.model_config.dtype)
123-
vllm_config.model_config.dtype = torch.bfloat16
124+
"Using bfloat16 instead.", model_config.dtype)
125+
model_config.dtype = torch.bfloat16
124126

125127
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
126128
cache_config.block_size = PallasAttentionBackend.get_page_size(
@@ -146,7 +148,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
146148
"Forcing --disable_chunked_mm_input.")
147149
scheduler_config.disable_chunked_mm_input = True
148150

149-
if vllm_config.model_config and vllm_config.model_config.use_mla:
151+
if model_config and model_config.use_mla:
150152
logger.info(
151153
"MLA is enabled on a non-GPU platform; forcing chunked "
152154
"prefill and prefix caching to be disabled.")

vllm/platforms/xpu.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,22 @@ def inference_mode(cls):
8585
@classmethod
8686
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8787
cache_config = vllm_config.cache_config
88+
model_config = vllm_config.model_config
8889
# in V1(or with ipex chunked prefill) block_size is 64
8990
if cache_config and cache_config.block_size is None:
9091
cache_config.block_size = 64
9192

9293
# FIXME: Temporarily forcing eager mode
9394
# remove after t.compile support stabilizes.
94-
95-
if (envs.VLLM_USE_V1 and vllm_config.model_config is not None
95+
if (envs.VLLM_USE_V1 and model_config is not None
9696
and not vllm_config.model_config.enforce_eager):
9797
from vllm.config import CompilationLevel
9898
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501
9999

100100
# Instances created using VllmConfig() typically have model_config as
101101
# None by default. The modification involves adding a check to prevent
102102
# potential null exceptions check and update model config.
103-
if vllm_config.model_config is not None:
104-
model_config = vllm_config.model_config
103+
if model_config is not None:
105104
if model_config.dtype == torch.bfloat16:
106105
bf16_supported = cls.device_support_bf16()
107106
if not bf16_supported:
@@ -139,7 +138,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
139138
parallel_config.distributed_executor_backend)
140139
parallel_config.distributed_executor_backend = "ray"
141140

142-
if vllm_config.model_config and vllm_config.model_config.use_mla:
141+
if model_config and model_config.use_mla:
143142
logger.info(
144143
"MLA is enabled on a non-GPU platform; forcing chunked "
145144
"prefill and prefix caching to be disabled.")

0 commit comments

Comments
 (0)