diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index fc5ca23787be..1114033d5cea 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -12,7 +12,10 @@ def parse_args(): parser = EngineArgs.add_cli_args(parser) # Set example specific arguments parser.set_defaults( - model="intfloat/e5-mistral-7b-instruct", task="embed", enforce_eager=True + model="intfloat/e5-mistral-7b-instruct", + task="embed", + enforce_eager=True, + max_model_len=1024, ) return parser.parse_args() diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/vision_language_embedding.py index 1f5bd4ad72b0..9451825f0b73 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/vision_language_embedding.py @@ -94,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData: engine_args = EngineArgs( model="TIGER-Lab/VLM2Vec-Full", task="embed", + max_model_len=4096, trust_remote_code=True, mm_processor_kwargs={"num_crops": 4}, limit_mm_per_prompt={"image": 1}, diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index dc6cfe9daccd..1ee9b234d9f4 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -31,7 +31,7 @@ class TestSetting: # basic llama model TestSetting( model="meta-llama/Llama-3.2-1B-Instruct", - model_args=[], + model_args=["--max-model-len", "2048"], pp_size=2, tp_size=2, attn_backend="FLASHINFER", @@ -41,7 +41,7 @@ class TestSetting: # llama model with quantization TestSetting( model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - model_args=["--quantization", "gptq"], + model_args=["--quantization", "gptq", "--max-model-len", "2048"], pp_size=1, tp_size=1, attn_backend="FLASH_ATTN", @@ -51,7 +51,7 @@ class TestSetting: # MoE model TestSetting( model="ibm/PowerMoE-3b", - model_args=[], + model_args=["--max-model-len", "2048"], pp_size=1, tp_size=2, attn_backend="FLASH_ATTN", @@ -61,23 +61,27 @@ class TestSetting: # embedding model TestSetting( model="BAAI/bge-multilingual-gemma2", - model_args=["--task", "embed", "--dtype", "bfloat16"], + model_args=[ + "--task", "embed", "--dtype", "bfloat16", "--max-model-len", + "2048" + ], pp_size=1, tp_size=1, attn_backend="FLASH_ATTN", method="encode", fullgraph=True, ), - # encoder-based embedding model (BERT) - TestSetting( - model="BAAI/bge-base-en-v1.5", - model_args=["--task", "embed"], - pp_size=1, - tp_size=1, - attn_backend="XFORMERS", - method="encode", - fullgraph=True, - ), + # TODO: bert models are not supported in V1 yet + # # encoder-based embedding model (BERT) + # TestSetting( + # model="BAAI/bge-base-en-v1.5", + # model_args=["--task", "embed"], + # pp_size=1, + # tp_size=1, + # attn_backend="XFORMERS", + # method="encode", + # fullgraph=True, + # ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", diff --git a/tests/conftest.py b/tests/conftest.py index 294805a8164f..ff564b2b8ed5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,6 +145,7 @@ def run_with_both_engines(request, monkeypatch): # Automatically runs tests twice, once with V1 and once without use_v1 = request.param # Tests decorated with `@skip_v1` are only run without v1 + skip_v0 = request.node.get_closest_marker("skip_v0") skip_v1 = request.node.get_closest_marker("skip_v1") if use_v1: @@ -152,6 +153,8 @@ def run_with_both_engines(request, monkeypatch): pytest.skip("Skipping test on vllm V1") monkeypatch.setenv('VLLM_USE_V1', '1') else: + if skip_v0: + pytest.skip("Skipping test on vllm V0") monkeypatch.setenv('VLLM_USE_V1', '0') yield diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index f0fa54aa3131..b930f05bebd0 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -8,6 +8,8 @@ from vllm import LLM, PoolingParams, PoolingRequestOutput from vllm.distributed import cleanup_dist_env_and_memory +from ...models.utils import check_embeddings_close + MODEL_NAME = "intfloat/multilingual-e5-small" PROMPTS = [ @@ -27,6 +29,14 @@ ] +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to @@ -46,9 +56,15 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: list[PoolingRequestOutput], +def assert_outputs_match(o1: list[PoolingRequestOutput], o2: list[PoolingRequestOutput]): - assert [o.outputs for o in o1] == [o.outputs for o in o2] + check_embeddings_close( + embeddings_0_lst=[o.outputs.data for o in o1], + embeddings_1_lst=[o.outputs.data for o in o2], + name_0="hf", + name_1="vllm", + tol=1e-2, + ) @pytest.mark.skip_global_cleanup @@ -63,7 +79,7 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) + assert_outputs_match(v1_output, v2_output) @pytest.mark.skip_global_cleanup @@ -80,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): } for p in TOKEN_IDS], pooling_params=pooling_params, ) - assert_outputs_equal(v1_output, v2_output) + assert_outputs_match(v1_output, v2_output) @pytest.mark.skip_global_cleanup diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 80640a2e1a8b..adb094127e40 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -21,6 +21,14 @@ DTYPE = "bfloat16" +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.fixture(scope="module") def server(): args = [ diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/openai/test_pooling.py index cf16ace6537a..41c30e71684b 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/openai/test_pooling.py @@ -7,6 +7,7 @@ import pytest import requests +from tests.models.utils import check_embeddings_close from vllm.entrypoints.openai.protocol import PoolingResponse from vllm.transformers_utils.tokenizer import get_tokenizer @@ -223,8 +224,11 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, np.frombuffer(base64.b64decode(data.data), dtype="float32").tolist()) - assert responses_float.data[0].data == decoded_responses_base64_data[0] - assert responses_float.data[1].data == decoded_responses_base64_data[1] + check_embeddings_close( + embeddings_0_lst=[d.data for d in responses_float.data], + embeddings_1_lst=decoded_responses_base64_data, + name_0="float32", + name_1="base64") # Default response is float32 decoded from base64 by OpenAI Client default_response = requests.post( @@ -237,5 +241,8 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, default_response.raise_for_status() responses_default = PoolingResponse.model_validate(default_response.json()) - assert responses_float.data[0].data == responses_default.data[0].data - assert responses_float.data[1].data == responses_default.data[1].data + check_embeddings_close( + embeddings_0_lst=[d.data for d in responses_default.data], + embeddings_1_lst=[d.data for d in responses_default.data], + name_0="float32", + name_1="base64") diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index 19eba320c279..e40bbca9a8ad 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -12,6 +12,14 @@ DTYPE = "bfloat16" +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.fixture(scope="module") def server(): args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE] diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index af51a0a3eeeb..8927fe771809 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -11,6 +11,15 @@ from ...utils import RemoteOpenAIServer + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + MODELS = [ { "name": "BAAI/bge-reranker-v2-m3", diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 4a6d781ce6f0..77df6d16a367 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -6,6 +6,14 @@ from vllm.platforms import current_platform +# TODO: enable when float32 is supported by V1 +# @pytest.fixture(autouse=True) +# def v1(run_with_both_engines): +# # Simple autouse wrapper to run both engines for each test +# # This can be promoted up to conftest.py to run for every +# # test in a package +# pass + @pytest.mark.parametrize( "model", @@ -29,7 +37,7 @@ def test_models( # switch to use ROCm CK FA backend monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) with hf_runner(model, diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 9516a01421cb..e29b4f6e8bec 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -8,6 +8,14 @@ from ...utils import check_embeddings_close +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize( "model", [ @@ -20,15 +28,27 @@ marks=[pytest.mark.core_model]), pytest.param("intfloat/e5-mistral-7b-instruct", marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), + # the qwen models interfere with each other (see PR + # https://github.com/vllm-project/vllm/pull/18720). + # To avoid this problem, for now we skip v0 since it will be + # deprecated anyway. + pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", + marks=[pytest.mark.skip_v0]), # [Encoder-only] pytest.param("BAAI/bge-base-en-v1.5", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), - pytest.param("sentence-transformers/all-MiniLM-L12-v2"), - pytest.param("intfloat/multilingual-e5-small"), - pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), + marks=[ + pytest.mark.core_model, pytest.mark.cpu_model, + pytest.mark.skip_v1 + ]), + pytest.param("sentence-transformers/all-MiniLM-L12-v2", + marks=[pytest.mark.skip_v1]), + pytest.param("intfloat/multilingual-e5-small", + marks=[pytest.mark.skip_v1]), + pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + marks=[pytest.mark.skip_v1]), # [Cross-Encoder] - pytest.param("sentence-transformers/stsb-roberta-base-v2"), + pytest.param("sentence-transformers/stsb-roberta-base-v2", + marks=[pytest.mark.skip_v1]), ], ) def test_models( @@ -62,7 +82,7 @@ def test_models( with vllm_runner(model, task="embed", - max_model_len=None, + max_model_len=512, **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) diff --git a/tests/models/registry.py b/tests/models/registry.py index fb93ba60c2e8..82253a1c94b3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -265,8 +265,8 @@ def check_available_online( _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] - "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), - "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), + "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), + "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -279,16 +279,16 @@ def check_available_online( "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", - trust_remote_code=True), + trust_remote_code=True, v0_only=True), "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", - trust_remote_code=True), + trust_remote_code=True, v0_only=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 - "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 - "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 - "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), + "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 + "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", @@ -300,10 +300,10 @@ def check_available_online( _CROSS_ENCODER_EXAMPLE_MODELS = { # [Text-only] - "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 - "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 - "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501 + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 } _MULTIMODAL_EXAMPLE_MODELS = { diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9f2414eca24f..f8aeba8301b1 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer, None, params, None, + None, 0.0, None, cache_salt=None, diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 347f98c772ff..e80ad8a68151 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -43,6 +43,7 @@ def make_request(request_id, multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, eos_token_id=100, lora_request=None, cache_salt=cache_salt, diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 394336624aca..7a42778831c5 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -39,6 +39,7 @@ def make_request(request_id, multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, eos_token_id=100, lora_request=None, cache_salt=cache_salt, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d348956aa177..b0b1116eb536 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -135,6 +135,7 @@ def create_requests(num_requests: int, request_id=f"{i}", prompt_token_ids=[i] * num_tokens, sampling_params=sampling_params, + pooling_params=None, multi_modal_inputs=mm_inputs, multi_modal_placeholders=mm_position, multi_modal_hashes=None, @@ -283,6 +284,7 @@ def test_schedule_partial_requests(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) @@ -333,6 +335,7 @@ def test_no_mm_input_chunking(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) @@ -396,6 +399,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) scheduler.update_from_output(output, model_runner_output) @@ -420,6 +424,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) scheduler.update_from_output(output1, model_runner_output) output2 = scheduler.schedule() @@ -473,7 +478,8 @@ def test_stop_via_update_from_output(): 11]], # First request hits EOS, second continues spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}) + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -523,7 +529,8 @@ def test_stop_via_update_from_output(): [13, 14]], # First request hits stop token spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}) + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -572,7 +579,8 @@ def test_stop_via_update_from_output(): [13]], # First request exceeds max_tokens spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}) + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -614,7 +622,8 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], spec_token_ids=None, logprobs=None, - prompt_logprobs_dict={}) + prompt_logprobs_dict={}, + pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) @@ -663,6 +672,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) scheduler.update_from_output(scheduler_output0, model_runner_output) @@ -680,6 +690,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) scheduler.update_from_output(scheduler_output1, model_runner_output) @@ -730,6 +741,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) @@ -769,6 +781,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) @@ -896,6 +909,7 @@ def test_kv_connector_basic(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) # Ensure ScheduleOutput is correct. @@ -941,6 +955,7 @@ def test_kv_connector_basic(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) # We should get a local cache hit of NUM_TOKENS_PREFIX and @@ -1007,6 +1022,7 @@ def test_kv_connector_unable_to_allocate(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) # Just one request should be running. @@ -1087,6 +1103,7 @@ def test_kv_connector_handles_preemption(): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) # All can be scheduled - 1st token. @@ -1181,6 +1198,7 @@ def make_output(scheduler: Scheduler): spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], ) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index bc7894e92814..bbdc73e9608a 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -39,6 +39,7 @@ def make_request() -> EngineCoreRequest: mm_hashes=None, mm_placeholders=None, sampling_params=SamplingParams(), + pooling_params=None, eos_token_id=None, arrival_time=time.time(), lora_request=None, diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index d4db16fe86fa..16c36cd5c6b9 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -53,6 +53,7 @@ def make_request( mm_hashes=None, mm_placeholders=None, sampling_params=params, + pooling_params=None, eos_token_id=None, arrival_time=time.time(), lora_request=None, diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py index 5c844e0e7095..f028b4ab1d73 100644 --- a/tests/v1/engine/test_fast_incdec_prefix_err.py +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -33,6 +33,7 @@ def test_fast_inc_detok_invalid_utf8_err_case(): None, params, None, + None, 0.0, None, cache_salt=None, diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 6b88b0cf17e3..1c8c5f25e29b 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, output_kind=request_output_kind, stop=[], include_stop_str_in_output=False, - )) + ), + pooling_params=None) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, include_stop_str_in_output=False, logprobs=num_sample_logprobs, prompt_logprobs=num_prompt_logprobs, - )) + ), + pooling_params=None) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool, logprobs=num_sample_logprobs, prompt_logprobs=None, ignore_eos=ignore_eos, - )) + ), + pooling_params=None) # Add request to the detokenizer. output_processor.add_request(request, prompt_string) @@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool, include_stop_str_in_output=include_stop_str_in_output, logprobs=num_sample_logprobs, prompt_logprobs=None, - )) + ), + pooling_params=None) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] @@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors): cache_salt=None, data_parallel_rank=None, sampling_params=SamplingParams(), + pooling_params=None, ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 4a9e3a7ad807..61f59f35f75b 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -150,6 +150,7 @@ def create_request( request_id=f"id-{request_id}", prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, + pooling_params=None, multi_modal_inputs=None, multi_modal_placeholders=None, multi_modal_hashes=None, @@ -183,6 +184,7 @@ def create_model_runner_output( spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=None, finished_sending=finished_sending, finished_recving=finished_recving, ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index de6ebe4f6716..9e5e06cdc1f5 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -46,7 +47,7 @@ def _compare_objs(obj1, obj2): for a_i, b_i in zip(a.block_tables, b.block_tables): _compare_objs(a_i, b_i) is_same = True - elif isinstance(a, (BlockTable, SamplingMetadata)): + elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)): _compare_objs(a, b) is_same = True # if we make it here must be same elif a == b: @@ -201,6 +202,7 @@ def _construct_cached_request_state(req_id_suffix: int): req_id=f"req_id_{req_id_suffix}", prompt_token_ids=prompt_token_ids, sampling_params=_create_sampling_params(), + pooling_params=None, mm_inputs=[], mm_positions=[], block_ids=([], ), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 994432dfd593..abf14a8fb625 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -122,6 +122,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), + pooling_params=None, block_ids=([0], ), num_computed_tokens=0, lora_request=None, diff --git a/vllm/config.py b/vllm/config.py index 7a9bc8a4f7af..54c7a497b261 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4496,11 +4496,31 @@ def __post_init__(self): if self.compilation_config.full_cuda_graph and \ not self.model_config.disable_cascade_attn: - logger.warning_once( - "full_cuda_graph is not supported with " - "cascade attention. Disabling cascade attention.") + logger.info("full_cuda_graph is not supported with " + "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True + disable_chunked_prefill_reasons: list[str] = [] + + if self.model_config and self.model_config.pooler_config: + pooling_type = self.model_config.pooler_config.pooling_type + if pooling_type is None or pooling_type.lower() != "last": + disable_chunked_prefill_reasons.append( + "Only \"last\" pooling supports chunked " + "prefill and prefix caching; disabling both.") + + if disable_chunked_prefill_reasons: + for reason in disable_chunked_prefill_reasons: + logger.info(reason) + self.scheduler_config.chunked_prefill_enabled = False + self.scheduler_config.long_prefill_token_threshold = 0 + self.scheduler_config.max_num_batched_tokens = max( + self.scheduler_config.max_model_len, + DEFAULT_MAX_NUM_BATCHED_TOKENS) + + if self.cache_config is not None: + self.cache_config.enable_prefix_caching = False + if (self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events and not self.cache_config.enable_prefix_caching): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4ca645b91b33..7a88e3269a5e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1041,7 +1041,7 @@ def create_engine_config( # Set default arguments for V0 or V1 Engine. if use_v1: - self._set_default_args_v1(usage_context) + self._set_default_args_v1(usage_context, model_config) else: self._set_default_args_v0(model_config) @@ -1349,13 +1349,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No Embedding Models so far. - if model_config.task not in ["generate"]: - _raise_or_fallback(feature_name=f"--task {model_config.task}", - recommend_to_remove=False) - return False - - # No Encoder-Decoder, not all Mamba so far. + # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, recommend_to_remove=False) @@ -1523,15 +1517,38 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None: if self.max_num_seqs is None: self.max_num_seqs = 256 - def _set_default_args_v1(self, usage_context: UsageContext) -> None: + def _set_default_args_v1(self, usage_context: UsageContext, + model_config: ModelConfig) -> None: """Set Default Arguments for V1 Engine.""" - # V1 always uses chunked prefills. - self.enable_chunked_prefill = True + # V1 always uses chunked prefills and prefix caching + # for non-pooling tasks. + # For pooling tasks the default is False + if model_config.runner_type != "pooling": + self.enable_chunked_prefill = True + if self.enable_prefix_caching is None: + self.enable_prefix_caching = True + else: + + pooling_type = model_config.pooler_config.pooling_type + + # TODO: when encoder models are supported we'll have to + # check for causal attention here. + incremental_prefill_supported = (pooling_type is not None and + pooling_type.lower() == "last") - # V1 enables prefix caching by default. - if self.enable_prefix_caching is None: - self.enable_prefix_caching = True + action = "Enabling" if \ + incremental_prefill_supported else "Disabling" + + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = incremental_prefill_supported + logger.info("(%s) chunked prefill by default", action) + if self.enable_prefix_caching is None: + self.enable_prefix_caching = incremental_prefill_supported + logger.info("(%s) prefix caching by default", action) + + if not self.enable_chunked_prefill: + self.max_num_batched_tokens = model_config.max_model_len # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c11e627ee236..f3170fa30fce 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1266,7 +1266,7 @@ def score( # the tokenizer for models such as # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing # lists of tokens to the `text` and `text_pair` kwargs - tokenizer = self.llm_engine.get_tokenizer() + tokenizer = self.get_tokenizer() def ensure_str(prompt: SingletonPrompt): if isinstance(prompt, dict): diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index b896cc46b9d0..c2ed50d04d12 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -9,6 +9,7 @@ import jinja2 import numpy as np +import torch from fastapi import Request from typing_extensions import assert_never @@ -39,7 +40,8 @@ def _get_data( elif encoding_format == "base64": # Force to use float32 for base64 encoding # to match the OpenAI python client behavior - pooling_bytes = np.array(output.data, dtype="float32").tobytes() + pt_float32 = output.data.to(dtype=torch.float32) + pooling_bytes = np.array(pt_float32, dtype="float32").tobytes() return base64.b64encode(pooling_bytes).decode("utf-8") assert_never(encoding_format) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 6829d93d2d6c..eb2148d76452 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -10,11 +10,15 @@ from typing_extensions import assert_never from vllm.config import ModelConfig, PoolerConfig -from vllm.model_executor.pooling_metadata import (PoolingMetadata, - PoolingTensors) +from vllm.model_executor.pooling_metadata import ( # noqa: E501 + PoolingMetadata as V0PoolingMetadata) +from vllm.model_executor.pooling_metadata import PoolingTensors from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) +from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata + +PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] class PoolingType(IntEnum): @@ -75,15 +79,18 @@ def __init__(self, *, normalize: bool, softmax: bool) -> None: def get_prompt_lens( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + assert isinstance(hidden_states, torch.Tensor) return PoolingTensors.from_pooling_metadata( pooling_metadata, hidden_states.device).prompt_lens def extract_states( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: raise NotImplementedError @@ -93,7 +100,7 @@ def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput: def forward( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) @@ -106,11 +113,19 @@ class CLSPool(SimplePooler): def extract_states( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + if isinstance(hidden_states, list): + result = [] + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with CLS pooling" + result.append(req_state[0]) + return result + first_token_flat_indices = torch.zeros_like(prompt_lens) first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] return hidden_states[first_token_flat_indices] @@ -120,9 +135,12 @@ class LastPool(SimplePooler): def extract_states( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: + if isinstance(hidden_states, list): + return [h[-1] for h in hidden_states] + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 @@ -133,11 +151,17 @@ class AllPool(SimplePooler): def extract_states( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + if isinstance(hidden_states, list): + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with ALL pooling" + return hidden_states + offset = 0 pooled_data = list[torch.Tensor]() for prompt_len in prompt_lens: @@ -151,11 +175,20 @@ class MeanPool(SimplePooler): def extract_states( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + if isinstance(hidden_states, list): + result = [] + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with mean pooling" + result.append(torch.mean(req_state, dim=0, + dtype=torch.float32)) + return result + # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) @@ -184,30 +217,53 @@ def __init__( self.step_tag_id = step_tag_id self.returned_token_ids = returned_token_ids + def get_prompt_token_ids( + self, + pooling_metadata: PoolingMetadata, + ) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) + ] + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + def extract_states( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> Union[list[torch.Tensor], torch.Tensor]: prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) - returned_token_ids = self.returned_token_ids - if returned_token_ids is not None and len(returned_token_ids) > 0: - hidden_states = hidden_states[:, returned_token_ids] + pooled_data: list[torch.Tensor] = [] + if isinstance(hidden_states, list): + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with mean pooling" + pooled_data = hidden_states + else: + offset = 0 + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + offset += prompt_len + pooled_data.append(pooled_data_i) + + pooled_data = [] + returned_token_ids = self.returned_token_ids step_tag_id = self.step_tag_id - offset = 0 - pooled_data = list[torch.Tensor]() - for prompt_len, seq_data_i in zip(prompt_lens, - pooling_metadata.seq_data.values()): - pooled_data_i = hidden_states[offset:offset + prompt_len] - if step_tag_id is not None: - token_ids = torch.tensor(seq_data_i.prompt_token_ids) - pooled_data_i = pooled_data_i[token_ids == step_tag_id] + for data, token_id in zip(pooled_data, prompt_token_ids): + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] - offset += prompt_len - pooled_data.append(pooled_data_i) + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) return pooled_data @@ -230,10 +286,17 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], else: pooled_data = pooled_data.to(torch.float32) - dimensions_list = [ - pooling_param.dimensions - for _, pooling_param in pooling_metadata.seq_groups - ] + if isinstance(pooling_metadata, V0PoolingMetadata): + dimensions_list = [ + pooling_param.dimensions + for _, pooling_param in pooling_metadata.seq_groups + ] + else: + assert isinstance(pooled_data, list) + dimensions_list = [ + pooling_param.dimensions + for pooling_param in pooling_metadata.pooling_params + ] if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) @@ -325,20 +388,41 @@ def __init__( raise NotImplementedError(f"task={config.task!r} is not supported" " with the classification pooler") + def get_prompt_lens( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + assert isinstance(hidden_states, torch.Tensor) + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + def forward( self, - hidden_states: torch.Tensor, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: """Pools sentence pair scores from the hidden_states.""" + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens + pooled_data = list[torch.Tensor]() + if isinstance(hidden_states, list): + for req_state, prompt_len in zip(hidden_states, prompt_lens): + assert prompt_len == req_state.shape[0], \ + "partial prefill not supported with classifier" + pooled_data = hidden_states + else: + offset = 0 + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + offset += prompt_len + pooled_data.append(pooled_data_i) offset = 0 pooled_data_lst = [] - for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] + for pooled_data_i in pooled_data: if self.pooler is not None: final_shape_tensor = self.pooler(pooled_data_i) @@ -346,7 +430,6 @@ def forward( final_shape_tensor = self.classifier(pooled_data_i) pooled_data_lst.append(final_shape_tensor) - offset += prompt_len pooled_output = torch.stack(pooled_data_lst) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 389393987c81..d6f6d9d1fb59 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -446,8 +446,8 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: softmax=False) -class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, - SupportsQuant): +class BertForSequenceClassification(nn.Module, SupportsV0Only, + SupportsCrossEncoding, SupportsQuant): """A model that uses Bert to provide embedding functionalities. This class encapsulates the BertModel and provides an interface for diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 35f416a6e21e..7c1f889e8f38 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -21,7 +21,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsCrossEncoding +from .interfaces import SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -270,7 +270,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output -class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): +class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, + SupportsCrossEncoding): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index bad0f6b1ffb7..216c1f1c7ff7 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -375,7 +375,12 @@ def pooler( ) -> Optional[PoolerOutput]: hidden_states = self._pooler.extract_states(hidden_states, pooling_metadata) - logits, _ = self.score(hidden_states) + + if isinstance(hidden_states, list): + logits = [self.score(state)[0] for state in hidden_states] + else: + logits, _ = self.score(hidden_states) + pooled_data = self._pooler.head(logits, pooling_metadata) pooled_outputs = [ self._pooler.build_output(data.squeeze(-1)) for data in pooled_data diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 322f9ed3efa9..b5c327bdd256 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -5,6 +5,8 @@ import msgspec +from vllm.sampling_params import RequestOutputKind + if TYPE_CHECKING: from vllm.config import ModelConfig @@ -23,6 +25,7 @@ class PoolingParams( dimensions: Optional[int] = None additional_data: Optional[Any] = None + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" @@ -52,3 +55,7 @@ def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " f"additional_metadata={self.additional_data})") + + def __post_init__(self) -> None: + assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ + "For pooling output_kind has to be FINAL_ONLY" diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 99531e7d213d..08bb0efb2f3d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -146,7 +146,8 @@ def get_computed_blocks(self, # Prefix caching is disabled or # When the request requires prompt logprobs, we skip prefix caching. if (not self.enable_caching - or request.sampling_params.prompt_logprobs is not None): + or (request.sampling_params is not None + and request.sampling_params.prompt_logprobs is not None)): return self.create_empty_block_list(), 0 # The block hashes for the request may already be computed diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 9b0a439fe7dc..6f31031a1086 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -14,6 +14,7 @@ KVConnectorMetadata) from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange + from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.request import Request @@ -26,7 +27,8 @@ class NewRequestData: mm_inputs: list[MultiModalKwargs] mm_hashes: list[str] mm_positions: list[PlaceholderRange] - sampling_params: SamplingParams + sampling_params: Optional[SamplingParams] + pooling_params: Optional[PoolingParams] block_ids: tuple[list[int], ...] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -44,6 +46,7 @@ def from_request( mm_hashes=request.mm_hashes, mm_positions=request.mm_positions, sampling_params=request.sampling_params, + pooling_params=request.pooling_params, block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2d2274ab6a4d..16e76defdf72 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -402,6 +402,15 @@ def schedule(self) -> SchedulerOutput: < num_new_tokens): num_new_tokens = ( self.scheduler_config.long_prefill_token_threshold) + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if not self.scheduler_config.chunked_prefill_enabled and \ + num_new_tokens > token_budget: + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 @@ -707,6 +716,7 @@ def update_from_output( logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pooler_outputs = model_runner_output.pooler_output new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) @@ -724,7 +734,8 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[req_index] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] scheduled_spec_token_ids = ( scheduler_output.scheduled_spec_decode_tokens.get(req_id)) @@ -776,8 +787,17 @@ def update_from_output( del new_token_ids[num_new:] # Trim new tokens if needed. break + pooler_output = None + if pooler_outputs: + pooler_output = pooler_outputs[req_index] + stopped = check_stop(request, self.max_model_len, + pooler_output) + if stopped: + kv_transfer_params = self._free_request(request) + # Extract sample logprobs if needed. - if request.sampling_params.logprobs is not None and logprobs: + if request.sampling_params is not None \ + and request.sampling_params.logprobs is not None and logprobs: # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) @@ -802,7 +822,8 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or kv_transfer_params: + if new_token_ids or pooler_output is not None \ + or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( @@ -812,6 +833,7 @@ def update_from_output( finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, + pooling_output=pooler_output, stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, diff --git a/vllm/v1/core/sched/utils.py b/vllm/v1/core/sched/utils.py index 1397c5f4c9a6..42ec95091f96 100644 --- a/vllm/v1/core/sched/utils.py +++ b/vllm/v1/core/sched/utils.py @@ -1,15 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + from vllm.v1.request import Request, RequestStatus -def check_stop(request: Request, max_model_len: int) -> bool: +def check_stop(request: Request, + max_model_len: int, + pooler_output: Optional[torch.Tensor] = None) -> bool: if (request.num_tokens >= max_model_len or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED return True + if request.pooling_params: + if pooler_output is not None: + request.status = RequestStatus.FINISHED_STOPPED + return True + return False + sampling_params = request.sampling_params + assert sampling_params is not None last_token_id = request.output_token_ids[-1] if (not sampling_params.ignore_eos and last_token_id == request.eos_token_id): diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 59463f1ba99f..4d1696a9b43a 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -7,10 +7,12 @@ from typing import Any, Optional, Union import msgspec +import torch from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors @@ -50,7 +52,8 @@ class EngineCoreRequest( mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] mm_hashes: Optional[list[str]] mm_placeholders: Optional[list[PlaceholderRange]] - sampling_params: SamplingParams + sampling_params: Optional[SamplingParams] + pooling_params: Optional[PoolingParams] eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] @@ -104,6 +107,8 @@ class EngineCoreOutput( new_logprobs: Optional[LogprobsLists] = None new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + pooling_output: Optional[torch.Tensor] = None + finish_reason: Optional[FinishReason] = None stop_reason: Union[int, str, None] = None events: Optional[list[EngineCoreEvent]] = None diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 7fb36cf5941e..998c4c5ea3cf 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -17,7 +17,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.outputs import RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -228,8 +228,7 @@ async def add_request( if self.errored: raise EngineDeadError() - assert isinstance(params, SamplingParams), \ - "Pooling is not supported in V1" + is_pooling = isinstance(params, PoolingParams) # Create a new output collector for the request. queue = RequestOutputCollector(output_kind=params.output_kind) @@ -240,7 +239,7 @@ async def add_request( tokenization_kwargs, trace_headers, prompt_adapter_request, priority, data_parallel_rank) - if params.n == 1: + if is_pooling or params.n == 1: await self._add_request(request, prompt_str, None, 0, queue) return queue @@ -443,7 +442,7 @@ def _record_stats( stat_logger.record(scheduler_stats=scheduler_stats, iteration_stats=iteration_stats) - def encode( + async def encode( self, prompt: PromptType, pooling_params: PoolingParams, @@ -451,8 +450,75 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ): - raise ValueError("Not Supported on V1 yet.") + ) -> AsyncGenerator[PoolingRequestOutput, None]: + """ + Main function called by the API server to kick off a request + * 1) Making an AsyncStream corresponding to the Request. + * 2) Processing the Input. + * 3) Adding the Request to the EngineCore (separate process). + + A separate output_handler loop runs in a background AsyncIO task, + pulling outputs from EngineCore and putting them into the + per-request AsyncStream. + + The caller of generate() iterates the returned AsyncGenerator, + returning the RequestOutput back to the caller. + """ + + try: + # We start the output_handler on the first call to generate() so + # we can call __init__ before the event loop, which enables us + # to handle startup failure gracefully in the OpenAI server. + self._run_output_handler() + + q = await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ) + + # The output_handler task pushes items into the queue. + # This task pulls from the queue and yields to caller. + finished = False + while not finished: + # Note: drain queue without await if possible (avoids + # task switching under load which helps performance). + out = q.get_nowait() or await q.get() + assert isinstance(out, PoolingRequestOutput) + # Note: both OutputProcessor and EngineCore handle their + # own request cleanup based on finished. + finished = out.finished + yield out + + # If the request is disconnected by the client, generate() + # is cancelled. So, we abort the request if we end up here. + except asyncio.CancelledError: + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s aborted.", request_id) + raise + + # Engine is dead. Do not abort since we shut down. + except EngineDeadError: + if self.log_requests: + logger.info("Request %s failed (engine dead).", request_id) + raise + + # Request validation error. + except ValueError: + if self.log_requests: + logger.info("Request %s failed (bad request).", request_id) + raise + + # Unexpected error in the generate() task (possibly recoverable). + except Exception as e: + await self.abort(request_id) + if self.log_requests: + logger.info("Request %s failed.", request_id) + raise EngineGenerateError() from e async def get_vllm_config(self) -> VllmConfig: return self.vllm_config diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 57fcf8daa5a1..da65550354d0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -60,7 +60,6 @@ def __init__(self, executor_class: type[Executor], log_stats: bool, executor_fail_callback: Optional[Callable] = None): - assert vllm_config.model_config.runner_type != "pooling" # plugins need to be loaded at the engine/scheduler level too from vllm.plugins import load_general_plugins diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 35aceba0fe76..2f5504ea14b4 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -50,6 +50,8 @@ def from_new_request( request: EngineCoreRequest, ) -> "IncrementalDetokenizer": + assert request.sampling_params is not None + if tokenizer is None: # No tokenizer => skipping detokenization. return IncrementalDetokenizer() @@ -70,6 +72,7 @@ def __init__(self, request: EngineCoreRequest): # Stop strings params = request.sampling_params + assert params is not None self.stop = stop = params.stop self.include_stop_str_in_output = params.include_stop_str_in_output @@ -164,6 +167,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, super().__init__(request) sampling_params = request.sampling_params + assert sampling_params is not None self.request_id = request.request_id self.skip_special_tokens = sampling_params.skip_special_tokens @@ -245,20 +249,20 @@ def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest): super().__init__(request) self.tokenizer = tokenizer + params = request.sampling_params + assert params is not None # Metadata for incremental detokenization. self.tokens, self.prefix_offset, self.read_offset = ( convert_prompt_ids_to_tokens( tokenizer=tokenizer, prompt_ids=request.prompt_token_ids, - skip_special_tokens=request.sampling_params. - skip_special_tokens, + skip_special_tokens=params.skip_special_tokens, )) self.token_ids.extend(request.prompt_token_ids) self.prompt_len = len(request.prompt_token_ids) - params = request.sampling_params self.skip_special_tokens = params.skip_special_tokens self.spaces_between_special_tokens = ( params.spaces_between_special_tokens) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 736ffd8b40f0..1932cd10bb1b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -15,7 +15,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.outputs import RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -221,7 +221,7 @@ def add_request( # Add the request to EngineCore. self.engine_core.add_request(child_request) - def step(self) -> list[RequestOutput]: + def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: if self.should_execute_dummy_batch: self.should_execute_dummy_batch = False diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index edc3be5b0120..e95da0a5e5aa 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -38,6 +38,7 @@ def from_new_request( tokenizer: Optional[AnyTokenizer], request: EngineCoreRequest, ) -> "LogprobsProcessor": + assert request.sampling_params is not None num_logprobs = request.sampling_params.logprobs num_prompt_logprobs = request.sampling_params.prompt_logprobs return cls( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 1dcfbab30cfb..2bcd61d1f0aa 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -4,9 +4,12 @@ import asyncio from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast -from vllm.outputs import CompletionOutput, RequestOutput +import torch + +from vllm.outputs import (CompletionOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput) from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -29,20 +32,22 @@ class RequestOutputCollector: def __init__(self, output_kind: RequestOutputKind): self.aggregate = output_kind == RequestOutputKind.DELTA - self.output: Optional[Union[RequestOutput, Exception]] = None + self.output: Optional[Union[RequestOutput, PoolingRequestOutput, + Exception]] = None self.ready = asyncio.Event() - def put(self, output: Union[RequestOutput, Exception]) -> None: + def put(self, output: Union[RequestOutput, PoolingRequestOutput, + Exception]) -> None: """Non-blocking put operation.""" if self.output is None or isinstance(output, Exception): self.output = output self.ready.set() - elif isinstance(self.output, RequestOutput): + elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)): # This ensures that request outputs with different request indexes # (if n > 1) do not override each other. self.output.add(output, aggregate=self.aggregate) - async def get(self) -> RequestOutput: + async def get(self) -> Union[RequestOutput, PoolingRequestOutput]: """Get operation blocks on put event.""" while (output := self.output) is None: await self.ready.wait() @@ -52,7 +57,8 @@ async def get(self) -> RequestOutput: raise output return output - def get_nowait(self) -> Optional[RequestOutput]: + def get_nowait( + self) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: """Non-blocking get operation.""" output = self.output if output is not None: @@ -66,7 +72,7 @@ def get_nowait(self) -> Optional[RequestOutput]: @dataclass class OutputProcessorOutput: - request_outputs: list[RequestOutput] + request_outputs: list[Union[RequestOutput, PoolingRequestOutput]] reqs_to_abort: list[str] @@ -81,8 +87,8 @@ def __init__( output_kind: RequestOutputKind, prompt: Optional[str], prompt_token_ids: list[int], - logprobs_processor: LogprobsProcessor, - detokenizer: IncrementalDetokenizer, + logprobs_processor: Optional[LogprobsProcessor], + detokenizer: Optional[IncrementalDetokenizer], max_tokens_param: Optional[int], arrival_time: float, queue: Optional[RequestOutputCollector], @@ -116,27 +122,39 @@ def from_new_request( queue: Optional[RequestOutputCollector], log_stats: bool, ) -> "RequestState": - if not request.sampling_params.detokenize: - tokenizer = None + + if sampling_params := request.sampling_params: + if not sampling_params.detokenize: + tokenizer = None + output_kind = sampling_params.output_kind + logprobs_processor = LogprobsProcessor.from_new_request( + tokenizer=tokenizer, + request=request, + ) + detokenizer = IncrementalDetokenizer.from_new_request( + tokenizer=tokenizer, + request=request, + ) + max_tokens_param = sampling_params.max_tokens + else: + logprobs_processor = None + detokenizer = None + max_tokens_param = None + assert request.pooling_params is not None + output_kind = request.pooling_params.output_kind + return cls( request_id=request.request_id, parent_req=parent_req, request_index=request_index, lora_name=(request.lora_request.name if request.lora_request is not None else None), - output_kind=request.sampling_params.output_kind, + output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, - logprobs_processor=LogprobsProcessor.from_new_request( - tokenizer=tokenizer, - request=request, - ), - detokenizer=IncrementalDetokenizer.from_new_request( - tokenizer=tokenizer, - request=request, - ), - max_tokens_param=(request.sampling_params.max_tokens if - request.sampling_params is not None else None), + logprobs_processor=logprobs_processor, + detokenizer=detokenizer, + max_tokens_param=max_tokens_param, arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -145,11 +163,12 @@ def from_new_request( def make_request_output( self, new_token_ids: list[int], + pooling_output: Optional[torch.Tensor], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], kv_transfer_params: Optional[dict[str, Any]] = None, num_cached_tokens: int = 0, - ) -> Optional[RequestOutput]: + ) -> Optional[Union[RequestOutput, PoolingRequestOutput]]: finished = finish_reason is not None final_only = self.output_kind == RequestOutputKind.FINAL_ONLY @@ -158,15 +177,20 @@ def make_request_output( # Only the final output is required in FINAL_ONLY mode. return None - completion_output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason) - request_id = self.request_id + if pooling_output is not None: + return self._new_request_output( + request_id, [self._new_pooling_output(pooling_output)], + finished) + + output = self._new_completion_output(new_token_ids, finish_reason, + stop_reason) + if self.parent_req is None: - outputs = [completion_output] + outputs = [output] else: request_id, outputs, finished = self.parent_req.get_outputs( - request_id, completion_output) + request_id, output) if not outputs: return None @@ -176,12 +200,21 @@ def make_request_output( def _new_request_output( self, request_id: str, - outputs: list[CompletionOutput], + outputs: Union[list[CompletionOutput], list[PoolingOutput]], finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, num_cached_tokens: int = 0, - ) -> RequestOutput: - + ) -> Union[RequestOutput, PoolingRequestOutput]: + + if isinstance(outputs[0], PoolingOutput): + assert len(outputs) == 1 + return PoolingRequestOutput( + request_id=request_id, + outputs=outputs[0], + prompt_token_ids=self.prompt_token_ids, + finished=finished, + ) + assert self.logprobs_processor is not None if self.output_kind == RequestOutputKind.DELTA: # Side effect: logprobs processor forgets prompt logprobs prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() @@ -193,7 +226,7 @@ def _new_request_output( prompt=self.prompt, prompt_token_ids=self.prompt_token_ids, prompt_logprobs=prompt_logprobs, - outputs=outputs, + outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, num_cached_tokens=num_cached_tokens, @@ -206,6 +239,8 @@ def _new_completion_output( stop_reason: Union[int, str, None], ) -> CompletionOutput: + assert self.detokenizer is not None + assert self.logprobs_processor is not None finished = finish_reason is not None delta = self.output_kind == RequestOutputKind.DELTA @@ -228,6 +263,13 @@ def _new_completion_output( finish_reason=str(finish_reason) if finished else None, stop_reason=stop_reason if finished else None) + def _new_pooling_output( + self, + pooling_output: torch.Tensor, + ) -> PoolingOutput: + + return PoolingOutput(data=pooling_output) + class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" @@ -326,7 +368,8 @@ def process_outputs( within the loop below. """ - request_outputs: list[RequestOutput] = [] + request_outputs: Union[list[RequestOutput], + list[PoolingRequestOutput]] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id @@ -341,25 +384,31 @@ def process_outputs( iteration_stats) new_token_ids = engine_core_output.new_token_ids + pooling_output = engine_core_output.pooling_output finish_reason = engine_core_output.finish_reason stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False - # 2) Detokenize the token ids into text and perform stop checks. - stop_string = req_state.detokenizer.update( - new_token_ids, finish_reason == FinishReason.STOP) - if stop_string: - finish_reason = FinishReason.STOP - stop_reason = stop_string - - # 3) Compute sample and prompt logprobs for request, if required. - req_state.logprobs_processor.update_from_output(engine_core_output) + if pooling_output is None: + assert req_state.detokenizer is not None + assert req_state.logprobs_processor is not None + # 2) Detokenize the token ids into text and perform stop checks. + stop_string = req_state.detokenizer.update( + new_token_ids, finish_reason == FinishReason.STOP) + if stop_string: + finish_reason = FinishReason.STOP + stop_reason = stop_string + + # 3) Compute sample and prompt logprobs for request, + # if required. + req_state.logprobs_processor.update_from_output( + engine_core_output) # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason, + new_token_ids, pooling_output, finish_reason, stop_reason, kv_transfer_params, num_cached_tokens): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index e28879d40460..b00f1444c7b3 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -136,8 +136,8 @@ def _validate_params( Should raise ValueError if unsupported for API Server. """ - if not isinstance(params, SamplingParams): - raise ValueError("V1 does not yet support Pooling models.") + if isinstance(params, PoolingParams): + return self._validate_logprobs(params) self._validate_sampling_params(params, lora_request) @@ -263,18 +263,22 @@ def process_inputs( if encoder_inputs is not None: raise NotImplementedError - assert isinstance(params, SamplingParams) - # TODO: can we avoid cloning here in multiproc case? - sampling_params = params.clone() - # If unset max tokens, then generate up to the max_model_len. - if sampling_params.max_tokens is None: - sampling_params.max_tokens = ( - self.model_config.max_model_len - - len(decoder_inputs["prompt_token_ids"])) - sampling_params.update_from_generation_config( - self.generation_config_fields, eos_token_id) - sampling_params.update_from_tokenizer( - self.tokenizer.get_lora_tokenizer(lora_request)) + sampling_params = None + pooling_params = None + if isinstance(params, SamplingParams): + # TODO: can we avoid cloning here in multiproc case? + sampling_params = params.clone() + # If unset max tokens, then generate up to the max_model_len. + if sampling_params.max_tokens is None: + sampling_params.max_tokens = ( + self.model_config.max_model_len - + len(decoder_inputs["prompt_token_ids"])) + sampling_params.update_from_generation_config( + self.generation_config_fields, eos_token_id) + sampling_params.update_from_tokenizer( + self.tokenizer.get_lora_tokenizer(lora_request)) + else: + pooling_params = params.clone() # Multimodal related. sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None @@ -331,6 +335,7 @@ def process_inputs( mm_hashes=sorted_mm_hashes, mm_placeholders=sorted_mm_positions, sampling_params=sampling_params, + pooling_params=pooling_params, eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 11865a0fd1f2..c720ca13e51b 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -481,8 +481,9 @@ def record(self, scheduler_stats: Optional[SchedulerStats], finished_request.num_prompt_tokens) self.histogram_num_generation_tokens_request.observe( finished_request.num_generation_tokens) - self.histogram_max_tokens_request.observe( - finished_request.max_tokens_param) + if finished_request.max_tokens_param: + self.histogram_max_tokens_request.observe( + finished_request.max_tokens_param) if self.gauge_lora_info is not None: running_lora_adapters = \ diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4a5d5fac49d1..716f40fffb28 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -106,7 +106,6 @@ def update_from_output(self, output: "EngineCoreOutput", self.num_generation_tokens += num_new_generation_tokens if is_prefilling: - assert num_new_generation_tokens > 0 self.num_prompt_tokens += prompt_len first_token_latency = self._time_since(req_stats.arrival_time) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 17a299d57cba..2234843293cc 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -101,6 +101,9 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + # [num_reqs, hidden_size] + pooler_output: list[Optional[torch.Tensor]] + # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None @@ -112,5 +115,6 @@ class ModelRunnerOutput: spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, + pooler_output=[], finished_sending=None, finished_recving=None) diff --git a/vllm/v1/pool/__init__.py b/vllm/v1/pool/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py new file mode 100644 index 000000000000..d70a0d044661 --- /dev/null +++ b/vllm/v1/pool/metadata.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.pooling_params import PoolingParams + + +@dataclass +class PoolingMetadata: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + prompt_token_ids: Optional[torch.Tensor] + pooling_params: list[PoolingParams] diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 694e271e5ad7..e3f3a418755c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import is_list_of from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, @@ -25,7 +26,8 @@ def __init__( multi_modal_inputs: Optional[list[MultiModalKwargs]], multi_modal_hashes: Optional[list[str]], multi_modal_placeholders: Optional[list[PlaceholderRange]], - sampling_params: SamplingParams, + sampling_params: Optional[SamplingParams], + pooling_params: Optional[PoolingParams], eos_token_id: Optional[int], client_index: int = 0, lora_request: Optional["LoRARequest"] = None, @@ -35,18 +37,35 @@ def __init__( self.request_id = request_id self.client_index = client_index self.sampling_params = sampling_params + self.pooling_params = pooling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request - self.status = (RequestStatus.WAITING_FOR_FSM - if sampling_params.guided_decoding is not None else - RequestStatus.WAITING) + self.status = RequestStatus.WAITING + if sampling_params and sampling_params.guided_decoding is not None: + self.status = RequestStatus.WAITING_FOR_FSM self.events: list[EngineCoreEvent] = [] self.stop_reason: Union[int, str, None] = None - assert sampling_params.max_tokens is not None - self.max_tokens = sampling_params.max_tokens + + # P/D: Connector-specific KV transfer parameters. + self.kv_transfer_params: Optional[dict[str, Any]] = None + + if pooling_params is not None: + self.max_tokens = 1 + elif sampling_params is not None: + assert sampling_params.max_tokens is not None + self.max_tokens = sampling_params.max_tokens + if sampling_params.guided_decoding is not None: + self.status = RequestStatus.WAITING_FOR_FSM + + if sampling_params.extra_args is not None: + self.kv_transfer_params = \ + sampling_params.extra_args.get("kv_transfer_params") + else: + raise ValueError( + "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids self.num_prompt_tokens = len(self.prompt_token_ids) @@ -63,11 +82,6 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 - # P/D: Connector-specific KV transfer parameters. - kv_params = (None if sampling_params.extra_args is None else - sampling_params.extra_args.get("kv_transfer_params")) - self.kv_transfer_params: Optional[dict[str, Any]] = kv_params - # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: @@ -98,10 +112,12 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, + pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( - sampling_params=request.sampling_params), + sampling_params=request.sampling_params) \ + if request.sampling_params else None, cache_salt=request.cache_salt, ) @@ -141,7 +157,8 @@ def get_num_encoder_tokens(self, input_id: int) -> int: @property def use_structured_output(self) -> bool: - return self.sampling_params.guided_decoding is not None + return self.sampling_params is not None and \ + self.sampling_params.guided_decoding is not None def record_event( self, diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index b2b0ee796954..c5500b9a384d 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -62,13 +62,15 @@ def grammar_init(self, request: Request) -> None: return if TYPE_CHECKING: - assert request.sampling_params.guided_decoding is not None + assert request.sampling_params is not None and \ + request.sampling_params.guided_decoding is not None # Initialize the backend the first time it is needed. # # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: + assert request.sampling_params is not None backend = request.sampling_params.guided_decoding.backend vocab_size = self.vllm_config.model_config.get_vocab_size() if backend == "xgrammar": diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e76293f98a51..3a2c9ef7dfac 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -10,9 +10,11 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors +from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable @@ -27,7 +29,8 @@ class CachedRequestState: prompt_token_ids: list[int] mm_inputs: list[MultiModalKwargs] mm_positions: list[PlaceholderRange] - sampling_params: SamplingParams + sampling_params: Optional[SamplingParams] + pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] block_ids: tuple[list[int], ...] @@ -226,6 +229,8 @@ def __init__( # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() + self.pooling_params: dict[str, PoolingParams] = {} + @property def req_ids(self) -> list[str]: # None elements should only be present transiently @@ -269,77 +274,83 @@ def add_request( self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens self.block_table.add_row(request.block_ids, req_index) - sampling_params = request.sampling_params - if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 - self.greedy_reqs.add(req_id) - else: - self.temperature_cpu[req_index] = sampling_params.temperature - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - top_k = sampling_params.top_k - if 0 < top_k < self.vocab_size: - self.top_k_reqs.add(req_id) - else: - top_k = self.vocab_size - self.top_k_cpu[req_index] = top_k - self.min_p_cpu[req_index] = sampling_params.min_p - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty - if sampling_params.min_p > _SAMPLING_EPS: - self.min_p_reqs.add(req_id) - if sampling_params.frequency_penalty != 0.0: - self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty - if sampling_params.presence_penalty != 0.0: - self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty - if sampling_params.repetition_penalty != 1.0: - self.repetition_penalties_reqs.add(req_id) - if sampling_params.min_tokens: - self.min_tokens[req_index] = (sampling_params.min_tokens, - sampling_params.all_stop_token_ids) - - # NOTE(woosuk): self.generators should not include the requests that - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = sampling_params.logprobs - if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs - if sampling_params.logit_bias is not None: - self.logit_bias[req_index] = sampling_params.logit_bias - - if sampling_params.allowed_token_ids: - self.has_allowed_token_ids.add(req_id) - if self.allowed_token_ids_mask_cpu_tensor is None: - # Lazy allocation for this tensor, which can be large. + if sampling_params := request.sampling_params: + if sampling_params.sampling_type == SamplingType.GREEDY: + # Avoid later division by zero. + self.temperature_cpu[req_index] = -1.0 + self.greedy_reqs.add(req_id) + else: + self.temperature_cpu[req_index] = sampling_params.temperature + self.random_reqs.add(req_id) + + self.top_p_cpu[req_index] = sampling_params.top_p + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: + self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k_cpu[req_index] = top_k + self.min_p_cpu[req_index] = sampling_params.min_p + self.frequency_penalties_cpu[ + req_index] = sampling_params.frequency_penalty + if sampling_params.min_p > _SAMPLING_EPS: + self.min_p_reqs.add(req_id) + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties_cpu[ + req_index] = sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties_cpu[ + req_index] = sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + if sampling_params.min_tokens: + self.min_tokens[req_index] = ( + sampling_params.min_tokens, + sampling_params.all_stop_token_ids) + + # NOTE(woosuk): self.generators should not include the requests that + # do not have their own generator. + if request.generator is not None: + self.generators[req_index] = request.generator + + if sampling_params.logprobs is not None: + self.num_logprobs[req_id] = sampling_params.logprobs + if sampling_params.prompt_logprobs is not None: + self.num_prompt_logprobs[ + req_id] = sampling_params.prompt_logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = sampling_params.logit_bias + + if sampling_params.allowed_token_ids: + self.has_allowed_token_ids.add(req_id) + if self.allowed_token_ids_mask_cpu_tensor is None: + # Lazy allocation for this tensor, which can be large. + # False means we don't fill with -inf. + self.allowed_token_ids_mask = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device=self.device) + self.allowed_token_ids_mask_cpu_tensor = torch.zeros( + self.max_num_reqs, + self.vocab_size, + dtype=torch.bool, + device="cpu") + self.allowed_token_ids_mask_cpu_tensor[req_index] = True # False means we don't fill with -inf. - self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( - self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device="cpu") - self.allowed_token_ids_mask_cpu_tensor[req_index] = True - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False + self.allowed_token_ids_mask_cpu_tensor[req_index][ + sampling_params.allowed_token_ids] = False - if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids + if sampling_params.bad_words_token_ids: + self.bad_words_token_ids[ + req_index] = sampling_params.bad_words_token_ids + else: + assert request.pooling_params is not None + self.pooling_params[req_id] = request.pooling_params # Add request lora ID if request.lora_request: @@ -392,6 +403,7 @@ def remove_request(self, req_id: str) -> Optional[int]: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) + self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: @@ -602,6 +614,25 @@ def _make_sampling_metadata(self) -> SamplingMetadata: bad_words_token_ids=self.bad_words_token_ids, ) + @property + def pooling_metadata(self) -> PoolingMetadata: + if len(self.pooling_params) == 0: + pooling_params = [] + else: + # Note, for now this assumes that all request in the batch + # are either sampling or pooling requests + assert len(self.req_ids) == len(self.pooling_params) + pooling_params = [ + self.pooling_params[req_id] for req_id in self.req_ids + ] + + return PoolingMetadata( + prompt_lens=torch.from_numpy( + self.num_prompt_tokens[:self.num_reqs]).to(self.device), + prompt_token_ids=self.sampling_metadata.prompt_token_ids, + pooling_params=pooling_params, + ) + def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c4163eb2b8f5..f96fb64342c9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,6 +36,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, @@ -51,6 +52,7 @@ SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) +from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -119,6 +121,7 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model + self.is_pooling_model = model_config.pooler_config is not None self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -394,7 +397,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + pooling_params = new_req_data.pooling_params + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -406,6 +411,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, + pooling_params=pooling_params, generator=generator, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, @@ -563,7 +569,7 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata]]: + Optional[SpecDecodeMetadata], np.ndarray]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -750,7 +756,7 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata) + spec_decode_metadata, num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1197,6 +1203,51 @@ def get_dp_padding(self, dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + def _pool( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]], + finished_recving: Optional[set[str]], + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" + + extracted_hidden_states = list( + torch.split(hidden_states[:num_scheduled_tokens], + num_scheduled_tokens_np.tolist())) + + pooling_metadata = self.input_batch.pooling_metadata + + raw_pooler_output = self.model.pooler( + hidden_states=extracted_hidden_states, + pooling_metadata=pooling_metadata) + + pooler_output: list[Optional[torch.Tensor]] = [] + seq_lens = self.seq_lens[:self.input_batch.num_reqs] + for raw_output, seq_len, prompt_len in zip( + raw_pooler_output, seq_lens, pooling_metadata.prompt_lens): + + if seq_len == prompt_len: + pooler_output.append(raw_output.data.cpu()) + else: + pooler_output.append(None) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + @torch.inference_mode() def execute_model( self, @@ -1214,7 +1265,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata) = (self._prepare_inputs(scheduler_output)) + spec_decode_metadata, + num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1284,7 +1336,7 @@ def execute_model( # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - # Run the decoder. + # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( attn_metadata, @@ -1326,6 +1378,11 @@ def execute_model( all_gather_group=get_tp_group()) logits = None else: + if self.input_batch.pooling_params: + return self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving) + sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: @@ -1541,6 +1598,7 @@ def execute_model( spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], finished_sending=finished_sending, finished_recving=finished_recving, ) @@ -1802,7 +1860,7 @@ def _dummy_run( self, num_tokens: int, capture_attn_cudagraph: bool = False, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) @@ -1899,7 +1957,7 @@ def _dummy_run( self.drafter.dummy_run(num_tokens) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - return hidden_states[logit_indices] + return hidden_states, hidden_states[logit_indices] @torch.inference_mode() def _dummy_sampler_run( @@ -1978,6 +2036,48 @@ def _dummy_sampler_run( ) return sampler_output + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + num_tokens = hidden_states.shape[0] + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + hidden_states_list = list( + torch.split(hidden_states, num_scheduled_tokens_list)) + + req_num_tokens = num_tokens // num_reqs + + dummy_metadata = PoolingMetadata( + prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list], + device=self.device), + prompt_token_ids=torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device), + pooling_params=[PoolingParams()] * num_reqs) + + try: + pooler_output = self.model.pooler(hidden_states=hidden_states_list, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler with " + f"{num_reqs} dummy requests. Please try lowering " + "`max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + return pooler_output + def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. # TODO: handle encoder-decoder models once we support them. @@ -2048,13 +2148,17 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - hidden_states = self._dummy_run(self.max_num_tokens) + hidden_states, last_hidden_states \ + = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: - sampler_output = self._dummy_sampler_run(hidden_states) + if self.is_pooling_model: + output = self._dummy_pooler_run(hidden_states) + else: + output = self._dummy_sampler_run(last_hidden_states) else: - sampler_output = None + output = None self._sync_device() - del hidden_states, sampler_output + del hidden_states, output self.encoder_cache.clear() gc.collect() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 58795e3fe292..b0f80c701325 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -273,9 +273,14 @@ def compile_or_warm_up_model(self) -> None: if get_pp_group().is_last_rank: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) - self.model_runner._dummy_sampler_run( - hidden_states=self.model_runner._dummy_run( - num_tokens=max_num_reqs)) + + hidden_states, last_hidden_states = \ + self.model_runner._dummy_run(num_tokens=max_num_reqs) + if self.model_runner.is_pooling_model: + self.model_runner._dummy_pooler_run(hidden_states) + else: + self.model_runner._dummy_sampler_run( + hidden_states=last_hidden_states) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 3f105ccc5d92..81c798685cb3 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -231,6 +231,7 @@ def add_request( self.block_table.add_row(request.block_ids, req_index) sampling_params = request.sampling_params + assert sampling_params is not None, "pooling requests not supported yet" if sampling_params.sampling_type == SamplingType.GREEDY: # Avoid later division by zero. self.temperature_cpu[req_index] = -1.0 diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index de5a0a1f5597..774caa1a3d98 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -386,6 +386,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: req_ids_to_add: list[str] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: + assert new_req_data.sampling_params is not None,\ + "Pooling is not supported in TPU yet" req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params @@ -395,6 +397,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: mm_inputs=new_req_data.mm_inputs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, + pooling_params=None, generator=None, block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, @@ -956,6 +959,7 @@ def execute_model( spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + pooler_output=[], ) # Check there are no new graphs compiled - all the graphs should be