Skip to content

[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: main
Choose a base branch
from

Conversation

fhl2000
Copy link

@fhl2000 fhl2000 commented Jun 25, 2025

Purpose

1. This PR introduces a new implementation for full cuda graph, and adds support for FA2 and FlashInfer.

Previous limitations

The original design in PR #16072 is to set compilation_config.splitting_ops as an empty list and capture the full cudagraph inside the flattened fx graph, which supports FA3 only. In later PR #18581, full cudagraph support for FlashMLA only captures the pure decode stage, and bypasses the mix prefill-decode stages, i.e., it runs the eager code of the compiled flattened fx graph in this stage. However, from the profiling results(see below), I found this flattened graph has performance issues at eager call, which is about 2x slower on the cpu side than the compiled piecewise fx graph running (possibly an issue from Python). This can lead to potential performance degradation when the prefill stage of a small batch size.

Also, considering that attention backends, like FA2, FlashInfer, and FlashMLA, have two distinct attention routines for prefill-decode stages and pure decode stages separately, which makes it difficult to contain all in a unified graph and only keeps one set of captured cudagraphs.

Solution of this PR.

So, the new trick is, we keep the piecewise compiled fx graph structure overall, but capture the full cudagraph outside the fx graph via a wrapper. With this at hand, we can dispatch to two sets of cudagraph. For the pure decode stage, directly using full cudagraphs since it is compatible with most attention backends. For mix prefill-decode stages, it can either fall back to piecewise cudagraph for incompatible routines in backends like FlashMLA and FlashInfer, or to use another set of full cudagraph for compatible backends(varlen supports in FA2).

Note that keeping the piecewise compiled fx graph is at least better than a full but flattened one from the viewpoint of reducing cpu overhead, even if we do not capture the mix prefill-decode stage. It is also flexible to switch between full cudagraph and piecewise cudagraph for future extension. For example, seamless fallback to piecewise cudagraph if cascade attention is needed.

The limitation is the increased startup time and more gpu memory required for the additional cudagraph capturing. Maybe we can optimize this by shrinking the list of batch sizes to be captured for the prefill-decode stage.

#profile on compiled flatten fx graph on eager execution, mix prefill-decode stage.

Takes roughly 56ms to fully launch the model. An additional 5ms latency in doing some safety checking before launching the first kernel. It seems Python is slow at executing the flattened and large module without submodules.
image

Note: the only way to use flatten fx graph in this PR is to hardcode the splitting_ops =[] in set_splitting_ops_for_v1 (around line 4200 in vllm/config.py)

#profile on compiled piecewise fx graph on eager execution, mix prefill-decode stage.

28 ms to fully launch, and the latency above almost disappears. In fact, they are hidden inside each submodule.
image

The patterns above are verified on two different machines (ignoring the gpu difference here as this is only related to cpu), tested on Qwen2.5-7B-Instruct-GPTQ-Int4 and profile benchmark_serving (sharegpt, unlimited request rate).

So, if a prefill batch size is a bit larger than the max capturing size (say 512) but not too big, the lower bound of model forward time is possibly bounded by cpu side, around 56ms in running the flattened graph, instead of 28ms for the piecewise one.

Details for supporting FA2:

The previous codes did not recognize the two routines under the FA2 code. It launches a standard varlen fwd kernel on mix prefill-decode batches. or launches another routine for pure decode batches, including an optimization for GQA/MQA and potential flash-decode kernels (split_kv >1). By setting max_query_len =1 or >1 on cuda capturing phase, we can correctly activate the desired attention routine, therefore to be correctly captured. (To be serious, the kernel for prefill-decode phase is, of course, compatible with pure decode, but is not fully optimized for decode phase. The actual reason PR #16072 did not support FA2 is a bug that the seq_lens is a zero tensor in the dummy_run in the early code, which bypasses launching any attention kernel at the capturing phase, leading to zero tensor outputs.)

  • FA2 runs both mix prefill-decode and pure decode batches at full cudagraph, but on two separate sets of cudagraphs.

Details for supporting FlashInfer:

  • Using the persistent buffer trick.
  • Create many decode_warpers, one for a cudagraph batch size, as this is required by the FlashInfer API.
  • Run pure decode batches at full cudagraph, and fall back to piecewise cudagraph at mix prefill-decode batches.

Launching command examples:

For FA2:

VLLM_FLASH_ATTN_VERSION=2 python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --compilation-config '{"full_cuda_graph":true, "separate_attention_routine":true}'

For FlashInfer:

VLLM_ATTENTION_BACKEND=FLASHINFER python -m ... --compilation-config '{"full_cuda_graph":true,"separate_attention_routine":true}'

others:
FlashMLA: the compilation-config is '{"full_cuda_graph":true,"separate_attention_routine":true}'
FA3: env set VLLM_FLASH_ATTN_VERSION=3 and the compilation-config is '{"full_cuda_graph":true}'

Test Plan

benchmark serving, lm_eval performance of FA2 and FlashInfer

I have no plan to test FlashMLA and FA3 as no hopper gpu at hand, but it should be fine as the current design is compatible with them. However, it would be very nice if somebody could help test them.

Test Result

Summary of results

Output token throughput is imporved by 5% for FA2 and 2% for FlashInfer on Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4. TPOT is reduced by 2.9% and 3.1%, respectively. The lm_evel has no changes for both.

Details

machine: A100 40G, torch2.6 cuda12.4

Benchmark serving command:

python benchmarks/benchmark_serving.py --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100 --request-rate 20

FA2 benchmark serving:

piecewise cudagraph before this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.41
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 8.77
Output token throughput (tok/s): 1898.67
Total Token throughput (tok/s): 3937.88
---------------Time to First Token----------------
Mean TTFT (ms): 76.37
Median TTFT (ms): 71.08
P99 TTFT (ms): 191.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 17.08
Median TPOT (ms): 15.22
P99 TPOT (ms): 67.68
---------------Inter-token Latency----------------
Mean ITL (ms): 13.45
Median ITL (ms): 11.05
P99 ITL (ms): 72.61
==================================================

full cudagraph + piecewise fx graph in this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 10.87
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 9.20
Output token throughput (tok/s): 1992.27
Total Token throughput (tok/s): 4132.01
---------------Time to First Token----------------
Mean TTFT (ms): 78.69
Median TTFT (ms): 75.10
P99 TTFT (ms): 195.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.57
Median TPOT (ms): 14.78
P99 TPOT (ms): 78.21
---------------Inter-token Latency----------------
Mean ITL (ms): 12.83
Median ITL (ms): 10.34
P99 ITL (ms): 72.37
==================================================

FA2 lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

FlashInfer benchmark serving

piecewise cudagraph before this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.36
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.81
Output token throughput (tok/s): 1907.38
Total Token throughput (tok/s): 3955.65
---------------Time to First Token----------------
Mean TTFT (ms): 73.61
Median TTFT (ms): 69.59
P99 TTFT (ms): 184.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.85
Median TPOT (ms): 15.13
P99 TPOT (ms): 65.75
---------------Inter-token Latency----------------
Mean ITL (ms): 13.34
Median ITL (ms): 11.09
P99 ITL (ms): 71.82
==================================================

full cudagraph + piecewise fx graph after this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.13
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.99
Output token throughput (tok/s): 1946.35
Total Token throughput (tok/s): 4036.48
---------------Time to First Token----------------
Mean TTFT (ms): 76.03
Median TTFT (ms): 67.04
P99 TTFT (ms): 192.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.34
Median TPOT (ms): 14.96
P99 TPOT (ms): 58.86
---------------Inter-token Latency----------------
Mean ITL (ms): 13.11
Median ITL (ms): 10.71
P99 ITL (ms): 71.69
==================================================

FlashInfer lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

One more thing, after merging some code from the main branch recently, I ran into a potential deadlock when testing this PR. This should be caused by an early merged code, and PR #19927 seems to solve the problem.

fhl2000 added 2 commits June 25, 2025 13:36
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fhl2000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a refined approach to full CUDA graph integration within vLLM, moving from a flattened FX graph to a wrapper-based strategy that preserves the piecewise graph structure. This new method facilitates broader full CUDA graph support for attention backends like FlashAttention 2 and FlashInfer, leading to measurable performance gains. Additionally, it includes a fix for a specific kernel compilation issue, enhancing overall system stability and compatibility.

Highlights

  • Enhanced Full CUDA Graph Implementation: Introduces a new strategy for full CUDA graph capture that wraps the piecewise compiled FX graph, rather than flattening it. This aims to reduce CPU overhead for non-captured batch sizes and offers greater flexibility, allowing dispatch to different CUDA graph sets for prefill-decode and pure decode stages.
  • FA2 and FlashInfer Support: Extends full CUDA graph support to FlashAttention 2 (FA2) and FlashInfer backends. This includes specific adaptations for their distinct prefill-decode and pure decode routines, enabling performance benefits for these attention backends.
  • Performance Improvements: Benchmarking results indicate a 5% improvement in output token throughput for FA2 and a 2% improvement for FlashInfer, with corresponding reductions in Time Per Output Token (TPOT) by 2.9% and 3.1% respectively.
  • Marlin Kernel Compilation Bug Fix: Addresses a minor bug where Marlin kernels were incorrectly compiled for unsupported GPU architectures (e.g., 8.7 for RTX 4090, which is 8.9), resolving 'RuntimeError: CUDA error: no kernel image is available for execution on the device' errors.
  • Separate Attention Routine Configuration: Adds a new separate_attention_routine flag to CompilationConfig, allowing for distinct CUDA graph capturing for prefill-decode and pure decode stages within attention backends that implement different branches for these cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new implementation for full cuda graph, adds support for FA2 and FlashInfer, and fixes a bug for Marlin kernels on Ada architecture. The core idea is to keep the piecewise graph structure and wrap it to capture the full CUDA graph, which avoids the CPU overhead of a large flattened graph. The changes are well-motivated, and the performance improvements are clearly demonstrated.

Signed-off-by: fhl <2410591650@qq.com>
@fhl2000 fhl2000 force-pushed the full_cudagraph_FA2_FlashInfer branch from bcf7cb9 to c2c5fea Compare June 25, 2025 08:33
Copy link

mergify bot commented Jun 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 25, 2025
fhl2000 and others added 2 commits June 25, 2025 16:52
@mergify mergify bot removed the needs-rebase label Jun 25, 2025
fhl2000 added 2 commits June 25, 2025 10:03
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000
Copy link
Author

fhl2000 commented Jun 25, 2025

I have incorporated some checks for the new flag separate_attention_routine, so it is safe to launch now. This PR is now ready to be reviewed!

@fhl2000 fhl2000 marked this pull request as ready for review June 25, 2025 14:53
@fhl2000 fhl2000 changed the title [Core][Bugfix] new way for full cudagraph, add support for FA2 and FlashInfer; a minor bug fixed [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer; A minor bug fixed Jun 25, 2025
@fhl2000
Copy link
Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

@fhl2000
Copy link
Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

Please let me know If any questions or suggestions. I am currently planning on adding some unit tests.

Signed-off-by: fhl <2410591650@qq.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good approach overall!
My initial feedback:

  • I think we should try to consolidate CUDAGraph logic into a single class.
  • CUDAGraph logic is complex on main already, and this PR increases complexity significantly. We should add significantly more documentation. I also think we should consolidate various config flags and states.
  • There are benefits to compilation without splitting the graph (e.g. attention+quant fusion). We should add a new flag that maintains that ability (and assert the attention backend supports full cudagraph only). CUDAGraph logic can stay in the wrapper class.
  • This is a large PR, so it might help to split it. e.g. FlashInfer cg support can be added in a follow-up. But I'll let others chime in here.

Okay, this is plenty for now :D - thanks for the PR!

@@ -3984,6 +3984,14 @@ class CompilationConfig:
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
separate_attention_routine: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph? I also don't understand why this has to be a flag and we can't just ask the attention backend what it wants?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we must leave such a flag in the global config, which tells the compiler backend to do the right thing. Otherwise, how is the attention backend supposed to communicate its requirements to the compiler? At least for now, the force_separate_routine flag of an attention backend has the ability to enforce its preference during the initialize_attn_backend phase of the gpu model runner.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph?

I am not sure what name can be better. Btw, I'm afraid split_attn_cudagraph is not a good name. It sounds like splitting the full graph into be piecewise graph, where attn ops are the splitting ops, like what we have already done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on the name. Also makes sense we use this to communicate from attention backend to compiler. Let's make sure that happens inside set_splitting_ops_for_v1/somewhere inside config initialization, if we can.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should figure out a different name for this; the current name doesnt indicate any relation to cudagraphs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not as zoned into this PR as you folks are, but I have no clue what this flag is from the name.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should figure out a different name for this; the current name doesnt indicate any relation to cudagraphs

How about cudagraph_separate_routine? Cutting the "attention" out seems to have no effect on its meaning. While it is basically prepared for distinct attention routines that are actually executed, in the future, that may be more than just attention ops.

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000
Copy link
Author

fhl2000 commented Jul 10, 2025

More benchmark results after refactors.

Benchmark command:

python vllm/benchmarks/benchmark_serving.py --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100 --request-rate 10

env: A100 40G, torch2.6, cuda12.4

Flash attention v2

FA2 attn_cudagraph_support is ALWAYS_SEPARATE (while FA3 is ALWAYS_UNIFIED )

(a). O3 + piecewise cg (main)
(b). O3 + piecewise cg (this PR)
(c). O3 + full cg + single routine/mixed only (this PR)
(d*). O3 + full cg + separate routine (this PR)
(e). O0 + full cg + separate routine (this PR)

Metric \Source (a) (b) (c) (d*) (e)
Benchmark duration (s) 14.40 14.37 14.40 13.63 13.94
Request throughput (req/s) 6.94 6.96 6.95 7.34 7.17
Output token throughput (tok/s) 1503.58 1507.35 1504.43 1589.10 1553.86
Median TTFT (ms) 44.48 46.59 50.03 48.89 50.26
Median TPOT (ms) 11.04 11.19 11.55 9.86 10.54
Median ITL (ms) 9.25 9.21 9.28 8.12 8.49

FlashInfer

attn_cudagraph_support is PURE_DECODE_ONLY

(a). O3 + piecewise cg (main)
(b). O3 + piecewise cg (this PR)
(c*). O3 + full cg/decode+piecewise cg/mixed (this PR)
(d). O0 + full cg/decode+no cg/mixed (this PR)

Metric \Source (a) (b) (c*) (d)
Benchmark duration (s) 14.50 14.59 13.78 14.12
Request throughput (req/s) 6.90 6.85 7.26 7.08
Output token throughput (tok/s) 1494.69 1485.18 1571.90 1534.32
Median TTFT (ms) 45.33 48.30 48.97 49.26
Median TPOT (ms) 11.14 11.38 10.05 10.85
Median ITL (ms) 9.29 9.40 8.37 8.77

Triton_attn

attn_cudagraph_support is ALWAYS_SEPARATE

(a). O3 + piecewise cg (main)
(b). O3 + piecewise cg (this PR)
(c). O3 + full cg + single routine/mixed only (main)
(d). O3 + full cg + single routine/mixed only (this PR)
(e*). O3 + full cg + separate routine (this PR)
(f). O0 + full cg + separate routine (this PR)

Metric \Source (a) (b) (c) (d) (e*) (f)
Benchmark duration (s) 14.54 14.45 14.83 14.53 13.81 14.14
Request throughput (req/s) 6.88 6.92 6.74 6.88 7.24 7.07
Output token throughput (tok/s) 1489.51 1499.07 1460.02 1490.53 1568.67 1532.15
Median TTFT (ms) 46.12 46.26 44.85 47.95 44.76 45.13
Median TPOT (ms) 11.80 11.43 11.24 10.80 10.37 11.07
Median ITL (ms) 9.89 9.14 9.70 9.35 9.28 9.49

Copy link

mergify bot commented Jul 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 11, 2025
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
@fhl2000
Copy link
Author

fhl2000 commented Jul 11, 2025

Revisit the unexpected results of (median) TTFT:

After carefully comparing the wall duration of the mixed batch from the profiling file of piecewise cg (both main and this PR) and full cg (this PR), I ran into following clues.

1. Between the piecewise cg of main and this PR.

Found that most pieces of mixed batches from the two are almost exactly match (while a few differ due to random effects, maybe unexpected latency of the cpu execution. After testing repeatedly, I found the median TTFT is very volatile; fluctuation within half the mean ITL is considered reasonable due to traffic or random effects. Alternatively, the mean TTFT is much stable, where the turbulence is within 1ms. So I believe those refactors in this PR are safe and have very little overhead on the piecewise cudagraph execution.

2. Between piecewise cg (main) and full cg (this PR).

Yes, I observe the factor why full cg on mixed batch may be slower than the piecewise cg.
The full cudagraph for mixed batch captured an unrealistic workload of around 464 blocks per SM for attentions, instead of normally much lower workload for batchsize < 512. This is because we currently assume 512 batch size is of 256 max num requests, and each request may have up to 512 query tokens when capturing. This can lead to a few ms slower than when eagerly running attention in piecewise cudagraph.

On piecewise cg and mixed batch

image

On full cg and mixed batch

image

So, even though FA2 is compatible for mixed batch, I guess running piecewise cg for mixed batches is better than full cg, until we can find a more reasonable strategy for capturing that on mixed batches.

@yinghai, is that a good answer for you?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work addressing comments! I really like the new dispatcher class, and the get_cudagraph_runtime_style function is really nice.

There are two more large-scale questions I have:

  1. Could we remove the dependence of the CUDAGraphDispatcher on the GPUModelRunner? I think the flags it looks up can be recomputed, and the model can be passed in, either to after_load_model or dispatch.
  2. I still don't think the separate_routine semantics are fully clear.
  3. Could you also update the PR title to better reflect the new (improved) changes? I'd mention support for full cudagraphs without compilation and ability for full+piecewise cudagraphs.

I will take another more detailed look at the dispatcher and runner next round of review, but everything else looks really good (apart from a few minor comments)!

@@ -584,7 +604,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self._called = True

if not self.compilation_config.use_cudagraph or \
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it clear in config.py that cudagraph_copy_inputs only applies to piecewise cudagraph

Copy link
Author

@fhl2000 fhl2000 Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does cudagraph_copy_inputs have to be piecewise cudagraph only? I am not convinced that if any incompatible issues with full cudagraph, except existing the unnecessary overhead of extra copies. @ProExpertProg

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code copies the input buffers for piecewise cudagraphs. But I didn't see anywhere that buffers are copied outside full cudagraphs - did I miss it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the input buffer copy doesn't have differences between piecewise and full cudagraphs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code is inside the cudagraph capture for full cudagraphs and so not sure it's achieving what we want for copying inputs. This code also wouldn't run if compilation is disabled

Copy link
Author

@fhl2000 fhl2000 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad! It makes sense to me now. Looks like I was trapped in the old implementation😓

# the separate_attention_routine flag, but should inform
# the user that this flag can be turned on to obtain
# better performance.
if attn_cg == AttentionCGSupport.ALWAYS_SEPARATE and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought ALWAYS_SEPARATE meant separate routine was required, not just preferred? And with ALWAYS_UNIFIED a separate routine is not required? Separate routine should always be allowed (because worst-case we capture 2 cudagraphs with the same attention routine)

Copy link
Author

@fhl2000 fhl2000 Jul 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post from Slack we discussed:

AlWAYS_SEPARATE basically means it supports full cudagraph for both routines, and it prefers separate routines, but not a must-have, because when separate_attention_routine is false, we only capture the mixed batch routine, and it just works for pure decode batches with suboptimal performance.

It is not necessary to contain two cudagraphs for ALWAYS_UNIFIED now, until we push the extended mode of FULL _AND_PIECEWISE or FULL_DECODE_ONLY in future PR. That will allow more flexible support to, for example, cascade attention.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this description I think the warning is very confusing to the user. They won't know what a "separate attention routine" means. We should try to think of new naming here; I would prefer something like:

{attn_backend_i.__name__} generally performs better when using
full cuda-graphs for only pure decode batches (and using piecewise
cuda-graphs for prefill and mixed-prefill-decode iterations) to enable
this behavior turn on 
CompilationConfig.mixed_full_cudagraphs_and_piecewise_cudagraphs

some names to consider instead of separate_attention_routine to communicate this flag is cudagraph related:
mixed_full_cudagraphs_and_piecewise_cudagraphs little long but leaves the door open to mixing full-cudagraphs with piecewise-cudagraphs if we want to dispatch on something other than is_pure_decode (see comment above)
full_cudagraphs_for_pure_decode_only little shorter and aligns better with current behavior but leaves little room to expand on this behavior in the future

Comment on lines 56 to 61
attn_cg = self.runner.attn_metadata_builders[0].\
attn_cudagraph_support
# create full cudagraph for mix prefill-decode/general batches
if attn_cg in [AttentionCGSupport.ALWAYS_UNIFIED,
AttentionCGSupport.ALWAYS_SEPARATE] and \
self.runner.capture_mixed_batches:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this can be passed as a boolean flag to maybe_initialize_cudagraph, and take all builders into account instead of just the first one (likely all(<bool expr>)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure when we might have different attention backends. But I think we can verify that all backend builders share the same attention_cudagraph_support type. If that doesn't meet, we raise an error for not supporting full cudagraph mode and ask users to turn on piecewise mode instead. Otherwise, it should be safe to just go with the first builder

Copy link

mergify bot commented Jul 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 12, 2025
@mergify mergify bot removed the needs-rebase label Jul 12, 2025
fhl2000 added 5 commits July 13, 2025 09:35
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice work!!

Comment on lines 101 to 110
cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\
attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE

# PIECEWISE would fall back to NONE if no compilation
if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \
self.no_compilation:
cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE

#TODO: can we optimize above logic?
return cudagraph_runtime_style
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a potential simplification:

Suggested change
cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\
attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE
# PIECEWISE would fall back to NONE if no compilation
if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \
self.no_compilation:
cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE
#TODO: can we optimize above logic?
return cudagraph_runtime_style
assert self.cudagraph_mode == CUDAGraphMode.FULL
if attn_cuda_graphs:
return CUDAGraphRuntimeStyle.FULL
# Need to skip attention, see if piecewise compilation available
if self.attention_piecewise_compilation:
return CUDAGraphRuntimeStyle.PIECEWISE
# Otherwise, fall back to NONE
return CUDAGraphRuntimeStyle.NONE

Comment on lines 96 to 110
# Otherwise, for modes that enable full cudagraph.

# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, we skip them,
# and turn back to the piecewise CUDA graphs.
cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\
attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE

# PIECEWISE would fall back to NONE if no compilation
if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \
self.no_compilation:
cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE

#TODO: can we optimize above logic?
return cudagraph_runtime_style
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a proposed simplification, with slighly better logic to detect if we can fall back to piecewise cudagraphs:

Suggested change
# Otherwise, for modes that enable full cudagraph.
# Some attention backends only support CUDA Graphs in pure decode.
# If attention doesn't support CUDA Graphs for this batch, we skip them,
# and turn back to the piecewise CUDA graphs.
cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\
attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE
# PIECEWISE would fall back to NONE if no compilation
if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \
self.no_compilation:
cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE
#TODO: can we optimize above logic?
return cudagraph_runtime_style
# Otherwise, for modes that enable full cudagraph.
assert self.cudagraph_mode == CUDAGraphMode.FULL
# If attention backend supports full cudagraphs for current batch,
# run with full cudagraphs.
if attn_cuda_graphs:
return CUDAGraphRuntimeStyle.FULL
# Fall back to piecewise cudagraphs if possible
if self.piecewise_attn_compilation:
return CUDAGraphRuntimeStyle.PIECEWISE
# Otherwise, fall back to running entirely without cudagraphs.
return CUDAGraphRuntimeStyle.NONE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.piecewise_attn_compilation defined in __init__ as:

self.piecewise_compilation = self.compilation_config.level == CompilationLevel.PIECEWISE
self.piecewise_attn_compilation = self.piecewise_compilation and \
all(op in self.compilation_config.splitting_ops for op in ["vllm.unified_attention", "vllm.unified_attention_with_output"])

Maybe also add a warning if this ends up as False and the backend requires fallback due to PURE_DECODE_ONLY support.

Copy link
Author

@fhl2000 fhl2000 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a warning inside initialize_attn_backend for this potential fallback.

Comment on lines +2365 to +2369
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice explanation and the logic is clear, thanks!

Comment on lines +2314 to +2316
# Skip capturing batch sizes of 1 in mix prefill-decode if
# separate_attention_routine is on. As bs=1 can treat as a
# pure decode.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice bonus mini-optimization!

Copy link

mergify bot commented Jul 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 14, 2025
@ProExpertProg
Copy link
Collaborator

Additionally, it would be great to get some tests in, especially for the capture/dispatch logic in CUDAGraphDispatcher and the config initialization. Perhaps we mock the model variable and check the forward context is set correctly? And we should test config initialization for all valid input configuration (and check flags/modes/etc. are adjusted correctly.

@mergify mergify bot removed the needs-rebase label Jul 14, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000 fhl2000 changed the title [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer [Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer Jul 14, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! I do like the idea of making supporting piecewise-cudagraphs and full-cudagraphs in parallel as first class citizen alot; im generally a fan of this direction. Did a first pass; will finish up soon but please do a scrub for typos and I do think we can dramatically simplify the dispatching logic. Right now there are alot different flags which makes it a bit confusing. I think we should basically make it so that we compile full-cudagraphs when we can (if enabled) and if have a compiled full-cudagraph exists we use that if not we fallback on piecewise cudagraphs, i.e. something like:

dispatch_key = DispatchKey(
    num_reqs=...,
    num_tokens=...,
    uniform_batch=...,
)

# 1) Prefer a full CUDA graph if one exists.
if dispatch_key in self.full_cudagraphs:
    return self.full_cudagraphs[dispatch_key]

# 2) Otherwise, fall back to piecewise or direct execution.
return self.piecewise_cudagraph or self.model

the uniform_batch flag here would indicate that all of the requests in the batch have the same number of tokens; so if num_reqs == num_tokens and uniform_batch would be pure decode.

But not using a is_pure_decode flag here would this would leave a door open for spec-decode support the future; i.e. where "decode" steps are validating 2-4ish tokens at the same time. So if we have a speculator set up to speculate 3 tokens at a time we could create full-cudagraphs for 3*num_reqs == num_tokens and uniform_batch. Something like FlashMLA would actually support this since the main thing it wants is a uniform batch.

cc @ProExpertProg

- [`cudagraph_capture_sizes`]
[vllm.config.CompilationConfig.cudagraph_capture_sizes]
- [`cudagraph_num_of_warmups`]
[vllm.config.CompilationConfig.cudagraph_num_of_warmups]
- [`cudagraph_copy_inputs`]
[vllm.config.CompilationConfig.cudagraph_copy_inputs]
- [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we deprecate these properly; i.e. keep then for now in addition to cudagraph_mode but warn that they will be deprecated and users should switch to using cudagraph_mode?

@@ -4068,14 +4083,13 @@ class CompilationConfig:
- [`custom_ops`][vllm.config.CompilationConfig.custom_ops]
- [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops]
- CudaGraph capture:
- [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph]
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we deprecate these properly; i.e. keep them for now in addition to cudagraph_mode but warn that they will be deprecated and users should switch to using cudagraph_mode? I assume use_cudagraph would map to cudagraph_mode == PIECEWISE for the purposes of this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can rename this file in a future PR (that just does file renaming) so we can see the diff here?

cudagraph_runtime_style: CUDAGraphRuntimeStyle
# Be aware that is_pure_decode should be default None
# for both piecewise cudagraphs and no cudagraphs.
is_pure_decode: Optional[bool] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we should name this use_full_cudagraph I could see cases in the future where we might want to mixed piecewise + full-cudagraphs but have a different heuristic than is_pure_decode. e.g. for FA3 we may want to run mixed small decodes or spec-decode using full-cudagraphs but run large prefills using piecewise cudagraphs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is pure decode is not equivalent to full cudagraphs. It is the same for backends that only support pure decode batches in CG, but not otherwise. While I agree the name could be improved I don't think full_cudagraph is better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya no I agree; wrote this early in the review, sorry!. Im actually more of fan of uniform_batch now haha, see: #20059 (review)

return CUDAGraphRuntimeStyle.NONE

def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle,
is_pure_decode: bool) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto above; I dont think we over index on is_pure_decode naming and since we may want to dispatch between full-cudagraphs and piecewise on other metrics in the future

else:
assert all(op in self.compilation_config.splitting_ops
for op in ["vllm.unified_attention",
"vllm.unified_attention_with_output"]),\
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like there's multiple places where we are asserting that ["vllm.unified_attention", "vllm.unified_attention_with_output"] is in splitting_ops; I think in the near future we will also be splitting on Mamba layers (for hybrid models) in which case we will have to add that to all of these locations. Im not quite sure why we need to assert that these ops are there; I think it would be just sufficient to assert that if we are expecting piecewise compilation we expect non-empty splitting ops

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I asked for this is because this maintains the possibility of excluding other (non-attention) ops from cudagraphs. Like maybe if we do torch.compile-based DBO and split the graph on the yield custom op or something. But I agree the repeated check is fragile: let's extract it into a property of compilation config, perhaps called is_attention_splitting or something like that

# the separate_attention_routine flag, but should inform
# the user that this flag can be turned on to obtain
# better performance.
if attn_cg == AttentionCGSupport.ALWAYS_SEPARATE and \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this description I think the warning is very confusing to the user. They won't know what a "separate attention routine" means. We should try to think of new naming here; I would prefer something like:

{attn_backend_i.__name__} generally performs better when using
full cuda-graphs for only pure decode batches (and using piecewise
cuda-graphs for prefill and mixed-prefill-decode iterations) to enable
this behavior turn on 
CompilationConfig.mixed_full_cudagraphs_and_piecewise_cudagraphs

some names to consider instead of separate_attention_routine to communicate this flag is cudagraph related:
mixed_full_cudagraphs_and_piecewise_cudagraphs little long but leaves the door open to mixing full-cudagraphs with piecewise-cudagraphs if we want to dispatch on something other than is_pure_decode (see comment above)
full_cudagraphs_for_pure_decode_only little shorter and aligns better with current behavior but leaves little room to expand on this behavior in the future

# for full cudagraph, select between mixed batches
# or pure decode batches
decode_case = self.compilation_config.separate_attention_routine\
and is_pure_decode
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should get rid of is_pure_decode here. I could see cases in the future where we might want to mixed piecewise + full-cudagraphs but have a different heuristic than is_pure_decode. e.g. for FA3 we may want to run mixed small decodes or spec-decode using full-cudagraphs but run large prefills using piecewise cudagraphs

Honestly I find all these flags very confusing; I think I much simpler more extensible dispatch logic would be:

dispatch_key = DispatchKey(num_reqs=..., num_tokens=..., uniform_batch=...)
if dispatch_key in self.full_cudagraphs:
     return self.full_cudagraphs[dispatch_key]
# Fall-back if a full_cudagraph isn't available 
return self.piecewise_cudagraph or self.model

the uniform_batch flag here would indicate that all of the requests in the batch have the same number of tokens; so if num_reqs == num_tokens and uniform_batch would be pure decode but this would leave a door open for spec-decode support the future i.e. where "decode" steps are validating 2-4ish tokens at the same time. So if we are speculator is set to speculate 3 tokens at a time we could create full-cudagraphs for 3*num_reqs == num_tokens and uniform_batch. Something like FlashMLA would actually support this since the main thing it wants is a uniform batch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post from previous discussions.

You are right about more explicit control of full cudagraph like all, mixed, and decode_only. But I think an extra flag may not be necessary now, after the cudagraph_mode proposed in #20283, the extra control can be achieved by extending the enum mode.

Like use NONE=0, PIECEWISE=1, FULL=2, FULL_DECODE_ONLY=3, and FULL_AND_PIECEWISE=4. Here, NONE is for no cudagraph. PIECEWISE use only piecewise cudagraph (now v1 default). FULL means the current strategy for maximum full cudagraph support (with separate_attention_routine tunable to achieve mixed only or all). FULL_DECODE_ONLY uses only one set of cudagraph for pure decode, and no cudagraph for the rest. FULL_AND_PIECEWISE means explicitly having two sets of cudagraph, with full cudagraph for decode-only, and piecewise cudagraph for mixed batches or any rest. In this way, the separate_attention_routine is forced to true in FULL_DECODE_ONLY and FULL_AND_PIECEWISE, and the cascade attention can also be supported in these two modes.

I think @ProExpertProg and I agree to leave this explicit control of mixed_full_cudagraphs_and_piecewise_cudagraphs to a follow-up PR to include a new cudagraph_mode like FULL_AND_PIECEWISE. Currently, the FULL mode just maximizes the support of full cudagraph with proper fallbacks, and separate_attention_routine takes effect only in this mode to tell if we want to retain a unified routine or use separate routines for different situations.

But have to admit that separate_attention_routine seems a bit redundant now, as it would be overridden when attention backends cudagraph support is PURE_DECODE_ONLY or ALWAYS_UNIFIED, and would not be overridden only when ALWAYS_SEPARATE.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Fall-back if a full_cudagraph isn't available 
return self.piecewise_cudagraph or self.model

I don't think we can now fully separate the piecewise_cudagraph and the raw model easily, since they are integrated together by the torch.compile with vllm piecewise compilation currently.

the uniform_batch flag here would indicate that all of the requests in the batch have the same number of tokens; so if num_reqs == num_tokens and uniform_batch would be pure decode but this would leave a door open for spec-decode support the future i.e. where "decode" steps are validating 2-4ish tokens at the same time. So if we are speculator is set to speculate 3 tokens at a time we could create full-cudagraphs for 3*num_reqs == num_tokens and uniform_batch. Something like FlashMLA would actually support this since the main thing it wants is a uniform batch.

I do agree we should leave a door for spec-decode, but I also think using the uniform_batch flag and num_tokens together for dispatching is somewhat confusing. First things is, if the speculator is set to speculate 3 tokens at a time, I guess this pattern is fixed, and we should just design a new enum representing that we are doing the speculate decode, rather than judging if 3*num_reqs == num_tokens and uniform_batch. Also, considering the meaning of uniform_batch you mentioned, that is not equivalent to pure decode and spec-decode. One counterexample may be in a pure prefill case where each request has 3 tokens.

Moreover, could we leave num_tokens being handled by the cudagraph wrapper itself? I think dispatching the num_tokens explicitly inside the dispatcher may be reasonable, but it is not trivial to manage all fused cases, and we couldn't reuse the uniform-designed cudagraph wrapper here. Leaving all that exact cudagraph management of one edge case to the wrapper would be good, as one cudagraph wrapper could just represent one case we want to process.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to leaving the token dispatching inside the wrapper.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can now fully separate the piecewise_cudagraph and the raw model easily, since they are integrated together by the torch.compile with vllm piecewise compilation currently.

this was just there as pseudo code to represent the logic; like you said the actual code would just be the model since the piecewise dispatching is under that (and dispatches to eager if the batch is too large)

One counterexample may be in a pure prefill case where each request has 3 tokens.

that would be ok; there would be no difference between that and a spec decode from the attention perspective so we would want it to use the the full-cudagraph if available in this case. We shouldn't fixate on the prefill/decode naming since chunked prefill and spec-decode blur these lines quite a bit; those names are just colloquially used to mean large query length or near 1 query length respectively. They are useful in the conversation but im a bit hesitant to harden those definitions (i.e. decode being query_len == 1) into the code; especially inside this fairly core part of the code.

Moreover, could we leave num_tokens being handled by the cudagraph wrapper itself?
+1 to leaving the token dispatching inside the wrapper.

I like the idea of having it as part of the dispatch keys because it leaves the door for have full-cudagraphs for only small batch sizes; like I could see a world where might only compile full cuda-graphs for up to BS 128 and then use piecewise for everything larger

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we already kinda support capturing full-cudagraphs only up to a size smaller than max_num_seqs via the configurable compilation_config.cudagraph_capture_sizes. So if that is the case then here would dispatch to the full-cudagraph wrapper only for it to fallback on eager inside the wrapper. I think this very confusing for readers/users. I think we should try to use this opportunity to flatten and simplify the dispatching; as someone who is moderately familiar with cudagraphs I vLLM I find all this dispatching very confusing.

This is where having a registry of the cudagraphs captured and keys representing the workloads they support I think could be far less confusing i.e. we can try to dispatch to a full-cudagraph first if one doesnt exist dispatch to a piecewise-cudagraph and if one doesnt exist dispatch to eager (I am aware the last 2 steps are currently happening in piecewise backend)

While I do see how later we might want to do more complex dispatching and size will certainly have to be involved, I think that's out of scope for this PR.

ya I don't think we have to go that far in this PR I just really want to make sure we are creating the right extensible abstractions since this PR is introducing alot of code that would have to get refactored to enable this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just really want to make sure we are creating the right extensible abstractions since this PR

I agree with that & thanks for driving this point. What about the following "compromise"/alternate solution: the runtime style (or mode if we consolidate) and dispatch key dispatch separately. The runtime style is used to select between the CUDAGraphWrapper instances (full and piecewise), and the DispatchKey dispatches between recorded cudagraphs within a wrapper, for now only including num_tokens and uniform_batch. This way

Structure would look like this:

  • CUDAGraphDispatcher decides the "runtime style" and the DispatchKey (whatever new name it receives), and sets it in the forward context
  • Each CUDAGraphWrapper inspects the forward context and only captures/replays if runtime style matches. It uses the DispatchKey to decide which cudagraph to capture/replay.

This solves the issue where we can't do cudagraph capture/replay directly in the dispatcher for piecewise cudagraphs. While piecewise cudagraphs might not need as much detailed dispatching, this would give us the flexibility to dispatch to different routines for ops that were not in splitting ops.

@fhl2000 no need to implement this yet, let's reach a consensus on this first

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this approach, the CUDAGraphWrapper instances could have even less logic and blindly trust the forward context on what cg to dispatch to; if it doesn't exist yet, it gets captured. That way the dispatcher is the single source of truth on available cudagraphs.

If I'm missing something, a "single source of truth" for "available" cudagraphs a new noinit dictionary on CompilationConfig, used by both the dispatcher and the CUDAGraphWrapper instances.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Structure would look like this:

  • CUDAGraphDispatcher decides the "runtime style" and the DispatchKey (whatever new name it receives), and sets it in the forward context
  • Each CUDAGraphWrapper inspects the forward context and only captures/replays if runtime style matches. It uses the DispatchKey to decide which cudagraph to capture/replay.

Looks good to me.

With this approach, the CUDAGraphWrapper instances could have even less logic and blindly trust the forward context on what cg to dispatch to; if it doesn't exist yet, it gets captured. That way the dispatcher is the single source of truth on available cudagraphs.

If I'm missing something, a "single source of truth" for "available" cudagraphs a new noinit dictionary on CompilationConfig, used by both the dispatcher and the CUDAGraphWrapper instances.

Sharing a new noinit dictionary on CompilationConfig by both the dispatcher and the CUDAGraphWrapper instances seems unviable to me (or should be improved). While sharing DispatchKey on FULL style is possible as one graph item is enough for one dispatch key, it is not good for piecewise cudagraph as there are many graph items (almost one for a layer) corresponding to one DispatchKey in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry phrased poorly/dropped a sentence. I was saying we could use that dictionary if we need a single source of truth but hopefully we don't. I don't think dispatch key cares about any layer wise info so the different piecewise backends (one per subgraph) don't need to worry about it.

cudagraph_runtime_style = self.cudagraph_dispatcher.\
get_cudagraph_runtime_style(attention_cuda_graphs)
# Note: When cudagraph_mode is FULL and
# compilation_config.separate_attention_routine is True, as in FA2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

, as in -> ; in

usage_str="full/mixed"))
})
logger.debug("Full cudagraph for mixed batches initialized")
# always create full cudagraph for pure decode batches if speparate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speparate -> separate

Copy link

mergify bot commented Jul 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 15, 2025
# runner have been done.

# Dict to store cudagraph candidates for runtime dispatching.
self.cudagraph_candidates: dict[DispatchKey, Any] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we name DispatchKey something else? This means something very specific in PyTorch.

Comment on lines +4012 to +4024
class CUDAGraphMode(enum.Enum):
# constants for the config of the cudagraph mode.
NONE = 0
PIECEWISE = 1
FULL = 2


class CUDAGraphRuntimeStyle(enum.Enum):
# constants for concrete cudagraph runtime style, used for
# runtime dispatching.
NONE = 0
PIECEWISE = 1
FULL = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both?

Copy link
Collaborator

@zou3519 zou3519 Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty confused about the CUDAGraphMode vs CUDAGraphRuntimeStyle. Is the reason that CUDAGraphMode=Full for FAv2 means that we do FULL for prefill but PIECEWISE for decode?

If so, would anyone want to do piecewise for prefill and piecewise for decode for FAv2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 why do we have 2 identical enums

Copy link
Author

@fhl2000 fhl2000 Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have both?
+1 why do we have 2 identical enums

Sorry for causing confusion. But as explained in the comments in the code and also in the history chats, the CUDAGraphMode is intended for configure the behavior of modes. This mode configuration would be extended in future PR to contain more, for example FULL_ONLY_DECODE, FULL_AND_PIECEWISE, or something like AUTO as mentioned in #20283 that automatically selects among the previous modes. See below for the meaning of these modes

Like use NONE=0, PIECEWISE=1, FULL=2, FULL_DECODE_ONLY=3, and FULL_AND_PIECEWISE=4. Here, NONE is for no cudagraph. PIECEWISE use only piecewise cudagraph (now v1 default). FULL means the current strategy for maximum full cudagraph support (with separate_attention_routine tunable to achieve mixed only or all). FULL_DECODE_ONLY uses only one set of cudagraph for pure decode, and no cudagraph for the rest. FULL_AND_PIECEWISE means explicitly having two sets of cudagraph, with full cudagraph for decode-only, and piecewise cudagraph for mixed batches or any rest.

On the other side, CUDAGraphRuntimeStyle would be the actual style of cudagraphs we selected to run at runtime. I think there are only three styles to be shared among all possibilities, and can't be extended. This is also used as a property assigned for the CUDAGraph wrapper class for correctly activating cudagraphs of the right style, because currently we could have nested CUDAGraph wrappers, i.e., piecewise cudagraph wrapper integrated with the piecewise compiled model inside, while one wrapper wrapped outside for full cudagraph.

They have members of the same name coincidentally now, but should look fine after extending the cudagraph mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think mode can be enough and we just add asserts that during runtime we're not using any of the "mixed" modes that aren't valid as runtime styles.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty confused about the CUDAGraphMode vs CUDAGraphRuntimeStyle. Is the reason that CUDAGraphMode=Full for FAv2 means that we do FULL for prefill but PIECEWISE for decode?

No. For FA2, when CUDAGraphMode is set to Full, it means using FULL cudagraphs for both mixed prefill-decode stages and pure decode stages. However, since FA2's cudagraph support is marked as ALWAYS_SEPARATE, it prefers separate cudagraph routines for these two stages. Only when the separate_attention_routine is set to False, there will be a single FULL cudagraph for mixed prefill-decode batches, which is also compatible with pure decode scenarios.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, the current design of FULL mode just intends to do the max it can to support full cudagraph runtime style, while falling back to piecewise cudagraph or no cudagraph runtime style if any incompatible routine.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think mode can be enough and we just add asserts that during runtime we're not using any of the "mixed" modes that aren't valid as runtime styles.

But I think fusing the usage of cudagraph mode for both the semantics of "mode" and "runtimestyle" would lead to more confusion.

Copy link
Collaborator

@ProExpertProg ProExpertProg Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think fusing the usage of cudagraph mode for both the semantics of "mode" and "runtimestyle" would lead to more confusion.

I think it's a tradeoff. And I don't think the semantics are different enough. It's also easier to introduce a separate enum later if needed. To distinguish between them, all variables currently of type RuntimeStyle can have runtime_mode in the name instead of just name if you want the semantic meaning to be clearer.

@LucasWilkinson @zou3519 what do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants