-
Notifications
You must be signed in to change notification settings - Fork 18
V1 embeddings #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
V1 embeddings #277
Changes from 27 commits
7efcc95
370ebcd
4de58bc
94b6fe2
313a20a
ce50b01
cc5d996
052b28d
6552783
bff271d
49effcc
b0d08d4
9ac34b1
d391a28
8f1f12c
8d3c65b
1ee5a6f
8665f5f
ede0080
165917a
3f1123a
dc54b8c
fb98ef2
4532431
4510bde
5acb9d9
c650f1e
5085371
e4a84ef
5f46a53
c047171
2e8c25c
9a055c3
5018399
c8e2db7
204241c
f04dd49
aa93ebd
7b19f99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,17 +11,17 @@ | |
init_builtin_logitsprocs) | ||
from vllm.v1.sample.metadata import SamplingMetadata | ||
|
||
from vllm_spyre.v1.worker.spyre_input_batch import (CachedRequestState, | ||
InputBatch) | ||
from vllm_spyre.v1.worker.spyre_input_batch import (SamplingInputBatch, | ||
SamplingRequestState) | ||
Comment on lines
-14
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you quickly explain why this is needed? is this related to the embedding support, or a general change upstream? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. having read more of the PR, I now understand 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, just to make clear for others who might read this later, the |
||
|
||
VOCAB_SIZE = 1024 | ||
NUM_OUTPUT_TOKENS = 20 | ||
MAX_PROMPT_SIZE = 100 | ||
MAX_NUM_PROMPT_TOKENS = 64 | ||
|
||
|
||
def _remove_requests(input_batch: InputBatch, batch_size: int, | ||
reqs: list[CachedRequestState]) -> set[str]: | ||
def _remove_requests(input_batch: SamplingInputBatch, batch_size: int, | ||
reqs: list[SamplingRequestState]) -> set[str]: | ||
""" | ||
Remove some requests randomly from the batch and returns a set of | ||
request ids removed | ||
|
@@ -41,9 +41,9 @@ def _remove_requests(input_batch: InputBatch, batch_size: int, | |
|
||
|
||
def _construct_expected_sampling_metadata( | ||
reqs: list[CachedRequestState], | ||
reqs: list[SamplingRequestState], | ||
req_ids_retained: set[int], | ||
input_batch: InputBatch, | ||
input_batch: SamplingInputBatch, | ||
device: torch.device, | ||
) -> SamplingMetadata: | ||
""" | ||
|
@@ -162,7 +162,7 @@ def _construct_cached_request_state(req_id_suffix: int): | |
np.random.randint(0, VOCAB_SIZE) | ||
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) | ||
] | ||
return CachedRequestState( | ||
return SamplingRequestState( | ||
req_id=f"req_id_{req_id_suffix}", | ||
prompt_token_ids=prompt_token_ids, | ||
sampling_params=_create_sampling_params(), | ||
|
@@ -226,18 +226,18 @@ def test_sampling_metadata_in_input_batch(batch_size: int): | |
""" | ||
|
||
device = torch.device('cpu') | ||
input_batch: InputBatch = InputBatch( | ||
input_batch: SamplingInputBatch = SamplingInputBatch( | ||
max_num_reqs=batch_size, | ||
max_model_len=1024, | ||
device=device, | ||
pin_memory=is_pin_memory_available(), | ||
vocab_size=1024, | ||
) | ||
reqs: list[CachedRequestState] = [] | ||
reqs: list[SamplingRequestState] = [] | ||
req_id_reqs = {} | ||
# Add requests | ||
for req_index in range(batch_size): | ||
req: CachedRequestState = _construct_cached_request_state(req_index) | ||
req: SamplingRequestState = _construct_cached_request_state(req_index) | ||
input_batch.add_request(req, req_index) | ||
reqs.append(req) | ||
req_id_reqs[req.req_id] = req | ||
|
@@ -260,8 +260,8 @@ def test_sampling_metadata_in_input_batch(batch_size: int): | |
|
||
# Add more requests | ||
for req_index in range(len(req_ids_to_remove)): | ||
req: CachedRequestState = _construct_cached_request_state(req_index + | ||
batch_size) | ||
req: SamplingRequestState = _construct_cached_request_state(req_index + | ||
batch_size) | ||
input_batch.add_request(req) | ||
reqs.append(req) | ||
req_ids_retained.add(req.req_id) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,17 +34,44 @@ | |
logger = init_logger(__name__) | ||
|
||
|
||
class classproperty: | ||
|
||
def __init__(self, func): | ||
self.func = func | ||
|
||
def __get__(self, instance, owner): | ||
return self.func(owner) | ||
|
||
|
||
@property # type: ignore | ||
def is_v1_compatible(self) -> bool: | ||
architectures = getattr(self.hf_config, "architectures", []) | ||
patterns = ["Bert", "Roberta"] | ||
if any(pat in arch for arch in architectures for pat in patterns): | ||
return True | ||
import vllm.model_executor.models as me_models | ||
return me_models.ModelRegistry.is_v1_compatible(architectures) | ||
|
||
|
||
class SpyrePlatform(Platform): | ||
_enum = PlatformEnum.OOT | ||
|
||
# "spyre" device_name no longer worked due to https://github.com/vllm-project/vllm/pull/16464 | ||
device_name: str = "cpu" | ||
device_type: str = "cpu" | ||
_device_type: str = "cpu" | ||
supported_quantization: list[str] = ["gptq"] | ||
_warmup_shapes: Optional[tuple[dict[str, int], ...]] = None | ||
_block_size: int = 64 # hardcoded Spyre constraint for now | ||
_config: VllmConfig = None | ||
|
||
@classproperty | ||
def device_type(cls): | ||
# TODO: temporary hack while BertModels | ||
# inherit SupportsV0Only in vllm upstream. | ||
from vllm.config import ModelConfig | ||
ModelConfig.is_v1_compatible = is_v1_compatible | ||
return cls._device_type | ||
|
||
@classmethod | ||
def get_device_name(cls, device_id: int = 0) -> str: | ||
return "spyre" | ||
|
@@ -68,18 +95,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
raise NotImplementedError | ||
|
||
is_decoder = model_config.task == "generate" | ||
is_embedding = model_config.task == "embed" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the name change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because because embedding is a special case of pooling. There have been some changes in the upstream to clean up the confusion around runner types, tasks and APIs. But I'm not sure if this should already go into this PR or into a follow-up PR. |
||
is_pooling = model_config.task == "embed" | ||
if model_config.task == "auto": | ||
is_embedding = "embed" in model_config.supported_tasks | ||
is_pooling = "embed" in model_config.supported_tasks | ||
is_decoder = "generate" in model_config.supported_tasks | ||
|
||
# v0 is only supported for embedding models, and embedding models must | ||
# be run on v0 | ||
if is_embedding and envs.VLLM_USE_V1: | ||
raise ValueError("Embedding models are only supported on v0") | ||
elif is_decoder and not envs.VLLM_USE_V1: | ||
if is_decoder and not envs.VLLM_USE_V1: | ||
raise ValueError("Decoder models are only supported on v1") | ||
elif not is_decoder and not is_embedding: | ||
elif not is_decoder and not is_pooling: | ||
raise ValueError("Only the 'generate' and 'embed' tasks are " | ||
"supported") | ||
|
||
|
@@ -121,9 +144,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
scheduler_config.scheduler_cls = ( | ||
"vllm_spyre.v1.core.scheduler."\ | ||
"StaticBatchingSpyreScheduler") | ||
elif is_embedding: | ||
scheduler_config.scheduler_cls = ( | ||
"vllm_spyre.core.scheduler.SpyreScheduler") | ||
elif is_pooling: | ||
if not envs.VLLM_USE_V1: | ||
scheduler_config.scheduler_cls = ( | ||
"vllm_spyre.core.scheduler.SpyreScheduler") | ||
else: | ||
scheduler_config.scheduler_cls = ( | ||
"vllm_spyre.v1.core.scheduler."\ | ||
"StaticBatchingSpyreScheduler") | ||
|
||
# To disable any paged attention ops in the base scheduler, we: | ||
# - Set the block size (in tokens) to the maximum sequence length | ||
|
@@ -243,7 +271,7 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: | |
model configuration. | ||
""" | ||
# We don't have an embedding runner for v1 yet | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove the comment as well ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thanks, for pointing this out! |
||
return model_config.task != "embed" | ||
return True | ||
|
||
@classmethod | ||
def validate_request( | ||
|
Uh oh!
There was an error while loading. Please reload this page.