Skip to content

Commit e19a8ee

Browse files
committed
Add examples and algorithm for non-shifting, fixes some minor issues
Signed-off-by: morgendave <morgendave@gmail.com>
1 parent 2e9541f commit e19a8ee

File tree

4 files changed

+79
-38
lines changed

4 files changed

+79
-38
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,24 +126,38 @@ def test_ngram_correctness(
126126
@pytest.mark.parametrize(
127127
"model_setup,mm_enabled", [
128128
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
129-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
129+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False, True),
130130
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
131-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
131+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False, True),
132132
pytest.param(
133133
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
134-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False),
134+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
135+
False, True),
135136
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
136137
pytest.param(
137138
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
138-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True),
139+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
140+
True, True),
141+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
142+
pytest.param(
143+
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
144+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
145+
False, False),
146+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
147+
pytest.param(
148+
(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
149+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
150+
True, False),
139151
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
140152
],
141-
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
153+
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm",
154+
"llama4_eagle_no_shift", "llama4_eagle_mm_no_shift"])
142155
def test_eagle_correctness(
143156
monkeypatch: pytest.MonkeyPatch,
144157
sampling_config: SamplingParams,
145158
model_setup: tuple[str, str, str, int],
146159
mm_enabled: bool,
160+
prefill_shift: bool,
147161
):
148162
# Generate test prompts inside the function instead of using fixture
149163
test_prompts = get_test_prompts(mm_enabled)
@@ -156,8 +170,9 @@ def test_eagle_correctness(
156170
m.setenv("VLLM_USE_V1", "1")
157171
method, model_name, spec_model_name, tp_size = model_setup
158172

173+
max_model_len = 2048 if not mm_enabled else 4096
159174
ref_llm = LLM(model=model_name,
160-
max_model_len=2048,
175+
max_model_len=max_model_len,
161176
tensor_parallel_size=tp_size)
162177
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
163178
del ref_llm
@@ -172,9 +187,10 @@ def test_eagle_correctness(
172187
"method": method,
173188
"model": spec_model_name,
174189
"num_speculative_tokens": 3,
175-
"max_model_len": 2048,
190+
"max_model_len": max_model_len,
191+
"prefill_token_shift": prefill_shift,
176192
},
177-
max_model_len=2048,
193+
max_model_len=max_model_len,
178194
)
179195
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
180196
matches = 0

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2557,7 +2557,7 @@ class SpeculativeConfig:
25572557

25582558
# Config for kv sharing, map from base model layer to draft layer
25592559
# Key is draft layer, value is base layer
2560-
kv_sharing_mapping: SkipValidation[dict[str, str]] = None
2560+
kv_sharing_mapping: SkipValidation[dict[str, str]] = None # type: ignore
25612561
"""KV copy mapping for prefill stage from base to draft"""
25622562

25632563
def compute_hash(self) -> str:

vllm/v1/spec_decode/eagle.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _prepare_adjusted_tensors(
104104
cu_num_tokens: torch.Tensor,
105105
decode_mask: torch.Tensor,
106106
full_prefill_mask: torch.Tensor,
107+
partial_prefill_mask: torch.Tensor,
107108
prefill_first_hiddens: torch.Tensor,
108109
block_table: torch.Tensor,
109110
batch_size: int,
@@ -131,6 +132,34 @@ def _prepare_adjusted_tensors(
131132
tuple: (target_positions, target_hidden_states, target_slot_mapping,
132133
cu_num_tokens, current_pos, partial_prefill_mask)
133134
135+
Algorithm design:
136+
- Suppose target tokens are [1,2,3,...N], next token is N+1
137+
- Position is [0,1,2,...N-1]
138+
- And hidden is [h1,h2,h3,...hN]
139+
- Suppose partial prefill is [Nm, Nm+1, ...Nm+M-1]
140+
-- For normal shifting:
141+
--- draft prefill is [2,3,...N+1], position is same as target
142+
--- Stacking hidden is [h1,h2,h3,...hN]
143+
--- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...]
144+
--- Decode positions are [N,N+1,...]
145+
--- draft partial prefill is [Nm+1, Nm+2, ...Nm+M]
146+
-- For non-shifting:
147+
--- draft full prefill is [1,2,3,...N+1], position is [0,1,2,...N]
148+
--- Stacking hidden is [hN,h1,h2,h3,...hN]
149+
--- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...]
150+
--- Decode positions are [N+1,N+2,...]
151+
--- draft partial prefill is [Nm, Nm+1, ...Nm+M-1]
152+
--- draft hidden is [hNm-1,hNm,...hNm+M]
153+
(hNm-1 is the last round hidden)
154+
-- For kv sharing(non-shifting required):
155+
This means all target prefill tokens are not needed to be processed
156+
in drafting prefill step as we don't need the kv from draft.
157+
--- draft full prefill is [N+1], position is [N]
158+
--- Stacking hidden is [hN]
159+
--- Decode is the same as non-shifting decode
160+
--- draft partial prefill is totally skipped
161+
All other metadata like slot mapping, etc. should be based on
162+
the positions and tokens to generate/manipulate again
134163
"""
135164
# Count total number of full prefill requests to determine the
136165
# size needed for adjusted tensors
@@ -184,21 +213,6 @@ def _prepare_adjusted_tensors(
184213
# Create updated cumulative token counts
185214
updated_cu_num_tokens = torch.zeros_like(cu_num_tokens)
186215

187-
# Track which requests are partial prefill (no decode tokens)
188-
partial_prefill_mask = torch.zeros_like(full_prefill_mask)
189-
190-
# Create masks for each category
191-
has_decode_mask = torch.zeros(batch_size,
192-
dtype=torch.bool,
193-
device=decode_mask.device)
194-
for i in range(batch_size):
195-
start_idx = cu_num_tokens[i].item()
196-
end_idx = cu_num_tokens[i + 1].item()
197-
has_decode_mask[i] = decode_mask[start_idx:end_idx].any().item()
198-
199-
# Category 1: Partial prefill (no decode tokens)
200-
partial_prefill_mask = ~has_decode_mask
201-
202216
# Process batched operations using masks
203217
current_pos = 0
204218
cu_num_tokens_index = 0
@@ -368,6 +382,7 @@ def propose(
368382
mm_embeds: Optional[list[torch.Tensor]] = None,
369383
decode_mask: torch.Tensor = None,
370384
full_prefill_mask: torch.Tensor = None,
385+
partial_prefill_mask: torch.Tensor = None,
371386
) -> torch.Tensor:
372387
num_tokens = target_token_ids.shape[0]
373388
batch_size = next_token_ids.shape[0]
@@ -388,6 +403,17 @@ def propose(
388403
prefill_shift_tokens = False
389404

390405
if not prefill_shift_tokens and has_prefill:
406+
if (partial_prefill_mask.all()
407+
and self.draft_prefill_kv_sharing_from_base):
408+
# All requests are partial prefill and
409+
# KV cache sharing is enabled
410+
# Skip the rest of the function
411+
# and return dummy draft tokens
412+
return torch.zeros(
413+
(batch_size, self.num_speculative_tokens),
414+
dtype=target_token_ids.dtype,
415+
device=target_token_ids.device,
416+
)
391417
# Adjust the tensors for full prefill requests
392418
(
393419
target_positions,
@@ -404,22 +430,12 @@ def propose(
404430
cu_num_tokens,
405431
decode_mask,
406432
full_prefill_mask,
433+
partial_prefill_mask,
407434
prefill_first_hiddens,
408435
block_table,
409436
batch_size,
410437
num_tokens,
411438
)
412-
if (partial_prefill_mask.all()
413-
and self.draft_prefill_kv_sharing_from_base):
414-
# All requests are partial prefill and
415-
# KV cache sharing is enabled
416-
# Skip the rest of the function
417-
# and return dummy draft tokens
418-
return torch.zeros(
419-
(batch_size, self.num_speculative_tokens),
420-
dtype=target_token_ids.dtype,
421-
device=target_token_ids.device,
422-
)
423439
batch_size = cu_num_tokens.shape[0] - 1
424440
else:
425441
# Original behavior: shift all tokens by one
@@ -451,6 +467,9 @@ def propose(
451467
if not prefill_shift_tokens and has_prefill:
452468
# Replace the last token with the next token under non-shifting,
453469
# but only for non-partial prefill requests
470+
# For partial prefill in non-shifting, we just match the target
471+
# prefill tokens as it would match the positions and hidden states
472+
# so no need to add this next token from next round
454473
mask = ~partial_prefill_mask
455474
# if we enable copy kv then all of the partial prefills
456475
# are completely skipped so they won't be in last_token_indices

vllm/v1/worker/gpu_model_runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,8 +1303,7 @@ def execute_model(
13031303

13041304
# Prepare the decoder inputs.
13051305
(attn_metadata, attention_cuda_graphs, logits_indices,
1306-
spec_decode_metadata,
1307-
num_scheduled_tokens_np,
1306+
spec_decode_metadata, num_scheduled_tokens_np,
13081307
decode_mask) = (self._prepare_inputs(scheduler_output))
13091308
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
13101309
if (self.use_cuda_graph
@@ -1687,6 +1686,7 @@ def propose_draft_token_ids(
16871686
# This is used in non-shifted prefill for eagle draft.
16881687
prefill_first_hiddens = []
16891688
full_prefill_mask = []
1689+
partial_prefill_mask = []
16901690
for i, token_ids in enumerate(sampled_token_ids):
16911691
req_id = self.input_batch.req_ids[i]
16921692
req_state = self.requests[req_id]
@@ -1695,7 +1695,7 @@ def propose_draft_token_ids(
16951695
# works very well for init the first prefill hidden state.
16961696
if req_state.prefill_hidden_states is None:
16971697
req_state.prefill_hidden_states = target_hidden_states[
1698-
cu_num_tokens[i]]
1698+
cu_num_tokens[i + 1] - 1]
16991699
prefill_first_hiddens.append(req_state.prefill_hidden_states)
17001700
num_prompt_tokens = req_state.num_prompt_tokens
17011701
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
@@ -1707,6 +1707,7 @@ def propose_draft_token_ids(
17071707
if token_ids:
17081708
# Common case.
17091709
next_token_id = token_ids[-1]
1710+
partial_prefill_mask.append(False)
17101711
else:
17111712
# Partial prefill (rare case).
17121713
# Get the next token id from the request state.
@@ -1719,6 +1720,7 @@ def propose_draft_token_ids(
17191720
# as the first prefill hidden for the next round
17201721
req_state.prefill_hidden_states = target_hidden_states[
17211722
last_hidden_index]
1723+
partial_prefill_mask.append(True)
17221724
next_token_ids.append(next_token_id)
17231725
next_token_ids = torch.tensor(next_token_ids,
17241726
dtype=torch.int32,
@@ -1727,6 +1729,9 @@ def propose_draft_token_ids(
17271729
full_prefill_mask = torch.tensor(full_prefill_mask,
17281730
dtype=torch.bool,
17291731
device=self.device)
1732+
partial_prefill_mask = torch.tensor(partial_prefill_mask,
1733+
dtype=torch.bool,
1734+
device=self.device)
17301735
draft_token_ids = self.drafter.propose(
17311736
target_token_ids=target_token_ids,
17321737
target_positions=target_positions,
@@ -1740,6 +1745,7 @@ def propose_draft_token_ids(
17401745
prefill_first_hiddens=prefill_first_hiddens,
17411746
decode_mask=decode_mask,
17421747
full_prefill_mask=full_prefill_mask,
1748+
partial_prefill_mask=partial_prefill_mask,
17431749
)
17441750
spec_token_ids = draft_token_ids.tolist()
17451751
return spec_token_ids

0 commit comments

Comments
 (0)