Skip to content

Commit d8848a2

Browse files
committed
[V1] Support ngram spec decode
Signed-off-by: ponix-j <657511300@qq.com>
1 parent 00e0243 commit d8848a2

File tree

11 files changed

+1707
-10
lines changed

11 files changed

+1707
-10
lines changed

tests/sample/__init__.py

Whitespace-only changes.

tests/sample/test_rejection_sampler.py

Lines changed: 611 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from __future__ import annotations
3+
4+
import random
5+
from typing import Any
6+
7+
import pytest
8+
9+
from vllm import LLM, SamplingParams
10+
11+
12+
@pytest.fixture
13+
def test_prompts():
14+
prompt_types = ["repeat", "sentence"]
15+
num_prompts = 100
16+
prompts = []
17+
18+
random.seed(0)
19+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
20+
21+
# Generate a mixed batch of prompts, some of which can be easily
22+
# predicted by n-gram matching and some which likely cannot.
23+
for kind in random_prompt_type_choices:
24+
word_choices = ["test", "temp", "hello", "where"]
25+
word = random.choice(word_choices)
26+
if kind == "repeat":
27+
prompt = f"""
28+
please repeat the word '{word}' 10 times.
29+
give no other output than the word at least ten times in a row,
30+
in lowercase with spaces between each word and without quotes.
31+
"""
32+
elif kind == "sentence":
33+
prompt = f"""
34+
please give a ten-word sentence that
35+
uses the word {word} at least once.
36+
give no other output than that simple sentence without quotes.
37+
"""
38+
else:
39+
raise ValueError(f"Unknown prompt type: {kind}")
40+
prompts.append([{"role": "user", "content": prompt}])
41+
42+
return prompts
43+
44+
45+
@pytest.fixture
46+
def sampling_config():
47+
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
48+
49+
50+
@pytest.fixture
51+
def model_name():
52+
return "meta-llama/Llama-3.1-8B-Instruct"
53+
54+
55+
def eagle_model_name():
56+
return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
57+
58+
59+
def eagle3_model_name():
60+
return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
61+
62+
63+
def test_ngram_correctness(
64+
monkeypatch: pytest.MonkeyPatch,
65+
test_prompts: list[list[dict[str, Any]]],
66+
sampling_config: SamplingParams,
67+
model_name: str,
68+
):
69+
'''
70+
Compare the outputs of a original LLM and a speculative LLM
71+
should be the same when using ngram speculative decoding.
72+
'''
73+
with monkeypatch.context() as m:
74+
m.setenv("VLLM_USE_V1", "1")
75+
76+
ref_llm = LLM(model=model_name, max_model_len=1024)
77+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
78+
del ref_llm
79+
80+
spec_llm = LLM(
81+
model=model_name,
82+
speculative_config={
83+
"method": "ngram",
84+
"prompt_lookup_max": 5,
85+
"prompt_lookup_min": 3,
86+
"num_speculative_tokens": 3,
87+
},
88+
max_model_len=1024,
89+
)
90+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
91+
matches = 0
92+
misses = 0
93+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
94+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
95+
matches += 1
96+
else:
97+
misses += 1
98+
print(f"ref_output: {ref_output.outputs[0].text}")
99+
print(f"spec_output: {spec_output.outputs[0].text}")
100+
101+
# Heuristic: expect at least 70% of the prompts to match exactly
102+
# Upon failure, inspect the outputs to check for inaccuracy.
103+
assert matches > int(0.7 * len(ref_outputs))
104+
del spec_llm
105+
106+
107+
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
108+
def test_eagle_correctness(
109+
monkeypatch: pytest.MonkeyPatch,
110+
test_prompts: list[list[dict[str, Any]]],
111+
sampling_config: SamplingParams,
112+
model_name: str,
113+
use_eagle3: bool,
114+
):
115+
'''
116+
Compare the outputs of a original LLM and a speculative LLM
117+
should be the same when using eagle speculative decoding.
118+
'''
119+
with monkeypatch.context() as m:
120+
m.setenv("VLLM_USE_V1", "1")
121+
122+
ref_llm = LLM(model=model_name, max_model_len=2048)
123+
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
124+
del ref_llm
125+
126+
spec_model_name = eagle3_model_name(
127+
) if use_eagle3 else eagle_model_name()
128+
spec_llm = LLM(
129+
model=model_name,
130+
trust_remote_code=True,
131+
speculative_config={
132+
"method": "eagle3" if use_eagle3 else "eagle",
133+
"model": spec_model_name,
134+
"num_speculative_tokens": 3,
135+
"max_model_len": 2048,
136+
},
137+
max_model_len=2048,
138+
)
139+
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
140+
matches = 0
141+
misses = 0
142+
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
143+
if ref_output.outputs[0].text == spec_output.outputs[0].text:
144+
matches += 1
145+
else:
146+
misses += 1
147+
print(f"ref_output: {ref_output.outputs[0].text}")
148+
print(f"spec_output: {spec_output.outputs[0].text}")
149+
150+
# Heuristic: expect at least 66% of the prompts to match exactly
151+
# Upon failure, inspect the outputs to check for inaccuracy.
152+
assert matches > int(0.66 * len(ref_outputs))
153+
del spec_llm

vllm_ascend/attention/attention_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class AscendMetadata:
111111
block_tables: torch.Tensor
112112
# (batch_size,). The sequence length per sequence. Sequence length means
113113
# the computed tokens + new tokens None if it is a decoding.
114+
query_start_loc: torch.Tensor
114115
query_lens: torch.Tensor
115116
seq_lens: torch.Tensor
116117
# Maximum query length in the batch. None for decoding.
@@ -141,6 +142,9 @@ def reorder_batch(self, input_batch: "InputBatch",
141142

142143
def build(self, num_reqs, num_actual_tokens, max_query_len,
143144
common_prefix_len):
145+
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
146+
query_start_loc = query_start_loc_cpu.to(self.runner.device,
147+
non_blocking=True)
144148
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
145149
block_table = (self.runner.input_batch.block_table.
146150
get_device_tensor()[:num_reqs])
@@ -159,6 +163,7 @@ def build(self, num_reqs, num_actual_tokens, max_query_len,
159163

160164
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
161165
block_tables=block_table,
166+
query_start_loc=query_start_loc,
162167
query_lens=query_lens,
163168
seq_lens=seq_lens,
164169
max_query_len=max_query_len,

vllm_ascend/patch/spec_decode/__init__.py

Whitespace-only changes.

vllm_ascend/patch/spec_decode/patch_common/__init__.py

Whitespace-only changes.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
import torch.nn as nn
4+
5+
from vllm.v1.spec_decode.eagle import EagleProposer
6+
7+
8+
@staticmethod
9+
def prepare_inputs(
10+
# [batch_size + 1]
11+
cu_target_query_lens: torch.Tensor,
12+
# [batch_size]
13+
num_rejected_tokens: torch.Tensor,
14+
) -> tuple[torch.Tensor, torch.Tensor]:
15+
# cu_target_query_lens: [0, a, a + b, a + b + c]
16+
# num_rejected_tokens: [n1, n2, n3]
17+
# num_tokens_per_req: [a - n1, b - n2, c - n3]
18+
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
19+
# token_indices: [0, 1, ..., a - n1 - 1,
20+
# a, a + 1, ..., a + b - n2 - 1,
21+
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
22+
23+
# [0, a, a + b, a + b + c] -> [a, b, c]
24+
query_len_per_req = (cu_target_query_lens[1:] -
25+
cu_target_query_lens[:-1])
26+
# [a, b, c] -> [a - n1, b - n2, c - n3]
27+
num_tokens_per_req = query_len_per_req - num_rejected_tokens
28+
29+
cu_num_tokens = torch.empty_like(cu_target_query_lens)
30+
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
31+
cu_num_tokens[0] = 0
32+
33+
# FIXME(woosuk): Avoid synchronization.
34+
num_tokens = cu_num_tokens[-1].item()
35+
token_indices = torch.empty(
36+
num_tokens,
37+
dtype=torch.int32,
38+
device=cu_num_tokens.device,
39+
)
40+
41+
BLOCK_SIZE = 1024
42+
prepare_input_pytorch(
43+
token_indices,
44+
cu_target_query_lens,
45+
cu_num_tokens,
46+
block_size=BLOCK_SIZE,
47+
)
48+
return cu_num_tokens, token_indices
49+
50+
51+
def prepare_input_pytorch(
52+
out_ptr: torch.Tensor,
53+
cu_query_lens: torch.Tensor,
54+
cu_num_tokens: torch.Tensor,
55+
block_size: int
56+
):
57+
num_pids = cu_num_tokens.shape[0] - 1
58+
59+
for pid in range(num_pids):
60+
start_pos = cu_num_tokens[pid].item()
61+
end_pos = cu_num_tokens[pid + 1].item()
62+
num_tokens = end_pos - start_pos
63+
64+
index_start = cu_query_lens[pid].item()
65+
num_blocks = (num_tokens + block_size - 1)
66+
67+
for i in range(num_blocks):
68+
offset = torch.arange(0, block_size, dtype=out_ptr.dtype,
69+
device=cu_query_lens.device)
70+
global_indices = start_pos + offset
71+
values = index_start + offset
72+
mask = offset < num_tokens
73+
out_ptr[global_indices[mask]] = values[mask]
74+
75+
76+
EagleProposer.prepare_inputs = prepare_inputs

vllm_ascend/sample/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)