Skip to content

Commit c8e31e6

Browse files
committed
Spec decode with probs
Signed-off-by: Andy Lo <andy@mistral.ai>
1 parent 14601f5 commit c8e31e6

File tree

6 files changed

+348
-36
lines changed

6 files changed

+348
-36
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -294,34 +294,49 @@ def create_deterministic_logits(token_ids):
294294
block_table = torch.randint(0, 10, (batch_size, 10), device=device)
295295

296296
sampling_metadata = mock.MagicMock()
297+
# Simulate mixed greedy and non-greedy requests
298+
sampling_metadata.all_greedy = False
299+
sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device)
297300

298301
# Call the method under test
299-
result = proposer.propose(target_token_ids=target_token_ids,
300-
target_positions=target_positions,
301-
target_hidden_states=target_hidden_states,
302-
target_slot_mapping=target_slot_mapping,
303-
next_token_ids=next_token_ids,
304-
cu_num_tokens=cu_num_tokens,
305-
block_table=block_table,
306-
sampling_metadata=sampling_metadata)
307-
308-
assert result.shape == (batch_size, num_speculative_tokens)
302+
result, result_probs = proposer.propose(
303+
target_token_ids=target_token_ids,
304+
target_positions=target_positions,
305+
target_hidden_states=target_hidden_states,
306+
target_slot_mapping=target_slot_mapping,
307+
next_token_ids=next_token_ids,
308+
cu_num_tokens=cu_num_tokens,
309+
block_table=block_table,
310+
sampling_metadata=sampling_metadata)
311+
312+
assert len(result) == batch_size
313+
assert len(result_probs) == batch_size
314+
assert all(len(tokens) == num_speculative_tokens for tokens in result)
315+
assert all(r.shape == (num_speculative_tokens, vocab_size)
316+
for r in result_probs)
309317

310318
# Create expected tokens based on our token pattern
311319
if num_speculative_tokens == 1:
312320
# Example for num_speculative_tokens=1:
313321
# [[42], [60]]
314-
expected_tokens = torch.tensor(
315-
[[base_token_ids[0]], [base_token_ids[1]]], device=device)
322+
expected_tokens = torch.tensor([[base_token_ids[0]],
323+
[base_token_ids[1]]])
324+
expected_probs = torch.zeros((batch_size, 1, vocab_size),
325+
device=device)
326+
for i, token_id in enumerate(base_token_ids):
327+
expected_probs[i, 0, token_id] = 1.0
316328
else:
317329
# Example for num_speculative_tokens=3:
318330
# [[42, 43, 44], [60, 61, 62]]
319331
expected_tokens = torch.zeros((batch_size, num_speculative_tokens),
320-
dtype=torch.int64,
321-
device=device)
332+
dtype=torch.int64)
333+
expected_probs = torch.zeros(
334+
(batch_size, num_speculative_tokens, vocab_size), device=device)
322335
for i in range(batch_size):
323336
for j in range(num_speculative_tokens):
324337
expected_tokens[i, j] = base_token_ids[i] + j
338+
expected_probs[i, j, base_token_ids[i] + j] = 1.0
325339

326340
# Verify all tokens match our expectations
327-
assert torch.equal(result, expected_tokens)
341+
assert torch.equal(torch.tensor(result), expected_tokens)
342+
torch.testing.assert_close(torch.stack(result_probs), expected_probs)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import tempfile
4+
5+
import pytest
6+
import torch
7+
8+
from tests.v1.worker.test_gpu_model_runner import _schedule_new_request
9+
from vllm.config import VllmConfig
10+
from vllm.distributed import (cleanup_dist_env_and_memory,
11+
init_distributed_environment,
12+
initialize_model_parallel)
13+
from vllm.engine.arg_utils import EngineArgs
14+
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
15+
from vllm.v1.engine.core import get_kv_cache_config
16+
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
17+
18+
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
19+
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
20+
21+
22+
@pytest.fixture()
23+
def should_do_global_cleanup_after_test(request) -> bool:
24+
# So we can share the DraftModelProposer between tests
25+
return False
26+
27+
28+
@pytest.fixture(scope="class")
29+
def monkeyclass():
30+
with pytest.MonkeyPatch.context() as mp:
31+
yield mp
32+
33+
34+
@pytest.fixture(scope="class")
35+
def spec_decode_vllm_config_and_env_setup(monkeyclass: pytest.MonkeyPatch):
36+
with monkeyclass.context() as m:
37+
m.setenv("VLLM_USE_V1", "1")
38+
vllm_config = EngineArgs(model=model_dir,
39+
max_model_len=256,
40+
cuda_graph_sizes=[1, 2, 4],
41+
gpu_memory_utilization=0.8,
42+
speculative_config={
43+
"model": eagle_dir,
44+
"method": "eagle",
45+
"num_speculative_tokens": 2,
46+
}).create_engine_config()
47+
temp_file = tempfile.mkstemp()[1]
48+
init_distributed_environment(
49+
world_size=1,
50+
rank=0,
51+
distributed_init_method=f"file://{temp_file}",
52+
local_rank=0,
53+
backend="nccl",
54+
)
55+
initialize_model_parallel(1, 1)
56+
yield vllm_config
57+
cleanup_dist_env_and_memory()
58+
59+
60+
@pytest.fixture(scope="class")
61+
def mock_spec_decode_model_runner(
62+
spec_decode_vllm_config_and_env_setup: VllmConfig):
63+
model_runner = GPUModelRunner(spec_decode_vllm_config_and_env_setup,
64+
torch.device("cuda"))
65+
model_runner.load_model()
66+
kv_cache_spec = model_runner.get_kv_cache_spec()
67+
68+
kv_cache_config = get_kv_cache_config(
69+
spec_decode_vllm_config_and_env_setup, kv_cache_spec, 1024**3) # 1GB
70+
model_runner.initialize_kv_cache(kv_cache_config)
71+
yield model_runner
72+
73+
74+
class TestSpecDecodeScheduling:
75+
76+
def test_spec_decode_partial_scheduling(
77+
self, mock_spec_decode_model_runner: GPUModelRunner):
78+
"""Make sure we don't crash when the scheduler schedules only a subset
79+
of the requests.
80+
81+
Four iterations:
82+
1. Schedule both req1 (w/ 0 draft) and req2 (w/ 0 draft)
83+
2. Schedule only req1 (w/ 1 draft)
84+
3. Schedule both req1 (w/ 1 draft) and req2 (w/ 2 draft)
85+
4. Terminate req1 and req2
86+
"""
87+
# Schedule both req1 and req2 on the first iteration
88+
scheduler_output = _schedule_new_request("req1", "req2")
89+
mock_spec_decode_model_runner.execute_model(scheduler_output)
90+
91+
# Only schedule req1 on the second iteration
92+
cached_req_data = CachedRequestData(
93+
req_ids=["req1"],
94+
resumed_from_preemption=[False],
95+
new_token_ids=[[3]],
96+
new_block_ids=[([], )],
97+
num_computed_tokens=[3],
98+
)
99+
scheduler_output = SchedulerOutput(
100+
scheduled_new_reqs=[],
101+
scheduled_cached_reqs=cached_req_data,
102+
num_scheduled_tokens={"req1": 2},
103+
total_num_scheduled_tokens=2,
104+
scheduled_spec_decode_tokens={"req1": [1001]},
105+
scheduled_encoder_inputs={},
106+
num_common_prefix_blocks=[0],
107+
finished_req_ids=set(),
108+
free_encoder_input_ids=[],
109+
structured_output_request_ids={},
110+
grammar_bitmask=None,
111+
)
112+
mock_spec_decode_model_runner.execute_model(scheduler_output)
113+
114+
# Schedule both req1 and req2 on the third iteration
115+
cached_req_data = CachedRequestData(
116+
req_ids=["req1", "req2"],
117+
resumed_from_preemption=[False, False],
118+
new_token_ids=[[10], [11]],
119+
new_block_ids=[([], ), ([], )],
120+
num_computed_tokens=[4, 3],
121+
)
122+
scheduler_output = SchedulerOutput(
123+
scheduled_new_reqs=[],
124+
scheduled_cached_reqs=cached_req_data,
125+
num_scheduled_tokens={
126+
"req1": 2,
127+
"req2": 3
128+
},
129+
total_num_scheduled_tokens=5,
130+
scheduled_spec_decode_tokens={
131+
"req1": [1001],
132+
"req2": [2001, 2002]
133+
},
134+
scheduled_encoder_inputs={},
135+
num_common_prefix_blocks=[0],
136+
finished_req_ids=set(),
137+
free_encoder_input_ids=[],
138+
structured_output_request_ids={},
139+
grammar_bitmask=None,
140+
)
141+
mock_spec_decode_model_runner.execute_model(scheduler_output)
142+
143+
# Terminate both req1 and req2
144+
cached_req_data = CachedRequestData(
145+
req_ids=[],
146+
resumed_from_preemption=[],
147+
new_token_ids=[],
148+
new_block_ids=[],
149+
num_computed_tokens=[],
150+
)
151+
scheduler_output = SchedulerOutput(
152+
scheduled_new_reqs=[],
153+
scheduled_cached_reqs=cached_req_data,
154+
num_scheduled_tokens={},
155+
total_num_scheduled_tokens=0,
156+
scheduled_spec_decode_tokens={},
157+
scheduled_encoder_inputs={},
158+
num_common_prefix_blocks=[0],
159+
finished_req_ids={"req1", "req2"},
160+
free_encoder_input_ids=[],
161+
structured_output_request_ids={},
162+
grammar_bitmask=None,
163+
)
164+
mock_spec_decode_model_runner.execute_model(scheduler_output)
165+
166+
def test_spec_decode_preemption_scheduling(
167+
self, mock_spec_decode_model_runner: GPUModelRunner):
168+
"""Make sure we don't crash when the scheduler preempts a request.
169+
170+
Four iterations:
171+
1. Schedule req1 (w/ 0 draft) and req2 (w/ 0 draft)
172+
2. Schedule req1 (w/ 1 draft) and preempt req2
173+
3. Schedule req1 (w/ 1 draft) and resume req2 (w/ 2 draft)
174+
4. Terminate req1 and req2
175+
"""
176+
# Schedule both req1 and req2 on the first iteration
177+
scheduler_output = _schedule_new_request("req1", "req2")
178+
mock_spec_decode_model_runner.execute_model(scheduler_output)
179+
180+
# Only schedule req1 on the second iteration
181+
cached_req_data = CachedRequestData(
182+
req_ids=["req1"],
183+
resumed_from_preemption=[False],
184+
new_token_ids=[[3]],
185+
new_block_ids=[([], )],
186+
num_computed_tokens=[3],
187+
)
188+
scheduler_output = SchedulerOutput(
189+
scheduled_new_reqs=[],
190+
scheduled_cached_reqs=cached_req_data,
191+
num_scheduled_tokens={"req1": 2},
192+
total_num_scheduled_tokens=2,
193+
scheduled_spec_decode_tokens={"req1": [1001]},
194+
scheduled_encoder_inputs={},
195+
num_common_prefix_blocks=[0],
196+
finished_req_ids=set(),
197+
free_encoder_input_ids=[],
198+
structured_output_request_ids={},
199+
grammar_bitmask=None,
200+
)
201+
mock_spec_decode_model_runner.execute_model(scheduler_output)
202+
203+
# Schedule both req1 and req2 on the third iteration
204+
cached_req_data = CachedRequestData(
205+
req_ids=["req1", "req2"],
206+
resumed_from_preemption=[False, True],
207+
new_token_ids=[[10], [11]],
208+
new_block_ids=[([], ), ([0], )],
209+
num_computed_tokens=[4, 0],
210+
)
211+
scheduler_output = SchedulerOutput(
212+
scheduled_new_reqs=[],
213+
scheduled_cached_reqs=cached_req_data,
214+
num_scheduled_tokens={
215+
"req1": 2,
216+
"req2": 6
217+
},
218+
total_num_scheduled_tokens=8,
219+
scheduled_spec_decode_tokens={
220+
"req1": [1001],
221+
"req2": [2001, 2002]
222+
},
223+
scheduled_encoder_inputs={},
224+
num_common_prefix_blocks=[0],
225+
finished_req_ids=set(),
226+
free_encoder_input_ids=[],
227+
structured_output_request_ids={},
228+
grammar_bitmask=None,
229+
)
230+
mock_spec_decode_model_runner.execute_model(scheduler_output)
231+
232+
# Terminate both req1 and req2
233+
cached_req_data = CachedRequestData(
234+
req_ids=[],
235+
resumed_from_preemption=[],
236+
new_token_ids=[],
237+
new_block_ids=[],
238+
num_computed_tokens=[],
239+
)
240+
scheduler_output = SchedulerOutput(
241+
scheduled_new_reqs=[],
242+
scheduled_cached_reqs=cached_req_data,
243+
num_scheduled_tokens={},
244+
total_num_scheduled_tokens=0,
245+
scheduled_spec_decode_tokens={},
246+
scheduled_encoder_inputs={},
247+
num_common_prefix_blocks=[0],
248+
finished_req_ids={"req1", "req2"},
249+
free_encoder_input_ids=[],
250+
structured_output_request_ids={},
251+
grammar_bitmask=None,
252+
)
253+
mock_spec_decode_model_runner.execute_model(scheduler_output)

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
139139
total_num_scheduled_tokens=total_num_scheduled_tokens,
140140
scheduled_spec_decode_tokens={},
141141
scheduled_encoder_inputs={},
142-
num_common_prefix_blocks=0,
142+
num_common_prefix_blocks=[0],
143143
finished_req_ids=set(),
144144
free_encoder_input_ids=[],
145145
structured_output_request_ids={},

vllm/v1/spec_decode/eagle.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def propose(
9292
# [batch_size, max_num_blocks_per_req]
9393
block_table: torch.Tensor,
9494
sampling_metadata: SamplingMetadata,
95-
) -> torch.Tensor:
95+
) -> tuple[list[list[int]], list[torch.Tensor]]:
9696
num_tokens = target_token_ids.shape[0]
9797
batch_size = next_token_ids.shape[0]
9898
last_token_indices = cu_num_tokens[1:] - 1
@@ -183,19 +183,26 @@ def propose(
183183
last_hidden_states, hidden_states = ret_hidden_states
184184
sample_hidden_states = last_hidden_states[last_token_indices]
185185
logits = self.model.compute_logits(sample_hidden_states, None)
186-
draft_token_ids = logits.argmax(dim=-1)
186+
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
187+
logits, sampling_metadata)
187188

188189
# Early exit if there is only one draft token to be generated.
189190
if self.num_speculative_tokens == 1:
190-
# [batch_size, 1]
191-
return draft_token_ids.view(-1, 1)
191+
# [batch_size, 1] and [batch_size, 1, vocab_size]
192+
return (
193+
draft_token_ids.view(-1, 1).tolist(),
194+
draft_probs.unsqueeze(1).unbind(0),
195+
)
192196

193197
# TODO: Currently, MTP module released by deepseek only has
194198
# one layer. Adapt this code to support multiple layers once
195199
# there's a multi-layer MTP module.
196200

197201
# Generate the remaining draft tokens.
198-
draft_token_ids_list = [draft_token_ids]
202+
# [num_speculative_tokens, batch_size]
203+
draft_token_ids_list: list[torch.Tensor] = [draft_token_ids]
204+
# [num_speculative_tokens, batch_size, vocab_size]
205+
draft_probs_list: list[torch.Tensor] = [draft_probs]
199206

200207
positions = target_positions[last_token_indices]
201208
hidden_states = hidden_states[last_token_indices]
@@ -268,12 +275,16 @@ def propose(
268275
None)
269276

270277
# TODO(wenlong): get more than one token for tree attention
271-
draft_token_ids = logits.argmax(dim=-1)
278+
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
279+
logits, sampling_metadata)
272280
draft_token_ids_list.append(draft_token_ids)
281+
draft_probs_list.append(draft_probs)
273282

274283
# [batch_size, num_speculative_tokens]
275-
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
276-
return draft_token_ids
284+
draft_token_ids = torch.stack(draft_token_ids_list, dim=1).tolist()
285+
# [batch_size, num_speculative_tokens, vocab_size]
286+
draft_probs_list = torch.stack(draft_probs_list, dim=1).unbind(0)
287+
return draft_token_ids, draft_probs_list
277288

278289
@staticmethod
279290
def prepare_inputs(
@@ -398,10 +409,6 @@ def validate_same_kv_cache_group(self,
398409
) == 1, "All eagle layers should belong to the same kv cache group"
399410

400411

401-
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
402-
# to sample the draft tokens. We will use this after we find a way to manage
403-
# the draft prob tensor.
404-
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
405412
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
406413
# We should refactor this to reuse the same sampling implementation.
407414
def compute_probs_and_sample_next_token(

0 commit comments

Comments
 (0)