Skip to content

Commit eb6bcd7

Browse files
committed
Cleanup
1 parent 8a9110d commit eb6bcd7

File tree

6 files changed

+714
-606
lines changed

6 files changed

+714
-606
lines changed

vllm/beam/beam.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,44 @@
11
from collections.abc import AsyncGenerator
2+
from typing import Union
3+
24
from vllm.beam.debug import BeamDebugInfo
35
from vllm.beam.penalty import PenaltyComputer
46
import torch
57
from vllm.beam.ranking import RankingComputer
8+
from vllm.entrypoints.openai.protocol import CompletionResponse, ErrorResponse
9+
from vllm.logger import init_logger
10+
11+
logger = init_logger(__name__)
612

713

814
class BeamScorer:
915
def __init__(self, classi_idx):
1016
self.penalty_computer = PenaltyComputer(classi_idx)
1117
self.ranking_computer = RankingComputer(classi_idx)
1218

13-
async def collapse_beams(self, responses: list[AsyncGenerator], chunk_num = 0, max_chunks = 4):
14-
debug_info = [BeamDebugInfo() for _ in responses]
15-
16-
scores = torch.zeros(len(responses), dtype=torch.float)
17-
18-
heads = [response.choices[0].additional_heads[0] for response in responses]
19-
heads_tensor = torch.tensor(heads, dtype=torch.float)
20-
if len(heads_tensor) > 0:
21-
penalties = self.penalty_computer.compute(heads_tensor, debug_info)
22-
scores -= penalties
23-
24-
ranking_scores = self.ranking_computer.compute(
19+
async def pick_best_beam(self, responses: list[
20+
Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]]) -> Union[
21+
AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
22+
debug_info = [BeamDebugInfo() for _ in responses]
23+
24+
scores = torch.zeros(len(responses), dtype=torch.float)
25+
26+
heads = [response.choices[0].additional_heads[0] for response in responses]
27+
heads_tensor = torch.tensor(heads, dtype=torch.float)
28+
if len(heads_tensor) > 0:
29+
penalties = self.penalty_computer.compute(heads_tensor, debug_info)
30+
scores -= penalties
31+
32+
ranking_scores = self.ranking_computer.compute(
2533
heads_tensor, debug_info
26-
)
27-
scores *= ranking_scores
34+
)
35+
scores *= ranking_scores
2836

29-
for i in range(len(responses)):
30-
debug_info[i].final_score = scores[i]
31-
debug_info[i].content = responses[i].choices[0].text
37+
for i in range(len(responses)):
38+
debug_info[i].final_score = scores[i]
39+
debug_info[i].content = responses[i].choices[0].text
3240

33-
print('debug_info', debug_info)
41+
logger.debug('debug_info: %s', debug_info)
3442

35-
best_idx = torch.argmax(scores).item()
36-
return responses[best_idx]
37-
43+
best_idx = torch.argmax(scores).item()
44+
return responses[best_idx]

vllm/beam/filtering.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import asyncio
22
from collections.abc import AsyncGenerator
3-
from typing import Callable, Optional
3+
from typing import Callable, Optional, Union
44
from urllib.request import Request
55
from vllm.beam.debug import BeamDebugInfo
66
import torch
7-
from vllm.entrypoints.openai.protocol import CompletionRequest
7+
from vllm.entrypoints.openai.protocol import CompletionRequest, CompletionResponse, \
8+
ErrorResponse
9+
from vllm.logger import init_logger
10+
11+
logger = init_logger(__name__)
812

913

1014
def format_filter(filter_params_list):
@@ -13,65 +17,75 @@ def format_filter(filter_params_list):
1317
raise ValueError(f"Expect 2 keys, seeing {filter_params_list=}")
1418
return [dict(zip(keys, vals)) for vals in filter_params_list]
1519

20+
1621
DEFAULT_CHAR_SERVER_FILTER = format_filter(
17-
[
18-
("annotations_porn", 0.1098),
19-
("annotations_racist", 0.2814),
20-
("annotations_disturbing", 0.1827),
21-
("annotations_harmful_promotes_selfharm", 0.0749),
22-
("annotations_harmful_promotes_terrorism", 0.1129),
23-
]
24-
)
25-
22+
[
23+
("annotations_porn", 0.1098),
24+
("annotations_racist", 0.2814),
25+
("annotations_disturbing", 0.1827),
26+
("annotations_harmful_promotes_selfharm", 0.0749),
27+
("annotations_harmful_promotes_terrorism", 0.1129),
28+
]
29+
)
30+
2631
MAX_GENERATIONS = 10
2732
_CHUNK_SIZE = 16
2833

34+
2935
class BeamValidator:
3036
def __init__(self, classi_idx, classifier_names):
3137
self.classi_idx = classi_idx
3238
self.classifier_names = classifier_names
3339

34-
async def get_n_valid_beams(self, create_completion: Callable, request: CompletionRequest, raw_request: Optional[Request] = None):
40+
async def get_n_valid_beams(self, create_completion: Callable,
41+
request: CompletionRequest,
42+
raw_request: Optional[Request] = None) -> list[
43+
Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]]:
3544
request.stream = False
3645
n = request.n
3746
request.n = 1
47+
# TODO(@tanuj): accept max tokens as a parameter
3848
request.max_tokens = _CHUNK_SIZE
3949
request.echo = True
4050
tasks = []
51+
# TODO(@tanuj): deep copy request and raw_request?
4152
for _ in range(n):
4253
request = request
4354
tasks.append(create_completion(
4455
request,
56+
raw_request=raw_request,
4557
))
4658
res = await asyncio.gather(*tasks)
4759
request.n = n
4860
beam_validator_res = self.validate(res)
4961
filtered_res = [r for r, valid in zip(res, beam_validator_res) if valid]
50-
print('everything is filtered', len(filtered_res) == 0)
62+
logger.debug("Filtered count: %d", len(filtered_res))
5163
if len(filtered_res) == 0:
5264
return res
53-
65+
5466
return filtered_res
55-
56-
def validate(self, responses: list[AsyncGenerator], debug_infos_G: list[BeamDebugInfo] = None):
57-
#TODO(@tanuj) - share this with the beam scorer
67+
68+
def validate(self, responses: list[AsyncGenerator],
69+
debug_infos_G: list[BeamDebugInfo] = None):
70+
# TODO(@tanuj) - share this with the beam scorer
5871
heads = [response.choices[0].additional_heads[0] for response in responses]
5972
heads_tensor = torch.tensor(heads, dtype=torch.float)
6073
prob_GC = torch.sigmoid(heads_tensor)
6174
valid_G = torch.ones(prob_GC.shape[0], dtype=torch.bool)
62-
75+
6376
for g in range(heads_tensor.shape[0]):
64-
filtered = self.get_filtered_classifiers(prob_GC[g], DEFAULT_CHAR_SERVER_FILTER)
77+
filtered = self.get_filtered_classifiers(prob_GC[g],
78+
DEFAULT_CHAR_SERVER_FILTER)
6579
if debug_infos_G is not None:
6680
debug_infos_G[g].filtered_classifiers = filtered
67-
81+
6882
if filtered:
6983
valid_G[g] = False
7084
for choice in responses[g].choices:
7185
choice.is_filtered = True
7286

7387
return valid_G
74-
88+
7589
def get_filtered_classifiers(self, prob_C, filter_params) -> list[str]:
7690
relevant_filters = [
7791
(p["name"], self.classi_idx[p["name"]], p["threshold"])
@@ -87,4 +101,4 @@ def get_filtered_classifiers(self, prob_C, filter_params) -> list[str]:
87101
if prob_C[idx] > threshold:
88102
ret.append(name)
89103

90-
return ret
104+
return ret

0 commit comments

Comments
 (0)