Skip to content

Commit 9d53c55

Browse files
author
Your Name
committed
fixes
1 parent a5d9dd2 commit 9d53c55

File tree

2 files changed

+32
-23
lines changed

2 files changed

+32
-23
lines changed

vllm/beam/beam.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,44 @@
11
from collections.abc import AsyncGenerator
2+
from typing import Union
3+
24
from vllm.beam.debug import BeamDebugInfo
35
from vllm.beam.penalty import PenaltyComputer
46
import torch
57
from vllm.beam.ranking import RankingComputer
8+
from vllm.entrypoints.openai.protocol import CompletionResponse, ErrorResponse
9+
from vllm.logger import init_logger
10+
11+
logger = init_logger(__name__)
612

713

814
class BeamScorer:
915
def __init__(self, classi_idx):
1016
self.penalty_computer = PenaltyComputer(classi_idx)
1117
self.ranking_computer = RankingComputer(classi_idx)
1218

13-
async def collapse_beams(self, responses: list[AsyncGenerator], chunk_num = 0, max_chunks = 4):
14-
debug_info = [BeamDebugInfo() for _ in responses]
15-
16-
scores = torch.zeros(len(responses), dtype=torch.float)
17-
18-
heads = [response.choices[0].additional_heads[0] for response in responses]
19-
heads_tensor = torch.tensor(heads, dtype=torch.float)
20-
if len(heads_tensor) > 0:
21-
penalties = self.penalty_computer.compute(heads_tensor, debug_info)
22-
scores -= penalties
23-
24-
ranking_scores = self.ranking_computer.compute(
19+
async def pick_best_beam(self, responses: list[
20+
Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]]) -> Union[
21+
AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
22+
debug_info = [BeamDebugInfo() for _ in responses]
23+
24+
scores = torch.zeros(len(responses), dtype=torch.float)
25+
26+
heads = [response.choices[0].additional_heads[0] for response in responses]
27+
heads_tensor = torch.tensor(heads, dtype=torch.float)
28+
if len(heads_tensor) > 0:
29+
penalties = self.penalty_computer.compute(heads_tensor, debug_info)
30+
scores -= penalties
31+
32+
ranking_scores = self.ranking_computer.compute(
2533
heads_tensor, debug_info
26-
)
27-
scores *= ranking_scores
34+
)
35+
scores += ranking_scores
2836

29-
for i in range(len(responses)):
30-
debug_info[i].final_score = scores[i]
31-
debug_info[i].content = responses[i].choices[0].text
37+
for i in range(len(responses)):
38+
debug_info[i].final_score = scores[i]
39+
debug_info[i].content = responses[i].choices[0].text
3240

33-
print('debug_info', debug_info)
41+
logger.debug('debug_info: %s', debug_info)
3442

35-
best_idx = torch.argmax(scores).item()
36-
return responses[best_idx]
37-
43+
best_idx = torch.argmax(scores).item()
44+
return responses[best_idx]

vllm/entrypoints/openai/serving_completion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import asyncio
5+
import math
56
import time
67
from collections.abc import AsyncGenerator, AsyncIterator
78
from collections.abc import Sequence as GenericSequence
@@ -13,7 +14,7 @@
1314
from typing_extensions import assert_never
1415

1516
from vllm.beam.beam import BeamScorer
16-
from vllm.beam.filtering import BeamValidator
17+
from vllm.beam.filtering import _CHUNK_SIZE, BeamValidator
1718
from vllm.beam.metrics import report_metrics
1819
from vllm.beam.penalty import MEOW_CLASSI_IDX, PenaltyComputer
1920
from vllm.config import ModelConfig
@@ -105,13 +106,14 @@ async def _process_prefix(request: CompletionRequest):
105106
async def _should_stop(final):
106107
return final.choices[0].finish_reason == "stop" or final.choices[0].is_filtered
107108

109+
max_chunks = math.ceil(request.max_tokens / _CHUNK_SIZE)
108110
async def _chunk_generator():
109111
num_chunks = 0
110112
should_stop = False
111113
output = None
112114

113115
# TODO(@tanuj): calc created tokens
114-
while num_chunks < 4 and not should_stop:
116+
while num_chunks < max_chunks and not should_stop:
115117
num_chunks += 1
116118
beams = await self.beam_validator.get_n_valid_beams(create_completion=self.create_completion, request=request, raw_request=raw_request)
117119
final = await self.beam_scorer.collapse_beams(beams, num_chunks)

0 commit comments

Comments
 (0)