Skip to content

tanuj/beam clean #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
23 changes: 23 additions & 0 deletions temp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
306bdbbe2 very initial proto of beam search
439547313 beam
113f84d7e Remove unused classifier files and refactor OpenAI serving completion to integrate beam scoring functionality.
2a9a5fd5f beam + filtering
edade0dda filtering
8a9110d5d completion
343ca3c75 fixes
d7448df0b Revert "Add metrics"
b8fa07bab fix
f0c43e7eb fixes in
d18c40a7c serving completion
19657376b defautl size
359104f5a error response
44bcfefc5 fix
fa57e12c3 parallel lock
f861ca1d8 delete
d1c435137 metrics
1e16a3de2 update thresholds.json
073bff364 fix formatting
d0af397a7 faster
121b965fd tmp
bdb239f60 tracing
b778c045d fix
Empty file added vllm/beam/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions vllm/beam/beam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from collections.abc import AsyncGenerator
from typing import Union

from vllm.beam.debug import BeamDebugInfo
from vllm.beam.penalty import PenaltyComputer
import torch
from vllm.beam.ranking import RankingComputer
from vllm.beam.tracing import trace_async_method
from vllm.entrypoints.openai.protocol import CompletionResponse, ErrorResponse, CompletionResponseChoice

Check failure on line 9 in vllm/beam/beam.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/beam/beam.py:9:81: E501 Line too long (84 > 80)
from vllm.logger import init_logger

logger = init_logger(__name__)


class BeamScorer:
def __init__(self, classi_idx):
self.penalty_computer = PenaltyComputer(classi_idx)
self.ranking_computer = RankingComputer(classi_idx)

@trace_async_method(span_name='pick_best_beam')
async def pick_best_beam(self, responses: list[
Union[AsyncGenerator[str, None], CompletionResponseChoice, ErrorResponse]]) -> Union[
AsyncGenerator[str, None], CompletionResponseChoice, ErrorResponse]:
debug_info = [BeamDebugInfo() for _ in responses]

scores = torch.zeros(len(responses), dtype=torch.float)

heads = [response.additional_heads[0] for response in responses]
heads_tensor = torch.tensor(heads, dtype=torch.float)
if len(heads_tensor) > 0:
penalties = self.penalty_computer.compute(heads_tensor, debug_info)
scores -= penalties

ranking_scores = self.ranking_computer.compute(
heads_tensor, debug_info
)
scores += ranking_scores

for i in range(len(responses)):
debug_info[i].final_score = scores[i]
debug_info[i].content = responses[i].text

logger.debug('debug_info: %s', debug_info)

best_idx = torch.argmax(scores).item()
return responses[best_idx]
12 changes: 12 additions & 0 deletions vllm/beam/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import dataclasses

@dataclasses.dataclass
class BeamDebugInfo:
final_score: float = dataclasses.field(default_factory=float)
cummulative_penalty: float = dataclasses.field(default_factory=float)
cummulative_ranking_score: float = dataclasses.field(default_factory=float)
penalty_classifiers_that_are_over_threshold: list[str] = dataclasses.field(default_factory=list)
content: str = dataclasses.field(default_factory=str)
filtered_classifiers: list[str] = dataclasses.field(default_factory=list)


175 changes: 175 additions & 0 deletions vllm/beam/emoji.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from vllm.beam.emoji_data import EMOJI_DATA

_EMOJI_SEARCH_TREE = None

def emoji_count(input: str) -> int:
return len(emoji_list(input))

def emoji_list(input: str) -> list:
_entities = []

def f(emj, emj_data):
_entities.append({
'match_start': emj_data['match_start'],
'match_end': emj_data['match_end'],
'emoji': emj,
})

demojize(input, language='en', version=-1, handle_version=f)
return _entities

def demojize(
string,
delimiters=(":", ":"),
language='en',
version=None,
handle_version=None
):
"""
Replace unicode emoji in a string with emoji shortcodes. Useful for storage.
>>> import emoji
>>> print(emoji.emojize("Python is fun :thumbs_up:"))
Python is fun 👍
>>> print(emoji.demojize(u"Python is fun 👍"))
Python is fun :thumbs_up:
>>> print(emoji.demojize(u"Unicode is tricky 😯", delimiters=("__", "__")))
Unicode is tricky __hushed_face__

Check failure on line 36 in vllm/beam/emoji.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/beam/emoji.py:36:80: E501 Line too long (83 > 80)

:param string: String contains unicode characters. MUST BE UNICODE.
:param delimiters: (optional) User delimiters other than ``_DEFAULT_DELIMITER``
:param language: Choose language of emoji name: language code 'es', 'de', etc. or 'alias'

Check failure on line 40 in vllm/beam/emoji.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/beam/emoji.py:40:81: E501 Line too long (83 > 80)
to use English aliases

Check failure on line 41 in vllm/beam/emoji.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/beam/emoji.py:41:81: E501 Line too long (93 > 80)
:param version: (optional) Max version. If set to an Emoji Version,
all emoji above this version will be removed.
:param handle_version: (optional) Replace the emoji above ``version``
instead of removing it. handle_version can be either a string or a
callable ``handle_version(emj: str, data: dict) -> str``; If it is
a callable, it's passed the unicode emoji and the data dict from
emoji.EMOJI_DATA and must return a replacement string to be used.
The passed data is in the form of::

handle_version(u'\\U0001F6EB', {
'en' : ':airplane_departure:',
'status' : fully_qualified,
'E' : 1,
'alias' : [u':flight_departure:'],
'de': u':abflug:',
'es': u':avión_despegando:',
...
})

"""
if language == 'alias':
language = 'en'
_use_aliases = True
else:
_use_aliases = False

tree = _get_search_tree()
result = []
i = 0
length = len(string)
while i < length:
consumed = False
char = string[i]
if char in tree:
j = i + 1
sub_tree = tree[char]
while j < length and string[j] in sub_tree:
sub_tree = sub_tree[string[j]]
j += 1
if 'data' in sub_tree:
emj_data = sub_tree['data']
code_points = string[i:j]
replace_str = None
if version is not None and emj_data['E'] > version:
if callable(handle_version):
emj_data = emj_data.copy()
emj_data['match_start'] = i
emj_data['match_end'] = j
replace_str = handle_version(code_points, emj_data)
elif handle_version is not None:
replace_str = str(handle_version)
else:
replace_str = None
elif language in emj_data:
if _use_aliases and 'alias' in emj_data:
replace_str = delimiters[0] + emj_data['alias'][0][1:-1] + delimiters[1]
else:
replace_str = delimiters[0] + emj_data[language][1:-1] + delimiters[1]
else:
# The emoji exists, but it is not translated, so we keep the emoji
replace_str = code_points

i = j - 1

Check failure on line 104 in vllm/beam/emoji.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/beam/emoji.py:104:81: E501 Line too long (86 > 80)
consumed = True
if replace_str:
result.append(replace_str)

if not consumed and char != u'\ufe0e' and char != u'\ufe0f':
result.append(char)
i += 1

return "".join(result)

def _get_search_tree():
"""
Generate a search tree for demojize().
Example of a search tree::

EMOJI_DATA =
{'a': {'en': ':Apple:'},
'b': {'en': ':Bus:'},
'ba': {'en': ':Bat:'},
'band': {'en': ':Beatles:'},
'bandit': {'en': ':Outlaw:'},
'bank': {'en': ':BankOfEngland:'},
'bb': {'en': ':BB-gun:'},
'c': {'en': ':Car:'}}

_SEARCH_TREE =
{'a': {'data': {'en': ':Apple:'}},
'b': {'a': {'data': {'en': ':Bat:'},
'n': {'d': {'data': {'en': ':Beatles:'},
'i': {'t': {'data': {'en': ':Outlaw:'}}}},
'k': {'data': {'en': ':BankOfEngland:'}}}},
'b': {'data': {'en': ':BB-gun:'}},
'data': {'en': ':Bus:'}},
'c': {'data': {'en': ':Car:'}}}

_SEARCH_TREE
/ | ⧵
/ | ⧵
a b c
| / | ⧵ |
| / | ⧵ |
:Apple: ba :Bus: bb :Car:
/ ⧵ |
/ ⧵ |
:Bat: ban :BB-gun:
/ ⧵
/ ⧵
band bank
/ ⧵ |
/ ⧵ |
bandi :Beatles: :BankOfEngland:
|
bandit
|
:Outlaw:


"""
global _EMOJI_SEARCH_TREE
if _EMOJI_SEARCH_TREE is None:
_EMOJI_SEARCH_TREE = {}
for emj in EMOJI_DATA:
sub_tree = _EMOJI_SEARCH_TREE
lastidx = len(emj) - 1
for i, char in enumerate(emj):
if char not in sub_tree:
sub_tree[char] = {}
sub_tree = sub_tree[char]
if i == lastidx:
sub_tree['data'] = EMOJI_DATA[emj]
return _EMOJI_SEARCH_TREE
Loading
Loading