Skip to content

Commit c2c5fea

Browse files
committed
Refactors
Signed-off-by: fhl <2410591650@qq.com>
1 parent 58ce477 commit c2c5fea

File tree

5 files changed

+36
-25
lines changed

5 files changed

+36
-25
lines changed

vllm/compilation/cuda_piecewise_backend.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
236236
self.graph_pool = graph_pool
237237
self.sym_shape_indices = sym_shape_indices
238238

239-
self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine
239+
self.separate_attention_routine = (
240+
vllm_config.compilation_config.separate_attention_routine
241+
)
240242

241243
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
242244

@@ -282,7 +284,7 @@ def __call__(self, *args) -> Any:
282284
# eagerly run the compiled graphs, which should be cudagraph capturable
283285
# as a whole.
284286

285-
concrete_size_entries = self.concrete_size_entries # default as general usage
287+
concrete_size_entries = self.concrete_size_entries
286288
if self.separate_attention_routine and forward_context.is_pure_decoding:
287289
concrete_size_entries = self.concrete_size_entries_decode
288290

@@ -324,15 +326,16 @@ def __call__(self, *args) -> Any:
324326
entry.input_addresses = input_addresses
325327
cudagraph = torch.cuda.CUDAGraph()
326328

327-
with ExitStack() as stack:
329+
with ExitStack(), \
330+
torch.cuda.graph(cudagraph, pool=self.graph_pool):
328331
# mind-exploding: carefully manage the reference and memory.
329-
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
330-
# `output` is managed by pytorch's cudagraph pool
331-
output = entry.runnable(*args)
332-
# by converting it to weak ref,
333-
# the original `output` will immediately be released
334-
# to save memory.
335-
output = weak_ref_tensors(output)
332+
333+
# `output` is managed by pytorch's cudagraph pool
334+
output = entry.runnable(*args)
335+
# by converting it to weak ref,
336+
# the original `output` will immediately be released
337+
# to save memory.
338+
output = weak_ref_tensors(output)
336339

337340
# here we always use weak ref for the output
338341
# to save memory

vllm/config.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3976,11 +3976,11 @@ class CompilationConfig:
39763976
performance benefits for smaller models."""
39773977
separate_attention_routine: bool = False
39783978
"""
3979-
Enable a distinct attention calls routine under an attention backend for full
3980-
cuda graph capturing. This is because some attention backends like FlashMLA,
3981-
FlashInfer, FA2, etc. implement different branches for mix prefill-decode and
3982-
pure decode cases. This flag enables us to potentially capture the cudagraph
3983-
separately for each branch.
3979+
Enable a distinct attention calls routine under an attention backend for
3980+
full cuda graph capturing. This is because some attention backends like
3981+
FlashMLA, FlashInfer, FA2, etc. implement different branches for mix
3982+
prefill-decode and pure decode cases. This flag enables us to potentially
3983+
capture the cudagraph separately for each branch.
39843984
"""
39853985

39863986
pass_config: PassConfig = field(default_factory=PassConfig)
@@ -4187,7 +4187,10 @@ def set_splitting_ops_for_v1(self):
41874187
# the runtime batch_size is not cudagraph captured. This is only
41884188
# supported for separate_attention_routine.
41894189
if self.separate_attention_routine:
4190-
assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True"
4190+
assert self.full_cuda_graph, (
4191+
"separate_attention_routine requires "
4192+
"full_cuda_graph to be True"
4193+
)
41914194
if not self.splitting_ops:
41924195
self.splitting_ops = [
41934196
"vllm.unified_attention",

vllm/forward_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class ForwardContext:
9797
# determine whether to use a full cudagraph for attention or piecewise
9898
# cudagraphs that skip the attention part. By default true, we use piecewise
9999
# cudagraphs.
100-
skip_attention_cuda_graphs: bool = True,
100+
skip_attention_cuda_graphs: bool = True
101101
is_pure_decoding: bool = False
102102

103103

vllm/v1/attention/backends/flashinfer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,17 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
227227
self._prefill_wrapper = None # Wrapper for prefill/append
228228
self._decode_wrapper = None # Wrapper for decode
229229
self._decode_wrapper = None # Wrapper for decode (general shape)
230-
self.enable_cuda_graph = self.vllm_config.compilation_config.full_cuda_graph
230+
self.enable_cuda_graph = (
231+
self.vllm_config.compilation_config.full_cuda_graph
232+
)
231233
if self.enable_cuda_graph:
232234
# For full cudagraph capture, one `decode_wrapper` for each batch
233235
# size is needed for FlashInfer.
234-
self._decode_wrappers_cudagraph: dict[int, BatchDecodeWithPagedKVCacheWrapper] = {}
235-
self._decode_cudagraph_max_bs = min(runner.max_num_reqs,
236-
runner.cudagraph_batch_sizes[-1])
237-
236+
self._decode_wrappers_cudagraph: dict[int,
237+
BatchDecodeWithPagedKVCacheWrapper] = {}
238+
self._decode_cudagraph_max_bs = min(
239+
runner.max_num_reqs, runner.cudagraph_batch_sizes[-1])
240+
238241
self._cascade_wrapper = None # Wrapper for cascade attention
239242

240243
# Global hyperparameters shared by all attention layers
@@ -446,8 +449,9 @@ def _plan(self, attn_metadata: FlashInferMetadata):
446449
use_cudagraph = (self.enable_cuda_graph and pure_decode and \
447450
self._num_decodes <= self._decode_cudagraph_max_bs)
448451
if use_cudagraph:
449-
num_input_tokens_decode = self.vllm_config.pad_for_cudagraph(
450-
self._num_decodes)
452+
num_input_tokens_decode = (
453+
self.vllm_config.pad_for_cudagraph(self._num_decodes)
454+
)
451455
else:
452456
num_input_tokens_decode = self._num_decodes
453457

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,8 @@ def execute_model(
13391339
if self.full_cuda_graph else True
13401340
# Note: When skip_attention_cuda_graphs is always False and
13411341
# compilition_config.separate_attention_routine is True, as in FA2,
1342-
# this flag helps to determine the correct routine to run for the full cudagraph.
1342+
# this flag helps to determine the correct routine for the full
1343+
# cudagraph.
13431344
is_pure_decoding = num_scheduled_tokens == self.input_batch.num_reqs
13441345

13451346
# Run the model.

0 commit comments

Comments
 (0)