Skip to content

Commit 7c81f06

Browse files
author
Your Name
committed
faster
1 parent 05f4cbf commit 7c81f06

File tree

6 files changed

+90
-33
lines changed

6 files changed

+90
-33
lines changed

tests/beam/__init__.py

Whitespace-only changes.

tests/beam/test_beam.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
3+
from vllm.beam.beam import BeamScorer
4+
from vllm.entrypoints.openai.protocol import CompletionResponse, CompletionResponseChoice, EmbeddingResponse, UsageInfo
5+
6+
classi_idx = {
7+
"annotations_sexually_suggestive": 0,
8+
"annotations_racist": 1,
9+
}
10+
11+
@pytest.mark.asyncio
12+
async def test_beam_scorer():
13+
responses = [CompletionResponse(
14+
choices=[CompletionResponseChoice(text="Hello", index=0, logprobs=None, finish_reason="length", additional_heads=[[10000, 0, 0]],),],
15+
model="test",
16+
usage=UsageInfo(),
17+
),
18+
CompletionResponse(
19+
choices=[CompletionResponseChoice(text="Hello", index=0, logprobs=None, finish_reason="length",
20+
additional_heads=[[-100, 0, 0]], ), ],
21+
model="test",
22+
usage=UsageInfo(),
23+
)
24+
]
25+
26+
scorer = BeamScorer(classi_idx)
27+
res = await scorer.pick_best_beam(responses)
28+
assert res == responses[1]

tests/beam/test_beam_meow.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
3+
from vllm.beam.beam import BeamScorer
4+
from vllm.beam.penalty import MEOW_CLASSI_IDX
5+
from vllm.entrypoints.openai.protocol import CompletionResponse, CompletionResponseChoice, EmbeddingResponse, UsageInfo
6+
7+
@pytest.fixture()
8+
async def meow_random_beams():
9+
return (
10+
" Aizawa: You haven't given me your name, age, and quirk",
11+
12+
)
13+
@pytest.mark.asyncio
14+
async def test_beam_scorer():
15+
responses = [CompletionResponse(
16+
choices=[CompletionResponseChoice(text="Hello", index=0, logprobs=None, finish_reason="length", additional_heads=[[10000, 0, 0]],),],
17+
model="test",
18+
usage=UsageInfo(),
19+
),
20+
CompletionResponse(
21+
choices=[CompletionResponseChoice(text="Hello", index=0, logprobs=None, finish_reason="length",
22+
additional_heads=[[-100, 0, 0]], ), ],
23+
model="test",
24+
usage=UsageInfo(),
25+
)
26+
]
27+
28+
scorer = BeamScorer(MEOW_CLASSI_IDX)
29+
res = await scorer.pick_best_beam(responses)
30+
assert res == responses[1]

vllm/beam/beam.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from vllm.beam.penalty import PenaltyComputer
66
import torch
77
from vllm.beam.ranking import RankingComputer
8-
from vllm.entrypoints.openai.protocol import CompletionResponse, ErrorResponse
8+
from vllm.entrypoints.openai.protocol import CompletionResponse, ErrorResponse, CompletionResponseChoice
99
from vllm.logger import init_logger
1010

1111
logger = init_logger(__name__)
@@ -17,13 +17,13 @@ def __init__(self, classi_idx):
1717
self.ranking_computer = RankingComputer(classi_idx)
1818

1919
async def pick_best_beam(self, responses: list[
20-
Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]]) -> Union[
21-
AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
20+
Union[AsyncGenerator[str, None], CompletionResponseChoice, ErrorResponse]]) -> Union[
21+
AsyncGenerator[str, None], CompletionResponseChoice, ErrorResponse]:
2222
debug_info = [BeamDebugInfo() for _ in responses]
2323

2424
scores = torch.zeros(len(responses), dtype=torch.float)
2525

26-
heads = [response.choices[0].additional_heads[0] for response in responses]
26+
heads = [response.additional_heads[0] for response in responses]
2727
heads_tensor = torch.tensor(heads, dtype=torch.float)
2828
if len(heads_tensor) > 0:
2929
penalties = self.penalty_computer.compute(heads_tensor, debug_info)
@@ -36,7 +36,7 @@ async def pick_best_beam(self, responses: list[
3636

3737
for i in range(len(responses)):
3838
debug_info[i].final_score = scores[i]
39-
debug_info[i].content = responses[i].choices[0].text
39+
debug_info[i].content = responses[i].text
4040

4141
logger.debug('debug_info: %s', debug_info)
4242

vllm/beam/filtering.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from starlette.datastructures import MutableHeaders
88

99
from vllm.entrypoints.openai.protocol import CompletionRequest, CompletionResponse, \
10-
ErrorResponse
10+
ErrorResponse, CompletionResponseChoice
1111
from vllm.logger import init_logger
1212
from vllm.utils import random_uuid
1313

@@ -36,32 +36,32 @@ async def get_n_valid_beams(self, create_completion: Callable,
3636
request: CompletionRequest,
3737
chunk_num: int,
3838
raw_request: Optional[Request] = None) -> list[
39-
Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]]:
39+
Union[AsyncGenerator[str, None], CompletionResponseChoice, ErrorResponse]]:
4040
request.stream = False
41-
n = request.n if request.n > 1 else _DEFAULT_BEAM_SIZE
42-
request.n = 1
43-
# TODO(@tanuj): accept max tokens as a parameter
41+
original_n = request.n
42+
request.n = request.n if request.n > 1 else _DEFAULT_BEAM_SIZE
4443
request.max_tokens = _CHUNK_SIZE
4544
request.echo = True
4645
original_request_id = None
4746
if raw_request is not None:
4847
original_request_id = raw_request.headers.get("X-Request-Id", None)
49-
50-
tasks = []
51-
# TODO(@tanuj): deep copy request and raw_request?
52-
for _ in range(n):
53-
if original_request_id is not None:
54-
mh = MutableHeaders(scope=raw_request.scope)
55-
del mh["x-request-id"]
56-
if hasattr(raw_request, "_headers"):
57-
delattr(raw_request, "_headers")
58-
59-
tasks.append(create_completion(
48+
49+
if original_request_id is not None:
50+
mh = MutableHeaders(scope=raw_request.scope)
51+
del mh["x-request-id"]
52+
if hasattr(raw_request, "_headers"):
53+
delattr(raw_request, "_headers")
54+
55+
raw_res = await create_completion(
6056
request,
6157
raw_request=raw_request,
62-
))
63-
res = await asyncio.gather(*tasks)
64-
request.n = n
58+
)
59+
60+
if isinstance(raw_res, ErrorResponse):
61+
return raw_res
62+
63+
res = raw_res.choices
64+
request.n = original_n
6565
beam_validator_res = self.validate(res)
6666
if isinstance(beam_validator_res, ErrorResponse):
6767
return beam_validator_res
@@ -73,7 +73,7 @@ async def get_n_valid_beams(self, create_completion: Callable,
7373

7474
return filtered_res
7575

76-
def validate(self, responses: list[AsyncGenerator],
76+
def validate(self, responses: list[CompletionResponseChoice | ErrorResponse],
7777
debug_infos_G: list[BeamDebugInfo] = None):
7878
error_responses = [r for r in responses if isinstance(r, ErrorResponse)]
7979
print(f"error_responses: {error_responses}")
@@ -86,7 +86,7 @@ def validate(self, responses: list[AsyncGenerator],
8686
)
8787

8888
# TODO(@tanuj) - share this with the beam scorer
89-
heads = [response.choices[0].additional_heads[0] for response in responses]
89+
heads = [response.additional_heads[0] for response in responses]
9090
heads_tensor = torch.tensor(heads, dtype=torch.float)
9191
prob_GC = torch.sigmoid(heads_tensor)
9292
valid_G = torch.ones(prob_GC.shape[0], dtype=torch.bool)
@@ -99,8 +99,7 @@ def validate(self, responses: list[AsyncGenerator],
9999

100100
if filtered:
101101
valid_G[g] = False
102-
for choice in responses[g].choices:
103-
choice.is_filtered = True
102+
responses[g].is_filtered = True
104103

105104
return valid_G
106105

vllm/entrypoints/openai/serving_completion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ async def _process_prefix(request: CompletionRequest):
104104
input_str_len = len(res.choices[0].text)
105105

106106
async def _should_stop(final):
107-
return final.choices[0].finish_reason == "stop" or final.choices[0].is_filtered
107+
return final.finish_reason == "stop" or final.is_filtered
108108

109109
max_chunks = math.ceil(request.max_tokens / _CHUNK_SIZE)
110110
async def _chunk_generator():
@@ -121,12 +121,12 @@ async def _chunk_generator():
121121
break
122122

123123
final = await self.beam_scorer.pick_best_beam(beams)
124-
request.prompt = final.choices[0].text
124+
request.prompt = final.text
125125
should_stop = await _should_stop(final)
126-
final.choices[0].text = final.choices[0].text[input_str_len:]
127-
output = final.choices[0].text
126+
final.text = final.text[input_str_len:]
127+
output = final.text
128128
if self.request_logger:
129-
logger.info(f"yielding chunk {num_chunks} text: {final.choices[0].text}")
129+
logger.info(f"yielding chunk {num_chunks} text: {final.text}")
130130
yield f"data: {final.model_dump_json()}\n\n"
131131

132132
if should_stop:

0 commit comments

Comments
 (0)