Skip to content

Commit a9783c3

Browse files
committed
Add piecewise cudagraph support + refactor
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent b0a2595 commit a9783c3

File tree

13 files changed

+267
-195
lines changed

13 files changed

+267
-195
lines changed

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 106 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33

44
import gc
55
from collections.abc import Iterable
6-
from typing import Optional, Union
6+
from typing import List, Optional, Union
77

88
import pytest
99
import torch
1010
from torch import nn
1111
from transformers import Qwen2Config
1212

1313
from vllm import LLM, SamplingParams
14-
from vllm.config import CacheConfig, VllmConfig
14+
from vllm.compilation.backends import set_model_tag
15+
from vllm.compilation.decorators import (skip_torch_compile,
16+
support_torch_compile)
17+
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
18+
VllmConfig)
1519
from vllm.forward_context import get_forward_context
1620
from vllm.model_executor.layers.layernorm import RMSNorm
1721
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -52,6 +56,7 @@ def __init__(
5256
target_layer_idx = layer_idx % 5
5357
kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace(
5458
str(layer_idx), str(target_layer_idx))
59+
5560
self.self_attn = Qwen2Attention(
5661
hidden_size=self.hidden_size,
5762
num_heads=config.num_attention_heads,
@@ -99,8 +104,72 @@ def forward(
99104
return hidden_states, residual
100105

101106

107+
@support_torch_compile
108+
class DecoderLayerGroup(nn.Module):
109+
110+
def __init__(
111+
self,
112+
*,
113+
vllm_config: VllmConfig,
114+
prefix: str = "",
115+
layers: List[nn.Module],
116+
):
117+
super().__init__()
118+
self.layers = layers
119+
120+
def forward(
121+
self,
122+
positions: torch.Tensor,
123+
hidden_states: torch.Tensor,
124+
residual: Optional[torch.Tensor] = None,
125+
):
126+
for layer in self.layers:
127+
hidden_states, residual = layer(
128+
positions,
129+
hidden_states,
130+
residual,
131+
)
132+
return hidden_states, residual
133+
134+
135+
@skip_torch_compile
102136
class Qwen2ModelWithKVSharing(Qwen2Model):
103137

138+
def __init__(self,
139+
*,
140+
vllm_config: VllmConfig,
141+
prefix: str = "",
142+
decoder_layer_type: type[
143+
nn.Module] = Qwen2DecoderLayerWithKVSharing):
144+
super().__init__(
145+
vllm_config=vllm_config,
146+
prefix=prefix,
147+
decoder_layer_type=decoder_layer_type,
148+
)
149+
150+
with set_model_tag("first_layer_group"):
151+
self.first_layer_group = DecoderLayerGroup(
152+
vllm_config=vllm_config,
153+
prefix=f"{prefix}.first_layer_group",
154+
layers=self.layers[self.start_layer:START_KV_SHARING_LAYER],
155+
)
156+
157+
with set_model_tag("second_layer_group"):
158+
self.second_layer_group = DecoderLayerGroup(
159+
vllm_config=vllm_config,
160+
prefix=f"{prefix}.second_layer_group",
161+
layers=self.layers[START_KV_SHARING_LAYER:self.end_layer],
162+
)
163+
164+
# Pre-allocate static buffers for CUDA graph
165+
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
166+
self.dtype = vllm_config.model_config.dtype
167+
self.device = next(self.parameters()).device
168+
self.hidden_size = vllm_config.model_config.get_hidden_size()
169+
self.residual = torch.zeros((self.max_num_tokens, self.hidden_size),
170+
dtype=self.dtype,
171+
device=self.device)
172+
104173
def forward(
105174
self,
106175
input_ids: torch.Tensor,
@@ -112,46 +181,40 @@ def forward(
112181
hidden_states = inputs_embeds
113182
else:
114183
hidden_states = self.get_input_embeddings(input_ids)
184+
115185
residual = None
186+
first_hidden_states, first_residual = self.first_layer_group(
187+
positions,
188+
hidden_states,
189+
residual, # no residual, assume no pipeline parallel
190+
)
116191

117192
decode_indices = get_forward_context().decode_indices
118193
if decode_indices is None:
119194
decode_indices = torch.arange(positions.size(0),
120195
device=positions.device)
121-
122-
# Forward with full inputs up to the first layer that shares KV cache
123-
for layer in self.layers[self.start_layer:START_KV_SHARING_LAYER]:
124-
hidden_states, residual = layer(
125-
positions,
126-
hidden_states,
127-
residual,
128-
)
129-
130-
if decode_indices is not None:
131-
decode_hidden_states = hidden_states[decode_indices]
132-
decode_positions = positions[decode_indices]
133-
decode_residual = (residual[decode_indices]
134-
if residual is not None else None)
135-
else:
136-
decode_hidden_states = hidden_states
137-
decode_positions = positions
138-
decode_residual = residual
139-
140-
# Optimization: forward with partial inputs only for last N layers
141-
for layer in self.layers[START_KV_SHARING_LAYER:self.end_layer]:
142-
decode_hidden_states, decode_residual = layer(
143-
decode_positions,
144-
decode_hidden_states,
145-
decode_residual,
146-
)
196+
num_decodes = decode_indices.shape[0]
197+
assert num_decodes >= 1
198+
assert first_residual is not None
199+
200+
# CUDA graph expects static tensor addresses
201+
# Copy output of first layer group to second layer group
202+
self.residual[:num_decodes].copy_(first_residual[decode_indices])
203+
hidden_states[:num_decodes].copy_(first_hidden_states[decode_indices])
204+
positions[:num_decodes].copy_(positions[decode_indices])
205+
206+
second_hidden_states, second_residual = self.second_layer_group(
207+
positions[:num_decodes],
208+
hidden_states[:num_decodes],
209+
self.residual[:num_decodes],
210+
)
147211

148212
# Merge results back
149-
if decode_hidden_states is not None:
150-
hidden_states[decode_indices] = decode_hidden_states
151-
if residual is not None:
152-
residual[decode_indices] = decode_residual
213+
first_hidden_states[decode_indices] = second_hidden_states
214+
if first_residual is not None:
215+
first_residual[decode_indices] = second_residual
153216

154-
hidden_states, _ = self.norm(hidden_states, residual)
217+
hidden_states, _ = self.norm(first_hidden_states, first_residual)
155218
return hidden_states
156219

157220

@@ -205,20 +268,24 @@ def load_weights(self, weights: Iterable[tuple[str,
205268
return loader.load_weights(weights)
206269

207270

208-
# TODO: make it work with torch.compile
209271
@fork_new_process_for_each_test
210-
@pytest.mark.parametrize("enforce_eager", [True])
272+
@pytest.mark.parametrize("enforce_eager", [False, True])
211273
def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
212274
prompt = "What is the capital of France?"
213275
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
214-
sampling_params = SamplingParams(temperature=0.0, max_tokens=40)
276+
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
215277
single_prompt = [prompt]
278+
compilation_config = CompilationConfig(
279+
level=CompilationLevel.PIECEWISE
280+
if not enforce_eager else CompilationLevel.NO_COMPILATION,
281+
cudagraph_share_memory_pool=False)
216282

217283
with monkeypatch.context() as m:
218284
m.setenv("VLLM_USE_V1", "1")
219285

220286
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
221-
enforce_eager=enforce_eager)
287+
enforce_eager=enforce_eager,
288+
compilation_config=compilation_config)
222289
responses = llm.generate(single_prompt, sampling_params)
223290
ref_output = responses[0].outputs[0].text
224291

@@ -229,7 +296,8 @@ def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
229296
m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1")
230297

231298
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
232-
enforce_eager=enforce_eager)
299+
enforce_eager=enforce_eager,
300+
compilation_config=compilation_config)
233301
responses = llm.generate(single_prompt, sampling_params)
234302
output = responses[0].outputs[0].text
235303
assert output == ref_output

vllm/compilation/backends.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,8 +391,11 @@ def __init__(
391391
# them, e.g. backbone (default), eagle_head, etc.
392392
self.prefix = prefix or model_tag
393393

394-
global global_graph_pool
395-
if global_graph_pool is None:
394+
if vllm_config.compilation_config.cudagraph_share_memory_pool:
395+
global global_graph_pool
396+
if global_graph_pool is None:
397+
global_graph_pool = current_platform.graph_pool_handle()
398+
else:
396399
global_graph_pool = current_platform.graph_pool_handle()
397400

398401
# TODO: in the future, if we want to use multiple

vllm/compilation/decorators.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@
2323
_T = TypeVar("_T", bound=type[nn.Module])
2424

2525

26+
def skip_torch_compile(cls: _T) -> _T:
27+
cls._skip_compile_vllm = True
28+
for base in cls.__bases__:
29+
base._skip_compile_vllm = True
30+
return cls
31+
32+
2633
@overload
2734
def support_torch_compile(
2835
*,
@@ -156,7 +163,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
156163
self.do_not_compile = \
157164
vllm_config.compilation_config.level in [
158165
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
159-
] or not supports_dynamo()
166+
] or not supports_dynamo() or getattr(self, "_skip_compile_vllm", False)
160167
if self.do_not_compile:
161168
return
162169
compilation_counter.num_models_seen += 1

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4029,6 +4029,8 @@ class CompilationConfig:
40294029
"""Sizes to capture cudagraph.
40304030
- None (default): capture sizes are inferred from vllm config.
40314031
- list[int]: capture sizes are specified as given."""
4032+
cudagraph_share_memory_pool: bool = True
4033+
"""Whether to share a single global memory pool for each CUDA graph captured"""
40324034
cudagraph_copy_inputs: bool = False
40334035
"""Whether to copy input tensors for
40344036
cudagraph. If the caller can guarantee that the same input buffers

vllm/envs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,6 @@ def get_vllm_port() -> Optional[int]:
962962
# models
963963
"VLLM_USE_NVFP4_CT_EMULATIONS":
964964
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
965-
966965
"VLLM_V1_KV_SHARING_SKIP_PREFILL":
967966
lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1",
968967
}

vllm/forward_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class ForwardContext:
9595
# set dynamically for each forward pass
9696
dp_metadata: Optional[DPMetadata] = None
9797
skip_cuda_graphs: bool = False
98+
9899
decode_indices: Optional[torch.Tensor] = None
100+
"""indices used for decoding"""
99101

100102

101103
_forward_context: Optional[ForwardContext] = None

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
from typing import Optional
3-
42
import numpy as np
53
import torch
64

@@ -121,22 +119,11 @@ def reorder_batch(self, input_batch: InputBatch,
121119

122120
return True
123121

124-
def build(
125-
self,
126-
common_prefix_len: int,
127-
common_attn_metadata: CommonAttentionMetadata,
128-
decode_only_common_attn_metadata: Optional[
129-
CommonAttentionMetadata] = None,
130-
):
131-
if decode_only_common_attn_metadata is not None:
132-
raise NotImplementedError(
133-
"CPU backend does not support decode-only attention yet.")
122+
def build(self, common_prefix_len: int,
123+
common_attn_metadata: CommonAttentionMetadata):
134124
num_reqs = common_attn_metadata.num_reqs
135125
num_actual_tokens = common_attn_metadata.num_actual_tokens
136126
max_query_len = common_attn_metadata.max_query_len
137-
query_start_loc_np = (common_attn_metadata.query_start_loc_np
138-
if common_attn_metadata.query_start_loc_np
139-
is not None else self.runner.query_start_loc_np)
140127

141128
runner = self.runner
142129
block_table = self.block_table
@@ -148,8 +135,8 @@ def build(
148135
) if num_prompt_req < num_reqs else 0
149136
self.seq_start_loc_np[0] = 0
150137
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
151-
num_prefill_tokens = query_start_loc_np[num_prompt_req].item()
152-
num_decode_tokens = query_start_loc_np[num_reqs].item(
138+
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
139+
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
153140
) - num_prefill_tokens
154141
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
155142
block_table_tensor = block_table.get_device_tensor()

0 commit comments

Comments
 (0)