Skip to content

Commit 92b1733

Browse files
committed
FA2 and FlashInfer Full cuda graph support
Signed-off-by: fhl <2410591650@qq.com>
1 parent 3443aaf commit 92b1733

File tree

10 files changed

+460
-64
lines changed

10 files changed

+460
-64
lines changed

vllm/compilation/backends.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
563563

564564
self._called = True
565565

566-
if not self.compilation_config.use_cudagraph or \
567-
not self.compilation_config.cudagraph_copy_inputs:
568-
return self.split_gm
569-
570566
# if we need to copy input buffers for cudagraph
571567
from torch._guards import detect_fake_mode
572568
fake_mode = detect_fake_mode()
@@ -585,6 +581,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
585581
any(is_symbolic(d) for d in x.size())
586582
]
587583

584+
if self.compilation_config.full_cuda_graph:
585+
assert self.compilation_config.use_cudagraph, \
586+
"full_cuda_graph mode requires use_cudagraph to be True"
587+
fullgraph_wrapper = resolve_obj_by_qualname(
588+
current_platform.get_fullgraph_wrapper_cls())
589+
self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config,
590+
self.graph_pool, self.sym_tensor_indices)
591+
592+
if not self.compilation_config.use_cudagraph or \
593+
not self.compilation_config.cudagraph_copy_inputs:
594+
return self.split_gm
595+
588596
# compiler managed cudagraph input buffers
589597
# we assume the first run with symbolic shapes
590598
# has the maximum size among all the tensors

vllm/compilation/base_piecewise_backend.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,46 @@ def __call__(self, *args) -> Any:
7070
or a replayed static graph.
7171
"""
7272
raise NotImplementedError
73+
74+
75+
class AbstractFullgraphWrapper(Protocol):
76+
"""
77+
FullgraphWrapper interface that allows platforms to wrap the piecewise graph
78+
to be viewed or captured as a full graph.
79+
"""
80+
81+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
82+
graph_pool: Any, sym_shape_indices: list[int], **kwargs):
83+
"""
84+
Initializes the FullgraphWrapper class with compilation and
85+
execution-related configurations.
86+
87+
Args:
88+
graph (fx.GraphModule): The graph represented in fx.
89+
vllm_config (VllmConfig): Global configuration for vLLM.
90+
graph_pool (Any):
91+
Graph memory pool handle, e.g.,
92+
`torch.cuda.graph_pool_handle()`.
93+
sym_shape_indices (list[int]):
94+
Indices of symbolic shape.
95+
96+
Keyword Args:
97+
kwargs: Additional keyword arguments reserved for future
98+
extensions or custom platforms.
99+
100+
"""
101+
raise NotImplementedError
102+
103+
def __call__(self, *args) -> Any:
104+
"""
105+
Executes the wrapped graph for given input args.
106+
107+
Args:
108+
*args: Variable length input arguments to be passed into the
109+
graph. The symbolic shape is expected to be in position
110+
`sym_shape_indices[0]`.
111+
112+
Returns:
113+
Any: Output of the executed wrapped graph.
114+
"""
115+
raise NotImplementedError

vllm/compilation/cuda_piecewise_backend.py

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class ConcreteSizeEntry:
3737
# during capture, and check if they are the same during replay
3838
input_addresses: Optional[list[int]] = None
3939

40+
usage_type: Optional[str] = None
41+
4042

4143
class CUDAPiecewiseBackend:
4244

@@ -96,6 +98,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
9698
runtime_shape=shape,
9799
need_to_compile=shape in self.compile_sizes,
98100
use_cudagraph=shape in self.cudagraph_capture_sizes,
101+
usage_type="piecewise(general)", # for logging only
99102
)
100103

101104
def check_for_ending_compilation(self):
@@ -139,27 +142,32 @@ def __call__(self, *args) -> Any:
139142
self.check_for_ending_compilation()
140143

141144
# Skip CUDA graphs if this entry doesn't use them OR
142-
# if we're supposed to skip them globally
143-
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
144-
if not entry.use_cudagraph or skip_cuda_graphs:
145+
# if we're supposed to treat the piecewise graphs as a whole,
146+
# which implies forward_context.skip_attention_cuda_graphs is False.
147+
# In the latter case, we rely on a wrapper class to capture
148+
# the full cudagraph outside the fx graph.
149+
skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs
150+
if not entry.use_cudagraph or not skip_attention_cuda_graphs:
145151
return entry.runnable(*args)
146152

147153
if entry.cudagraph is None:
148154
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
149155
entry.num_finished_warmup += 1
150156
if self.is_first_graph:
151157
logger.debug(
152-
"Warming up %s/%s for shape %s",
158+
"Warming up %s/%s of %s usage for shape %s",
153159
entry.num_finished_warmup,
154160
self.compilation_config.cudagraph_num_of_warmups,
161+
entry.usage_type,
155162
runtime_shape)
156163
return entry.runnable(*args)
157164

158165
if self.is_first_graph:
159166
# Since we capture cudagraph for many different shapes and
160167
# capturing is fast, we don't need to log it for every shape.
161168
# We only log it in the debug mode.
162-
logger.debug("Capturing a cudagraph for shape %s",
169+
logger.debug("Capturing a cudagraph of %s usage for shape %s",
170+
entry.usage_type,
163171
runtime_shape)
164172

165173
input_addresses = [
@@ -216,3 +224,137 @@ def __call__(self, *args) -> Any:
216224

217225
entry.cudagraph.replay()
218226
return entry.output
227+
228+
229+
class FullCudagraphWrapper:
230+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
231+
graph_pool: Any, sym_shape_indices: list[int],
232+
):
233+
self.graph = graph
234+
self.vllm_config = vllm_config
235+
self.compilation_config = vllm_config.compilation_config
236+
self.graph_pool = graph_pool
237+
self.sym_shape_indices = sym_shape_indices
238+
239+
self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine
240+
241+
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
242+
243+
self.first_run_finished = False
244+
245+
self.cudagraph_capture_sizes: set[int] = set(
246+
self.compilation_config.cudagraph_capture_sizes
247+
) if self.compilation_config.use_cudagraph else set()
248+
249+
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
250+
self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {}
251+
252+
253+
for shape in self.cudagraph_capture_sizes:
254+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
255+
runtime_shape=shape,
256+
need_to_compile=False,
257+
use_cudagraph=True,
258+
usage_type="general",
259+
)
260+
if self.separate_attention_routine:
261+
self.concrete_size_entries_decode[shape] = ConcreteSizeEntry(
262+
runtime_shape=shape,
263+
need_to_compile=False,
264+
use_cudagraph=True,
265+
usage_type="decode",
266+
)
267+
268+
def __call__(self, *args) -> Any:
269+
if not self.first_run_finished:
270+
self.first_run_finished = True
271+
return self.graph(*args)
272+
list_args = list(args)
273+
runtime_shape = list_args[self.sym_shape_indices[0]].shape[0]
274+
forward_context = get_forward_context()
275+
276+
if forward_context.skip_attention_cuda_graphs:
277+
# turn back to piecewise cudagraphs backend, which is responsible
278+
# for capturing and running the piecewise cudagraphs.
279+
return self.graph(*args)
280+
281+
# if not skip, the fx graph and its sub-graphs will only be supposed to
282+
# eagerly run the compiled graphs, which should be cudagraph capturable
283+
# as a whole.
284+
285+
concrete_size_entries = self.concrete_size_entries # default as general usage
286+
if self.separate_attention_routine and forward_context.is_pure_decoding:
287+
concrete_size_entries = self.concrete_size_entries_decode
288+
289+
if not runtime_shape in concrete_size_entries:
290+
# we don't need to do anything for this shape.
291+
return self.graph(*args)
292+
293+
entry = concrete_size_entries[runtime_shape]
294+
295+
if entry.runnable is None:
296+
entry.runnable = self.graph
297+
298+
if not entry.use_cudagraph:
299+
return entry.runnable(*args)
300+
301+
if entry.cudagraph is None:
302+
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
303+
entry.num_finished_warmup += 1
304+
logger.debug(
305+
"Warming up %s/%s of %s usage for shape %s",
306+
entry.num_finished_warmup,
307+
self.compilation_config.cudagraph_num_of_warmups,
308+
entry.usage_type,
309+
runtime_shape)
310+
return entry.runnable(*args)
311+
312+
313+
# Since we capture cudagraph for many different shapes and
314+
# capturing is fast, we don't need to log it for every shape.
315+
# We only log it in the debug mode.
316+
317+
logger.debug("Capturing a cudagraph of %s usage for shape %s",
318+
entry.usage_type,
319+
runtime_shape)
320+
321+
input_addresses = [
322+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
323+
]
324+
entry.input_addresses = input_addresses
325+
cudagraph = torch.cuda.CUDAGraph()
326+
327+
with ExitStack() as stack:
328+
# mind-exploding: carefully manage the reference and memory.
329+
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
330+
# `output` is managed by pytorch's cudagraph pool
331+
output = entry.runnable(*args)
332+
# by converting it to weak ref,
333+
# the original `output` will immediately be released
334+
# to save memory.
335+
output = weak_ref_tensors(output)
336+
337+
# here we always use weak ref for the output
338+
# to save memory
339+
entry.output = weak_ref_tensors(output)
340+
entry.cudagraph = cudagraph
341+
342+
compilation_counter.num_cudagraph_captured += 1
343+
344+
# important: we need to return the output, rather than
345+
# the weak ref of the output, so that pytorch can correctly
346+
# manage the memory during cuda graph capture
347+
return output
348+
349+
if self.is_debugging_mode:
350+
# check if the input addresses are the same
351+
new_input_addresses = [
352+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
353+
]
354+
assert new_input_addresses == entry.input_addresses, (
355+
"Input addresses for cudagraphs are different during replay."
356+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
357+
)
358+
359+
entry.cudagraph.replay()
360+
return entry.output

vllm/config.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3974,6 +3974,14 @@ class CompilationConfig:
39743974
splitting certain operations such as attention into subgraphs. Thus this
39753975
flag cannot be used together with splitting_ops. This may provide
39763976
performance benefits for smaller models."""
3977+
separate_attention_routine: bool = False
3978+
"""
3979+
Enable a distinct attention calls routine under an attention backend for full
3980+
cuda graph capturing. This is because some attention backends like FlashMLA,
3981+
FlashInfer, FA2, etc. implement different branches for mix prefill-decode and
3982+
pure decode cases. This flag enables us to potentially capture the cudagraph
3983+
separately for each branch.
3984+
"""
39773985

39783986
pass_config: PassConfig = field(default_factory=PassConfig)
39793987
"""Custom inductor passes, see PassConfig for more details"""
@@ -4172,13 +4180,16 @@ def init_with_cudagraph_sizes(self,
41724180

41734181
def set_splitting_ops_for_v1(self):
41744182
# NOTE: this function needs to be called
4175-
if self.splitting_ops and self.full_cuda_graph:
4176-
raise ValueError("full_cuda_graph cannot be used together with "
4177-
"splitting_ops, as Full CUDA graph will override "
4178-
f"the splitting_ops: {self.splitting_ops}")
4179-
4183+
# NOTE: When full_cuda_graph is True, instead of setting an empty
4184+
# list and capture the full cudagraph inside the flattened fx graph,
4185+
# we keep the piecewise fx graph structure but capture the full
4186+
# cudagraph outside the fx graph. This reduces some cpu overhead when
4187+
# the runtime batch_size is not cudagraph captured. This is only
4188+
# supported for separate_attention_routine.
4189+
if self.separate_attention_routine:
4190+
assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True"
41804191
if not self.splitting_ops:
4181-
self.splitting_ops = [] if self.full_cuda_graph else [
4192+
self.splitting_ops = [
41824193
"vllm.unified_attention",
41834194
"vllm.unified_attention_with_output",
41844195
]

vllm/forward_context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ class ForwardContext:
9494
virtual_engine: int # set dynamically for each forward pass
9595
# set dynamically for each forward pass
9696
dp_metadata: Optional[DPMetadata] = None
97-
skip_cuda_graphs: bool = False
97+
# determine whether to use a full cudagraph for attention or piecewise
98+
# cudagraphs that skip the attention part. By default true, we use piecewise
99+
# cudagraphs.
100+
skip_attention_cuda_graphs: bool = True,
101+
is_pure_decoding: bool = False
98102

99103

100104
_forward_context: Optional[ForwardContext] = None
@@ -115,7 +119,8 @@ def set_forward_context(
115119
virtual_engine: int = 0,
116120
num_tokens: Optional[int] = None,
117121
num_tokens_across_dp: Optional[torch.Tensor] = None,
118-
skip_cuda_graphs: bool = False,
122+
skip_attention_cuda_graphs: bool = True,
123+
is_pure_decoding: bool = False,
119124
):
120125
"""A context manager that stores the current forward context,
121126
can be attention metadata, etc.
@@ -140,7 +145,8 @@ def set_forward_context(
140145
virtual_engine=virtual_engine,
141146
attn_metadata=attn_metadata,
142147
dp_metadata=dp_metadata,
143-
skip_cuda_graphs=skip_cuda_graphs,
148+
skip_attention_cuda_graphs=skip_attention_cuda_graphs,
149+
is_pure_decoding=is_pure_decoding,
144150
)
145151

146152
try:

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@ def use_custom_allreduce(cls) -> bool:
370370
@classmethod
371371
def get_piecewise_backend_cls(cls) -> str:
372372
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
373+
374+
@classmethod
375+
def get_fullgraph_wrapper_cls(cls) -> str:
376+
return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa
373377

374378
@classmethod
375379
def stateless_init_device_torch_dist_pg(

vllm/platforms/interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,14 @@ def get_piecewise_backend_cls(cls) -> str:
531531
Get piecewise backend class for piecewise graph.
532532
"""
533533
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
534+
535+
@classmethod
536+
def get_fullgraph_wrapper_cls(cls) -> str:
537+
"""
538+
Get fullgraph wrapper class for fullgraph static graph.
539+
"""
540+
return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa
541+
534542

535543
@classmethod
536544
def stateless_init_device_torch_dist_pg(

0 commit comments

Comments
 (0)