Skip to content

Commit 9d02ece

Browse files
zou3519wwl2755-google
authored andcommitted
[BugFix] Fix use_cudagraph=False (vllm-project#19612)
Signed-off-by: Richard Zou <zou3519@gmail.com>
1 parent 4ed39ce commit 9d02ece

File tree

3 files changed

+35
-29
lines changed

3 files changed

+35
-29
lines changed

tests/compile/test_config.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import pytest
4-
import torch
54

65
import vllm
76
from vllm.compilation.counter import compilation_counter
8-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
9-
set_current_vllm_config)
10-
11-
from .piecewise.test_simple import SillyModel
7+
from vllm.config import VllmConfig
128

139

1410
def test_use_cudagraphs_dynamic(monkeypatch):
@@ -22,23 +18,24 @@ def test_use_cudagraphs_dynamic(monkeypatch):
2218

2319

2420
@pytest.mark.parametrize("enabled", [True, False])
25-
def test_use_cudagraphs(enabled):
21+
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
2622
assert vllm.envs.VLLM_USE_V1
27-
vllm_config = VllmConfig(compilation_config=CompilationConfig(
28-
level=CompilationLevel.PIECEWISE,
29-
use_cudagraph=enabled,
30-
cudagraph_capture_sizes=[100],
31-
))
32-
with set_current_vllm_config(vllm_config):
33-
model = SillyModel(vllm_config=vllm_config, prefix='')
34-
35-
inputs = torch.randn(100, device="cuda")
36-
37-
with compilation_counter.expect(
38-
num_graphs_seen=1, # one graph for the model
39-
num_cudagraph_captured=1 if enabled else 0,
40-
):
41-
# first run is warmup
42-
model(inputs)
43-
# second run does CUDAGraphs recording (if enabled)
44-
model(inputs)
23+
24+
# Disable multiprocessing so that the counter is in the same process
25+
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
26+
27+
compilation_config = {
28+
"cudagraph_capture_sizes": [100],
29+
"use_cudagraph": enabled,
30+
}
31+
with (
32+
compilation_counter.expect(
33+
num_graphs_seen=1,
34+
num_gpu_runner_capture_triggers=1 if enabled else 0,
35+
num_cudagraph_captured=13 if enabled else 0,
36+
),
37+
# loading the model causes compilation (if enabled) to happen
38+
vllm_runner('facebook/opt-125m',
39+
compilation_config=compilation_config,
40+
gpu_memory_utilization=0.4) as _):
41+
pass

vllm/compilation/counter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ class CompilationCounter:
1515
# not including the splitting ops
1616
num_piecewise_capturable_graphs_seen: int = 0
1717
num_backend_compilations: int = 0
18+
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
19+
num_gpu_runner_capture_triggers: int = 0
20+
# Number of CUDAGraphs captured
1821
num_cudagraph_captured: int = 0
1922
# InductorAdapter.compile calls
2023
num_inductor_compiles: int = 0

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.attention import AttentionType, get_attn_backend
1919
from vllm.attention.backends.abstract import AttentionBackend
2020
from vllm.attention.layer import Attention
21+
from vllm.compilation.counter import compilation_counter
2122
from vllm.config import (CompilationLevel, VllmConfig,
2223
get_layers_from_vllm_config)
2324
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -200,9 +201,11 @@ def __init__(
200201
block_sizes=[self.cache_config.block_size],
201202
)
202203

203-
self.use_cuda_graph = (self.compilation_config.level
204-
== CompilationLevel.PIECEWISE
205-
and not self.model_config.enforce_eager)
204+
self.use_cuda_graph = (
205+
self.vllm_config.compilation_config.level
206+
== CompilationLevel.PIECEWISE
207+
and self.vllm_config.compilation_config.use_cudagraph
208+
and not self.model_config.enforce_eager)
206209
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
207210
# The convention is different.
208211
# self.cudagraph_batch_sizes sorts in ascending order.
@@ -2058,10 +2061,13 @@ def profile_run(self) -> None:
20582061
def capture_model(self) -> None:
20592062
if not self.use_cuda_graph:
20602063
logger.warning(
2061-
"Skipping CUDA graph capture. Please add "
2062-
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
2064+
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
2065+
"set -O %s and ensure `use_cudagraph` was not manually set to "
2066+
"False", CompilationLevel.PIECEWISE)
20632067
return
20642068

2069+
compilation_counter.num_gpu_runner_capture_triggers += 1
2070+
20652071
start_time = time.perf_counter()
20662072
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
20672073

0 commit comments

Comments
 (0)