Skip to content

V1 embeddings #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7efcc95
Solve conflicts with upstream embedding branch
maxdebayser May 29, 2025
370ebcd
refactor input batch
maxdebayser May 29, 2025
4de58bc
add spyre pooling batch
maxdebayser Jun 2, 2025
94b6fe2
initial model runner prototype
maxdebayser Jun 3, 2025
313a20a
Merge branch 'main' into v1_embeddings
maxdebayser Jul 2, 2025
ce50b01
fix linting
maxdebayser Jul 2, 2025
cc5d996
appease isort
maxdebayser Jul 2, 2025
052b28d
Use upstream LogitsProcessors
maxdebayser Jul 9, 2025
6552783
Merge branch 'logits_processors' into v1_embeddings
maxdebayser Jul 9, 2025
bff271d
Remove attn_type from spec as this change hasn't made it upstream yet
maxdebayser Jul 10, 2025
49effcc
Revert "[Priority merge] NewRequestData parameter introduced in vllm …
maxdebayser Jul 10, 2025
b0d08d4
disable token type ids for now
maxdebayser Jul 10, 2025
9ac34b1
linting
maxdebayser Jul 10, 2025
d391a28
fix masking
maxdebayser Jul 10, 2025
8f1f12c
add missing arg
maxdebayser Jul 10, 2025
8d3c65b
fix off by one error
maxdebayser Jul 10, 2025
1ee5a6f
finish most of the model runner refactoring
maxdebayser Jul 10, 2025
8665f5f
small fixes
maxdebayser Jul 11, 2025
ede0080
Merge branch 'main' into v1_embeddings
maxdebayser Jul 14, 2025
165917a
add embedding tests for multiple requests
gmarinho2 Jul 14, 2025
3f1123a
Fix test typo and monkey patch Bert model support
maxdebayser Jul 15, 2025
dc54b8c
Merge branch 'main' into v1_embeddings
maxdebayser Jul 15, 2025
fb98ef2
fix assertion
maxdebayser Jul 15, 2025
4532431
fix _get_token_ids
maxdebayser Jul 15, 2025
4510bde
fix mistakes
maxdebayser Jul 15, 2025
5acb9d9
Merge branch 'main' into v1_embeddings
maxdebayser Jul 17, 2025
c650f1e
Merge branch 'main' into v1_embeddings
maxdebayser Jul 17, 2025
5085371
address review comments
maxdebayser Jul 21, 2025
e4a84ef
Merge branch 'main' into v1_embeddings
maxdebayser Jul 21, 2025
5f46a53
restore chili peppers
maxdebayser Jul 21, 2025
c047171
fix tests
maxdebayser Jul 21, 2025
2e8c25c
fix missing torch_sendnn initialization
maxdebayser Jul 21, 2025
9a055c3
support upstream changes
maxdebayser Jul 21, 2025
5018399
revert edit mistake
maxdebayser Jul 21, 2025
c8e2db7
appease mypy
maxdebayser Jul 21, 2025
204241c
work around upstream changes
maxdebayser Jul 21, 2025
f04dd49
compatibility with vllm 0.9.3
maxdebayser Jul 22, 2025
aa93ebd
Merge branch 'main' into v1_embeddings
maxdebayser Jul 25, 2025
7b19f99
fix merge problem
maxdebayser Jul 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion tests/e2e/test_spyre_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
Run `python -m pytest tests/e2e/test_spyre_embeddings.py`.
"""

import os

import pytest
from spyre_util import (compare_embedding_results, get_chicken_soup_prompts,
get_spyre_backend_list, get_spyre_model_list,
spyre_vllm_embeddings, st_embeddings)
from vllm import LLM


@pytest.mark.parametrize("model", get_spyre_model_list(isEmbeddings=True))
Expand All @@ -15,7 +18,7 @@
(128, 8)]) # (prompt_length/batch_size)
@pytest.mark.parametrize("backend", get_spyre_backend_list())
# TODO: Add it when v1 is supported.
@pytest.mark.parametrize("vllm_version", ["V0"])
@pytest.mark.parametrize("vllm_version", ["V0", "V1"])
def test_output(
model: str,
warmup_shape: tuple[int, int],
Expand Down Expand Up @@ -49,3 +52,94 @@ def test_output(
backend=backend,
vllm_results=vllm_results,
hf_results=hf_results)


@pytest.fixture
def example_prompts():
return [
"The capital of France is Paris.", "Hello",
"What is the weather today like?", "Who are you?"
]


@pytest.mark.parametrize("warmup_shapes", [
(64, 1),
(64, 2),
(64, 4),
]) # (prompt_length/batch_size)
@pytest.mark.parametrize("backend", get_spyre_backend_list())
@pytest.mark.parametrize("model", get_spyre_model_list(isEmbeddings=True))
@pytest.mark.parametrize("vllm_version", ["V0", "V1"])
def test_scheduling_invariance(
example_prompts,
model,
backend,
warmup_shapes,
vllm_version,
) -> None:

os.environ["VLLM_SPYRE_DYNAMO_BACKEND"] = backend
os.environ['VLLM_USE_V1'] = "1" if vllm_version == "V1" else "0"

prompts = [str(s).strip() for s in example_prompts]
reference_embeds = st_embeddings(model, example_prompts)

vllm_model = LLM(model=model,
task="embed",
tokenizer=model,
max_model_len=256,
block_size=256,
tensor_parallel_size=1)

# Four requests with one prompt each
results = []
for i in range(4):
results.append(vllm_model.embed(prompts[i]))

vllm_outputs = []
for req_output in results:
result = {'embeddings': req_output[0].outputs.embedding}
vllm_outputs.append(result)

compare_embedding_results(model=model,
prompts=example_prompts,
warmup_shapes=[warmup_shapes],
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_outputs,
hf_results=reference_embeds)

# Two requests with two prompt each
results = []
for i in range(2):
results.append(vllm_model.embed([prompts[i * 2], prompts[i * 2 + 1]]))

vllm_outputs = []
for req_output in results:
result1 = {'embeddings': req_output[0].outputs.embedding}
result2 = {'embeddings': req_output[1].outputs.embedding}

vllm_outputs.extend([result1, result2])

compare_embedding_results(model=model,
prompts=example_prompts,
warmup_shapes=[warmup_shapes],
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_outputs,
hf_results=reference_embeds)

# One requests with four prompts
results = vllm_model.embed(prompts)
vllm_outputs = []
for req_output in results:
result = {'embeddings': req_output.outputs.embedding}
vllm_outputs.append(result)

compare_embedding_results(model=model,
prompts=example_prompts,
warmup_shapes=[warmup_shapes],
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_outputs,
hf_results=reference_embeds)
8 changes: 4 additions & 4 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def force_engine_shutdown(llm: LLM):
"""
🌶️🌶️🌶
�️�️�
This hack is here because of an issue in vllm 0.9.2+ where a circular
reference occurs in vllm.executor.ray_utils if ray is not installed. This
circular reference holds a copy of the vllm config which contains a
Expand All @@ -35,7 +35,7 @@ def force_engine_shutdown(llm: LLM):
engine is never shut down then the TP worker processes are never killed.
When the TP worker processes are held open, all future attempts to create a
new engine will fail with an EADDRINUSE error.
🌶️🌶️🌶
�️�️�
"""
llm.llm_engine.engine_core.shutdown()

Expand Down Expand Up @@ -523,7 +523,7 @@ def create_text_prompt(model: str, min_tokens: int, max_tokens: int) -> str:
"""Create a text prompt for the specified model that will tokenize to within
the specified token length range."""
tokenizer = AutoTokenizer.from_pretrained(model)
pepper = "🌶️"
pepper = "️"
pepper_tokens = len(tokenizer.encode(pepper, add_special_tokens=False))

# Find a good starting number of peppers
Expand All @@ -549,11 +549,11 @@ def create_random_request(
mm_hashes=None,
mm_placeholders=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=None,
arrival_time=0,
lora_request=None,
data_parallel_rank=None,
pooling_params=None,
cache_salt=None)


Expand Down
24 changes: 12 additions & 12 deletions tests/v1/worker/test_spyre_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
init_builtin_logitsprocs)
from vllm.v1.sample.metadata import SamplingMetadata

from vllm_spyre.v1.worker.spyre_input_batch import (CachedRequestState,
InputBatch)
from vllm_spyre.v1.worker.spyre_input_batch import (SamplingInputBatch,
SamplingRequestState)
Comment on lines -14 to +15
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you quickly explain why this is needed? is this related to the embedding support, or a general change upstream?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having read more of the PR, I now understand 😄

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just to make clear for others who might read this later, the CachedRequestState class was refactored into a base class BaseRequestState that holds common attributes and methods and two derived classes SamplingRequestState and PoolingRequestState that are specialized for the generation and pooling use cases.


VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
MAX_NUM_PROMPT_TOKENS = 64


def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
def _remove_requests(input_batch: SamplingInputBatch, batch_size: int,
reqs: list[SamplingRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns a set of
request ids removed
Expand All @@ -41,9 +41,9 @@ def _remove_requests(input_batch: InputBatch, batch_size: int,


def _construct_expected_sampling_metadata(
reqs: list[CachedRequestState],
reqs: list[SamplingRequestState],
req_ids_retained: set[int],
input_batch: InputBatch,
input_batch: SamplingInputBatch,
device: torch.device,
) -> SamplingMetadata:
"""
Expand Down Expand Up @@ -162,7 +162,7 @@ def _construct_cached_request_state(req_id_suffix: int):
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
return SamplingRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
Expand Down Expand Up @@ -226,18 +226,18 @@ def test_sampling_metadata_in_input_batch(batch_size: int):
"""

device = torch.device('cpu')
input_batch: InputBatch = InputBatch(
input_batch: SamplingInputBatch = SamplingInputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
device=device,
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
reqs: list[CachedRequestState] = []
reqs: list[SamplingRequestState] = []
req_id_reqs = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
req: SamplingRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
reqs.append(req)
req_id_reqs[req.req_id] = req
Expand All @@ -260,8 +260,8 @@ def test_sampling_metadata_in_input_batch(batch_size: int):

# Add more requests
for req_index in range(len(req_ids_to_remove)):
req: CachedRequestState = _construct_cached_request_state(req_index +
batch_size)
req: SamplingRequestState = _construct_cached_request_state(req_index +
batch_size)
input_batch.add_request(req)
reqs.append(req)
req_ids_retained.add(req.req_id)
Expand Down
54 changes: 41 additions & 13 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,44 @@
logger = init_logger(__name__)


class classproperty:

def __init__(self, func):
self.func = func

def __get__(self, instance, owner):
return self.func(owner)


@property # type: ignore
def is_v1_compatible(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
patterns = ["Bert", "Roberta"]
if any(pat in arch for arch in architectures for pat in patterns):
return True
import vllm.model_executor.models as me_models
return me_models.ModelRegistry.is_v1_compatible(architectures)


class SpyrePlatform(Platform):
_enum = PlatformEnum.OOT

# "spyre" device_name no longer worked due to https://github.com/vllm-project/vllm/pull/16464
device_name: str = "cpu"
device_type: str = "cpu"
_device_type: str = "cpu"
supported_quantization: list[str] = ["gptq"]
_warmup_shapes: Optional[tuple[dict[str, int], ...]] = None
_block_size: int = 64 # hardcoded Spyre constraint for now
_config: VllmConfig = None

@classproperty
def device_type(cls):
# TODO: temporary hack while BertModels
# inherit SupportsV0Only in vllm upstream.
from vllm.config import ModelConfig
ModelConfig.is_v1_compatible = is_v1_compatible
return cls._device_type

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "spyre"
Expand All @@ -68,18 +95,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
raise NotImplementedError

is_decoder = model_config.task == "generate"
is_embedding = model_config.task == "embed"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the name change?

Copy link
Collaborator Author

@maxdebayser maxdebayser Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because because embedding is a special case of pooling. There have been some changes in the upstream to clean up the confusion around runner types, tasks and APIs. But I'm not sure if this should already go into this PR or into a follow-up PR.

is_pooling = model_config.task == "embed"
if model_config.task == "auto":
is_embedding = "embed" in model_config.supported_tasks
is_pooling = "embed" in model_config.supported_tasks
is_decoder = "generate" in model_config.supported_tasks

# v0 is only supported for embedding models, and embedding models must
# be run on v0
if is_embedding and envs.VLLM_USE_V1:
raise ValueError("Embedding models are only supported on v0")
elif is_decoder and not envs.VLLM_USE_V1:
if is_decoder and not envs.VLLM_USE_V1:
raise ValueError("Decoder models are only supported on v1")
elif not is_decoder and not is_embedding:
elif not is_decoder and not is_pooling:
raise ValueError("Only the 'generate' and 'embed' tasks are "
"supported")

Expand Down Expand Up @@ -121,9 +144,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config.scheduler_cls = (
"vllm_spyre.v1.core.scheduler."\
"StaticBatchingSpyreScheduler")
elif is_embedding:
scheduler_config.scheduler_cls = (
"vllm_spyre.core.scheduler.SpyreScheduler")
elif is_pooling:
if not envs.VLLM_USE_V1:
scheduler_config.scheduler_cls = (
"vllm_spyre.core.scheduler.SpyreScheduler")
else:
scheduler_config.scheduler_cls = (
"vllm_spyre.v1.core.scheduler."\
"StaticBatchingSpyreScheduler")

# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
Expand Down Expand Up @@ -243,7 +271,7 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
model configuration.
"""
# We don't have an embedding runner for v1 yet
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove the comment as well ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks, for pointing this out!

return model_config.task != "embed"
return True

@classmethod
def validate_request(
Expand Down
Loading