Skip to content

Commit 226edcf

Browse files
committed
Fix cudagraph issue with padding
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 123d52b commit 226edcf

File tree

4 files changed

+131
-54
lines changed

4 files changed

+131
-54
lines changed

tests/v1/e2e/conftest.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/v1/e2e/test_kv_sharing_skip_prefill.py

Lines changed: 96 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import gc
5+
import random
56
from collections.abc import Iterable
6-
from typing import Any, Optional, Union
7+
from typing import Optional, Union
78

89
import pytest
910
import torch
@@ -105,7 +106,7 @@ def forward(
105106

106107

107108
@support_torch_compile
108-
class DecoderLayerGroup(nn.Module):
109+
class FirstLayerGroup(nn.Module):
109110

110111
def __init__(
111112
self,
@@ -121,7 +122,35 @@ def forward(
121122
self,
122123
positions: torch.Tensor,
123124
hidden_states: torch.Tensor,
124-
residual: Optional[torch.Tensor] = None,
125+
):
126+
residual = None
127+
for layer in self.layers:
128+
hidden_states, residual = layer(
129+
positions,
130+
hidden_states,
131+
residual,
132+
)
133+
return hidden_states, residual
134+
135+
136+
@support_torch_compile
137+
class SecondLayerGroup(nn.Module):
138+
139+
def __init__(
140+
self,
141+
*,
142+
vllm_config: VllmConfig,
143+
prefix: str = "",
144+
layers: list[nn.Module],
145+
):
146+
super().__init__()
147+
self.layers = layers
148+
149+
def forward(
150+
self,
151+
positions: torch.Tensor,
152+
hidden_states: torch.Tensor,
153+
residual: torch.Tensor,
125154
):
126155
for layer in self.layers:
127156
hidden_states, residual = layer(
@@ -147,15 +176,17 @@ def __init__(self,
147176
decoder_layer_type=decoder_layer_type,
148177
)
149178

179+
self.vllm_config = vllm_config
180+
150181
with set_model_tag("first_layer_group"):
151-
self.first_layer_group = DecoderLayerGroup(
182+
self.first_layer_group = FirstLayerGroup(
152183
vllm_config=vllm_config,
153184
prefix=f"{prefix}.first_layer_group",
154185
layers=self.layers[self.start_layer:START_KV_SHARING_LAYER],
155186
)
156187

157188
with set_model_tag("second_layer_group"):
158-
self.second_layer_group = DecoderLayerGroup(
189+
self.second_layer_group = SecondLayerGroup(
159190
vllm_config=vllm_config,
160191
prefix=f"{prefix}.second_layer_group",
161192
layers=self.layers[START_KV_SHARING_LAYER:self.end_layer],
@@ -170,6 +201,10 @@ def __init__(self,
170201
self.residual = torch.zeros((self.max_num_tokens, self.hidden_size),
171202
dtype=self.dtype,
172203
device=self.device)
204+
self.hidden_states = torch.zeros(
205+
(self.max_num_tokens, self.hidden_size),
206+
dtype=self.dtype,
207+
device=self.device)
173208

174209
def forward(
175210
self,
@@ -183,11 +218,12 @@ def forward(
183218
else:
184219
hidden_states = self.get_input_embeddings(input_ids)
185220

186-
residual = None
221+
num_input_tokens = input_ids.size(0)
222+
self.hidden_states[:num_input_tokens].copy_(hidden_states)
223+
187224
first_hidden_states, first_residual = self.first_layer_group(
188225
positions,
189-
hidden_states,
190-
residual, # no residual, assume no pipeline parallel
226+
self.hidden_states[:num_input_tokens],
191227
)
192228

193229
decode_indices = get_forward_context().decode_indices
@@ -202,15 +238,24 @@ def forward(
202238
# CUDA graph expects static tensor addresses
203239
# Copy output of first layer group to second layer group
204240
self.residual[:num_decodes].copy_(first_residual[decode_indices])
205-
hidden_states[:num_decodes].copy_(first_hidden_states[decode_indices])
241+
self.hidden_states[:num_decodes].copy_(
242+
first_hidden_states[decode_indices])
206243
positions[:num_decodes].copy_(positions[decode_indices])
207244

208245
second_hidden_states, second_residual = self.second_layer_group(
209246
positions[:num_decodes],
210-
hidden_states[:num_decodes],
247+
self.hidden_states[:num_decodes],
211248
self.residual[:num_decodes],
212249
)
213250

251+
# NOTE(sarckk): Due to cudagraph padding, decode_indices may have
252+
# trailing repeated indices. Attention output is only valid at the
253+
# last index in this case.
254+
last_index_mask = decode_indices == decode_indices[-1]
255+
second_hidden_states[last_index_mask] = second_hidden_states[-1].clone(
256+
)
257+
second_residual[last_index_mask] = second_residual[-1].clone()
258+
214259
# Merge results back
215260
first_hidden_states[decode_indices] = second_hidden_states
216261
if first_residual is not None:
@@ -270,16 +315,43 @@ def load_weights(self, weights: Iterable[tuple[str,
270315
return loader.load_weights(weights)
271316

272317

318+
@pytest.fixture
319+
def test_prompts():
320+
prompt_types = ["repeat", "sentence"]
321+
# Setting higher num prompts increases the chance of numerics mismatch
322+
# due to matrix multiplication numerics depending on batch dimension
323+
num_prompts = 10
324+
prompts = []
325+
326+
random.seed(0)
327+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
328+
329+
# Generate a mixed batch of prompts, some of which can be easily
330+
# predicted by n-gram matching and some which likely cannot.
331+
for kind in random_prompt_type_choices:
332+
word_choices = ["test", "temp", "hello", "where"]
333+
word = random.choice(word_choices)
334+
if kind == "repeat":
335+
prompt = f"""please repeat the word '{word}' 10 times."""
336+
elif kind == "sentence":
337+
prompt = f"""please give a ten-word sentence that
338+
uses the word {word} at least once."""
339+
else:
340+
raise ValueError(f"Unknown prompt type: {kind}")
341+
prompts.append(prompt)
342+
343+
return prompts
344+
345+
273346
@fork_new_process_for_each_test
274347
@pytest.mark.parametrize("enforce_eager", [True, False])
275348
def test_kv_sharing_skip_prefill(
276349
monkeypatch: pytest.MonkeyPatch,
277350
enforce_eager: bool,
278-
test_prompts: list[list[dict[str, Any]]],
351+
test_prompts: list[str],
279352
):
280353
ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM)
281354
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
282-
prompts = [prompt[0]['content'] for prompt in test_prompts]
283355
compilation_config = CompilationConfig(
284356
level=CompilationLevel.PIECEWISE
285357
if not enforce_eager else CompilationLevel.NO_COMPILATION,
@@ -293,8 +365,7 @@ def test_kv_sharing_skip_prefill(
293365
enforce_eager=enforce_eager,
294366
compilation_config=compilation_config,
295367
)
296-
responses = llm.generate(prompts, sampling_params)
297-
ref_output = responses[0].outputs[0].text
368+
ref_responses = llm.generate(test_prompts, sampling_params)
298369

299370
del llm
300371
gc.collect()
@@ -304,6 +375,14 @@ def test_kv_sharing_skip_prefill(
304375
enforce_eager=enforce_eager,
305376
compilation_config=compilation_config,
306377
kv_sharing_skip_prefill=True)
307-
responses = llm.generate(prompts, sampling_params)
308-
output = responses[0].outputs[0].text
309-
assert output == ref_output
378+
optimized_responses = llm.generate(test_prompts, sampling_params)
379+
380+
misses = 0
381+
382+
for ref_response, optimized_response in zip(ref_responses,
383+
optimized_responses):
384+
if ref_response.outputs[0].text != optimized_response.outputs[
385+
0].text:
386+
misses += 1
387+
388+
assert misses == 0

tests/v1/e2e/test_spec_decode.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,47 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from __future__ import annotations
44

5+
import random
56
from typing import Any
67

78
import pytest
89

910
from vllm import LLM, SamplingParams
1011

1112

13+
@pytest.fixture
14+
def test_prompts():
15+
prompt_types = ["repeat", "sentence"]
16+
num_prompts = 100
17+
prompts = []
18+
19+
random.seed(0)
20+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
21+
22+
# Generate a mixed batch of prompts, some of which can be easily
23+
# predicted by n-gram matching and some which likely cannot.
24+
for kind in random_prompt_type_choices:
25+
word_choices = ["test", "temp", "hello", "where"]
26+
word = random.choice(word_choices)
27+
if kind == "repeat":
28+
prompt = f"""
29+
please repeat the word '{word}' 10 times.
30+
give no other output than the word at least ten times in a row,
31+
in lowercase with spaces between each word and without quotes.
32+
"""
33+
elif kind == "sentence":
34+
prompt = f"""
35+
please give a ten-word sentence that
36+
uses the word {word} at least once.
37+
give no other output than that simple sentence without quotes.
38+
"""
39+
else:
40+
raise ValueError(f"Unknown prompt type: {kind}")
41+
prompts.append([{"role": "user", "content": prompt}])
42+
43+
return prompts
44+
45+
1246
@pytest.fixture
1347
def sampling_config():
1448
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)

vllm/compilation/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def skip_torch_compile(cls: _T) -> _T:
2727
cls._skip_compile_vllm = True
2828
for base in cls.__bases__:
29-
base._skip_compile_vllm = True
29+
setattr(base,"_skip_compile_vllm",True)
3030
return cls
3131

3232

0 commit comments

Comments
 (0)