Skip to content

Commit b0a2595

Browse files
committed
[V1] Perf optimization for layers with KV reuse
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent a0389e0 commit b0a2595

File tree

11 files changed

+436
-50
lines changed

11 files changed

+436
-50
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import gc
5+
from collections.abc import Iterable
6+
from typing import Optional, Union
7+
8+
import pytest
9+
import torch
10+
from torch import nn
11+
from transformers import Qwen2Config
12+
13+
from vllm import LLM, SamplingParams
14+
from vllm.config import CacheConfig, VllmConfig
15+
from vllm.forward_context import get_forward_context
16+
from vllm.model_executor.layers.layernorm import RMSNorm
17+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18+
from vllm.model_executor.layers.quantization import QuantizationConfig
19+
from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP,
20+
Qwen2Model)
21+
from vllm.model_executor.models.registry import ModelRegistry
22+
from vllm.model_executor.models.utils import (AutoWeightsLoader,
23+
extract_layer_index,
24+
maybe_prefix)
25+
from vllm.model_executor.sampling_metadata import SamplingMetadata
26+
from vllm.sequence import IntermediateTensors
27+
28+
from ...utils import fork_new_process_for_each_test
29+
30+
START_KV_SHARING_LAYER = 10
31+
32+
33+
class Qwen2DecoderLayerWithKVSharing(nn.Module):
34+
35+
def __init__(
36+
self,
37+
config: Qwen2Config,
38+
cache_config: Optional[CacheConfig] = None,
39+
quant_config: Optional[QuantizationConfig] = None,
40+
prefix: str = "",
41+
) -> None:
42+
super().__init__()
43+
self.hidden_size = config.hidden_size
44+
rope_theta = getattr(config, "rope_theta", 1000000)
45+
rope_scaling = getattr(config, "rope_scaling", None)
46+
attn_prefix = f"{prefix}.self_attn"
47+
layer_idx = extract_layer_index(prefix)
48+
kv_sharing_target_layer_name = None
49+
50+
if layer_idx >= START_KV_SHARING_LAYER:
51+
# re-use KV cache from first 5 layers
52+
target_layer_idx = layer_idx % 5
53+
kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace(
54+
str(layer_idx), str(target_layer_idx))
55+
self.self_attn = Qwen2Attention(
56+
hidden_size=self.hidden_size,
57+
num_heads=config.num_attention_heads,
58+
max_position=config.max_position_embeddings,
59+
num_kv_heads=config.num_key_value_heads,
60+
rope_theta=rope_theta,
61+
cache_config=cache_config,
62+
quant_config=quant_config,
63+
rope_scaling=rope_scaling,
64+
prefix=attn_prefix,
65+
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
66+
)
67+
68+
self.mlp = Qwen2MLP(
69+
hidden_size=self.hidden_size,
70+
intermediate_size=config.intermediate_size,
71+
hidden_act=config.hidden_act,
72+
quant_config=quant_config,
73+
prefix=f"{prefix}.mlp",
74+
)
75+
self.input_layernorm = RMSNorm(config.hidden_size,
76+
eps=config.rms_norm_eps)
77+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
78+
eps=config.rms_norm_eps)
79+
80+
def forward(
81+
self,
82+
positions: torch.Tensor,
83+
hidden_states: torch.Tensor,
84+
residual: Optional[torch.Tensor],
85+
) -> tuple[torch.Tensor, torch.Tensor]:
86+
if residual is None:
87+
residual = hidden_states
88+
hidden_states = self.input_layernorm(hidden_states)
89+
else:
90+
hidden_states, residual = self.input_layernorm(
91+
hidden_states, residual)
92+
hidden_states = self.self_attn(
93+
positions=positions,
94+
hidden_states=hidden_states,
95+
)
96+
hidden_states, residual = self.post_attention_layernorm(
97+
hidden_states, residual)
98+
hidden_states = self.mlp(hidden_states)
99+
return hidden_states, residual
100+
101+
102+
class Qwen2ModelWithKVSharing(Qwen2Model):
103+
104+
def forward(
105+
self,
106+
input_ids: torch.Tensor,
107+
positions: torch.Tensor,
108+
intermediate_tensors: Optional[IntermediateTensors] = None,
109+
inputs_embeds: Optional[torch.Tensor] = None,
110+
) -> Union[torch.Tensor, IntermediateTensors]:
111+
if inputs_embeds is not None:
112+
hidden_states = inputs_embeds
113+
else:
114+
hidden_states = self.get_input_embeddings(input_ids)
115+
residual = None
116+
117+
decode_indices = get_forward_context().decode_indices
118+
if decode_indices is None:
119+
decode_indices = torch.arange(positions.size(0),
120+
device=positions.device)
121+
122+
# Forward with full inputs up to the first layer that shares KV cache
123+
for layer in self.layers[self.start_layer:START_KV_SHARING_LAYER]:
124+
hidden_states, residual = layer(
125+
positions,
126+
hidden_states,
127+
residual,
128+
)
129+
130+
if decode_indices is not None:
131+
decode_hidden_states = hidden_states[decode_indices]
132+
decode_positions = positions[decode_indices]
133+
decode_residual = (residual[decode_indices]
134+
if residual is not None else None)
135+
else:
136+
decode_hidden_states = hidden_states
137+
decode_positions = positions
138+
decode_residual = residual
139+
140+
# Optimization: forward with partial inputs only for last N layers
141+
for layer in self.layers[START_KV_SHARING_LAYER:self.end_layer]:
142+
decode_hidden_states, decode_residual = layer(
143+
decode_positions,
144+
decode_hidden_states,
145+
decode_residual,
146+
)
147+
148+
# Merge results back
149+
if decode_hidden_states is not None:
150+
hidden_states[decode_indices] = decode_hidden_states
151+
if residual is not None:
152+
residual[decode_indices] = decode_residual
153+
154+
hidden_states, _ = self.norm(hidden_states, residual)
155+
return hidden_states
156+
157+
158+
class TestQwen2ForCausalLM(nn.Module):
159+
160+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
161+
super().__init__()
162+
config = vllm_config.model_config.hf_config
163+
quant_config = vllm_config.quant_config
164+
lora_config = vllm_config.lora_config
165+
self.config = config
166+
self.lora_config = lora_config
167+
168+
self.quant_config = quant_config
169+
self.model = Qwen2ModelWithKVSharing(
170+
vllm_config=vllm_config,
171+
prefix=maybe_prefix(prefix, "model"),
172+
decoder_layer_type=Qwen2DecoderLayerWithKVSharing)
173+
self.lm_head = self.model.embed_tokens
174+
self.logits_processor = LogitsProcessor(config.vocab_size)
175+
self.make_empty_intermediate_tensors = (
176+
self.model.make_empty_intermediate_tensors)
177+
178+
def forward(
179+
self,
180+
input_ids: torch.Tensor,
181+
positions: torch.Tensor,
182+
intermediate_tensors: Optional[IntermediateTensors] = None,
183+
inputs_embeds: Optional[torch.Tensor] = None,
184+
) -> Union[torch.Tensor, IntermediateTensors]:
185+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
186+
inputs_embeds)
187+
return hidden_states
188+
189+
def compute_logits(
190+
self,
191+
hidden_states: torch.Tensor,
192+
sampling_metadata: SamplingMetadata,
193+
) -> Optional[torch.Tensor]:
194+
logits = self.logits_processor(self.lm_head, hidden_states,
195+
sampling_metadata)
196+
return logits
197+
198+
def load_weights(self, weights: Iterable[tuple[str,
199+
torch.Tensor]]) -> set[str]:
200+
loader = AutoWeightsLoader(
201+
self,
202+
skip_prefixes=(["lm_head."]
203+
if self.config.tie_word_embeddings else None),
204+
)
205+
return loader.load_weights(weights)
206+
207+
208+
# TODO: make it work with torch.compile
209+
@fork_new_process_for_each_test
210+
@pytest.mark.parametrize("enforce_eager", [True])
211+
def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager):
212+
prompt = "What is the capital of France?"
213+
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
214+
sampling_params = SamplingParams(temperature=0.0, max_tokens=40)
215+
single_prompt = [prompt]
216+
217+
with monkeypatch.context() as m:
218+
m.setenv("VLLM_USE_V1", "1")
219+
220+
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
221+
enforce_eager=enforce_eager)
222+
responses = llm.generate(single_prompt, sampling_params)
223+
ref_output = responses[0].outputs[0].text
224+
225+
del llm
226+
gc.collect()
227+
torch.cuda.empty_cache()
228+
229+
m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1")
230+
231+
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
232+
enforce_eager=enforce_eager)
233+
responses = llm.generate(single_prompt, sampling_params)
234+
output = responses[0].outputs[0].text
235+
assert output == ref_output

vllm/envs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
140140
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
141141
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
142+
VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False
142143

143144

144145
def get_default_cache_root():
@@ -960,7 +961,10 @@ def get_vllm_port() -> Optional[int]:
960961
# generations on machines < 100 for compressed-tensors
961962
# models
962963
"VLLM_USE_NVFP4_CT_EMULATIONS":
963-
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")))
964+
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
965+
966+
"VLLM_V1_KV_SHARING_SKIP_PREFILL":
967+
lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1",
964968
}
965969

966970
# --8<-- [end:env-vars-definition]

vllm/forward_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ForwardContext:
9595
# set dynamically for each forward pass
9696
dp_metadata: Optional[DPMetadata] = None
9797
skip_cuda_graphs: bool = False
98+
decode_indices: Optional[torch.Tensor] = None
9899

99100

100101
_forward_context: Optional[ForwardContext] = None
@@ -116,6 +117,7 @@ def set_forward_context(
116117
num_tokens: Optional[int] = None,
117118
num_tokens_across_dp: Optional[torch.Tensor] = None,
118119
skip_cuda_graphs: bool = False,
120+
decode_indices: Optional[torch.Tensor] = None,
119121
):
120122
"""A context manager that stores the current forward context,
121123
can be attention metadata, etc.
@@ -141,6 +143,7 @@ def set_forward_context(
141143
attn_metadata=attn_metadata,
142144
dp_metadata=dp_metadata,
143145
skip_cuda_graphs=skip_cuda_graphs,
146+
decode_indices=decode_indices,
144147
)
145148

146149
try:

vllm/model_executor/models/qwen2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def __init__(
109109
prefix: str = "",
110110
attn_type: str = AttentionType.DECODER,
111111
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
112+
**attn_kwargs,
112113
) -> None:
113114
super().__init__()
114115
self.hidden_size = hidden_size
@@ -170,7 +171,8 @@ def __init__(
170171
**{
171172
"layer_idx": extract_layer_index(prefix),
172173
"dual_chunk_attention_config": dual_chunk_attention_config,
173-
} if dual_chunk_attention_config else {})
174+
} if dual_chunk_attention_config else {},
175+
**attn_kwargs)
174176

175177
def forward(
176178
self,

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Optional
3+
24
import numpy as np
35
import torch
46

@@ -119,11 +121,22 @@ def reorder_batch(self, input_batch: InputBatch,
119121

120122
return True
121123

122-
def build(self, common_prefix_len: int,
123-
common_attn_metadata: CommonAttentionMetadata):
124+
def build(
125+
self,
126+
common_prefix_len: int,
127+
common_attn_metadata: CommonAttentionMetadata,
128+
decode_only_common_attn_metadata: Optional[
129+
CommonAttentionMetadata] = None,
130+
):
131+
if decode_only_common_attn_metadata is not None:
132+
raise NotImplementedError(
133+
"CPU backend does not support decode-only attention yet.")
124134
num_reqs = common_attn_metadata.num_reqs
125135
num_actual_tokens = common_attn_metadata.num_actual_tokens
126136
max_query_len = common_attn_metadata.max_query_len
137+
query_start_loc_np = (common_attn_metadata.query_start_loc_np
138+
if common_attn_metadata.query_start_loc_np
139+
is not None else self.runner.query_start_loc_np)
127140

128141
runner = self.runner
129142
block_table = self.block_table
@@ -135,8 +148,8 @@ def build(self, common_prefix_len: int,
135148
) if num_prompt_req < num_reqs else 0
136149
self.seq_start_loc_np[0] = 0
137150
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
138-
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
139-
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
151+
num_prefill_tokens = query_start_loc_np[num_prompt_req].item()
152+
num_decode_tokens = query_start_loc_np[num_reqs].item(
140153
) - num_prefill_tokens
141154
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
142155
block_table_tensor = block_table.get_device_tensor()

0 commit comments

Comments
 (0)