From 462b26928051742c4dcac259fd0f04cc56fe581d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 6 Jul 2025 18:32:13 -0700 Subject: [PATCH 01/27] Implement OpenAI Responses API [1/N] (#20504) Signed-off-by: Woosuk Kwon --- .../entrypoints/openai/test_openai_schema.py | 4 + .../entrypoints/openai/responses/__init__.py | 0 .../entrypoints/openai/responses/conftest.py | 32 ++ .../openai/responses/test_basic.py | 75 +++ .../openai/responses/test_stateful.py | 137 ++++++ .../responses/test_structured_output.py | 92 ++++ vllm/entrypoints/chat_utils.py | 4 +- vllm/entrypoints/openai/api_server.py | 91 +++- vllm/entrypoints/openai/protocol.py | 201 ++++++++ vllm/entrypoints/openai/serving_engine.py | 8 +- vllm/entrypoints/openai/serving_responses.py | 464 ++++++++++++++++++ vllm/reasoning/abs_reasoning_parsers.py | 6 +- 12 files changed, 1106 insertions(+), 8 deletions(-) create mode 100644 tests/v1/entrypoints/openai/responses/__init__.py create mode 100644 tests/v1/entrypoints/openai/responses/conftest.py create mode 100644 tests/v1/entrypoints/openai/responses/test_basic.py create mode 100644 tests/v1/entrypoints/openai/responses/test_stateful.py create mode 100644 tests/v1/entrypoints/openai/responses/test_structured_output.py create mode 100644 vllm/entrypoints/openai/serving_responses.py diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 4ded37595384..aa87cd22fe44 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -95,6 +95,10 @@ def test_openapi_stateless(case: schemathesis.Case): case.operation.method.upper(), case.operation.path, ) + if case.operation.path.startswith("/v1/responses"): + # Skip responses API as it is meant to be stateful. + return + timeout = { # requires a longer timeout ("POST", "/v1/chat/completions"): diff --git a/tests/v1/entrypoints/openai/responses/__init__.py b/tests/v1/entrypoints/openai/responses/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/v1/entrypoints/openai/responses/conftest.py b/tests/v1/entrypoints/openai/responses/conftest.py new file mode 100644 index 000000000000..2dcdda04ecb5 --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/conftest.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +# Use a small reasoning model to test the responses API. +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + "--max-model-len", + "8192", + "--enforce-eager", # For faster startup. + "--reasoning-parser", + "deepseek_r1", + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py new file mode 100644 index 000000000000..974ea8673c44 --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import openai # use the official client for correctness check +import pytest + + +@pytest.mark.asyncio +async def test_simple_input(client: openai.AsyncOpenAI): + response = await client.responses.create(input="What is 13 * 24?") + print(response) + + outputs = response.output + # Whether the output contains the answer. + assert outputs[-1].type == "message" + assert "312" in outputs[-1].content[0].text + + # Whether the output contains the reasoning. + assert outputs[0].type == "reasoning" + assert outputs[0].text != "" + + +@pytest.mark.asyncio +async def test_instructions(client: openai.AsyncOpenAI): + response = await client.responses.create( + instructions="Finish the answer with QED.", + input="What is 13 * 24?", + ) + print(response) + + output_text = response.output[-1].content[0].text + assert "312" in output_text + assert "QED" in output_text + + +@pytest.mark.asyncio +async def test_chat(client: openai.AsyncOpenAI): + response = await client.responses.create(input=[ + { + "role": "system", + "content": "Finish the answer with QED." + }, + { + "role": "user", + "content": "What is 5 * 3?" + }, + { + "role": "assistant", + "content": "15. QED." + }, + { + "role": "user", + "content": "Multiply the result by 2." + }, + ], ) + print(response) + + output_text = response.output[-1].content[0].text + assert "30" in output_text + assert "QED" in output_text + + +@pytest.mark.asyncio +async def test_chat_with_input_type(client: openai.AsyncOpenAI): + response = await client.responses.create(input=[ + { + "role": "user", + "content": [{ + "type": "input_text", + "text": "Hello!" + }], + }, + ], ) + print(response) + assert response.status == "completed" diff --git a/tests/v1/entrypoints/openai/responses/test_stateful.py b/tests/v1/entrypoints/openai/responses/test_stateful.py new file mode 100644 index 000000000000..a2d581ef7ced --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_stateful.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio + +import openai +import pytest + + +@pytest.mark.asyncio +async def test_store(client: openai.AsyncOpenAI): + # By default, store is True. + response = await client.responses.create(input="Hello!") + assert response.status == "completed" + + # Retrieve the response. + response = await client.responses.retrieve(response.id) + assert response.status == "completed" + + # Test store=False. + response = await client.responses.create( + input="Hello!", + store=False, + ) + assert response.status == "completed" + + # The response should not be found. + with pytest.raises(openai.NotFoundError, + match="Response with id .* not found."): + await client.responses.retrieve(response.id) + + +@pytest.mark.asyncio +async def test_background(client: openai.AsyncOpenAI): + # NOTE: This query should be easy enough for the model to answer + # within the 10 seconds. + response = await client.responses.create( + input="Hello!", + background=True, + ) + assert response.status == "queued" + + max_retries = 10 + for _ in range(max_retries): + await asyncio.sleep(1) + response = await client.responses.retrieve(response.id) + if response.status != "queued": + break + print(response) + + assert response.status == "completed" + + +@pytest.mark.asyncio +async def test_background_error(client: openai.AsyncOpenAI): + with pytest.raises( + openai.BadRequestError, + match="background can only be used when `store` is true"): + _ = await client.responses.create( + input="What is 13 * 24?", + background=True, + store=False, + ) + + +@pytest.mark.asyncio +async def test_background_cancel(client: openai.AsyncOpenAI): + response = await client.responses.create( + input="Write a long story about a cat.", + background=True, + ) + assert response.status == "queued" + + # Cancel the response before it is completed. + # FIXME: This test can be flaky. + await asyncio.sleep(0.5) + response = await client.responses.cancel(response.id) + assert response.status == "cancelled" + + # Make sure the response status remains unchanged. + await asyncio.sleep(5) + response = await client.responses.retrieve(response.id) + assert response.status == "cancelled" + + +@pytest.mark.asyncio +async def test_cancel_completed(client: openai.AsyncOpenAI): + response = await client.responses.create(input="Hello") + assert response.status == "completed" + + with pytest.raises(openai.BadRequestError, + match="Cannot cancel a synchronous response."): + await client.responses.cancel(response.id) + + +@pytest.mark.asyncio +async def test_previous_response_id(client: openai.AsyncOpenAI): + response1 = await client.responses.create( + instructions="You are tested on your ability to retrieve the correct " + "information from the previous response.", + input="Hello, my name is John.") + + response2 = await client.responses.create( + input="Actually, my name is not John. My real name is Mark.", + previous_response_id=response1.id, + ) + + response3 = await client.responses.create( + input="What is my real name again? Answer in one word.", + previous_response_id=response2.id, + ) + print(response3) + assert "Mark" in response3.output[-1].content[0].text + assert "John" not in response3.output[-1].content[0].text + + +@pytest.mark.asyncio +async def test_two_responses_with_same_prev_id(client: openai.AsyncOpenAI): + response1 = await client.responses.create( + instructions="You are tested on your ability to retrieve the correct " + "information from the previous response.", + input="Hello, my name is John.") + + # Both response 2 and 3 use response 1 as the previous response. + response2 = client.responses.create( + input="Actually, my name is not John. My name is Mark.", + previous_response_id=response1.id, + ) + response3 = client.responses.create( + input="What is my name again? Answer in one word.", + previous_response_id=response1.id, + ) + + _ = await response2 + response3_result = await response3 + print(response3_result) + assert "John" in response3_result.output[-1].content[0].text + assert "Mark" not in response3_result.output[-1].content[0].text diff --git a/tests/v1/entrypoints/openai/responses/test_structured_output.py b/tests/v1/entrypoints/openai/responses/test_structured_output.py new file mode 100644 index 000000000000..c4c43a87b601 --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_structured_output.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + +import openai +import pytest +from pydantic import BaseModel + + +@pytest.mark.asyncio +async def test_structured_output(client: openai.AsyncOpenAI): + response = await client.responses.create( + input=[ + { + "role": "system", + "content": "Extract the event information." + }, + { + "role": "user", + "content": + "Alice and Bob are going to a science fair on Friday.", + }, + ], + text={ + "format": { + "type": "json_schema", + "name": "calendar_event", + "schema": { + "type": "object", + "properties": { + "event_name": { + "type": "string" + }, + "date": { + "type": "string" + }, + "participants": { + "type": "array", + "items": { + "type": "string" + } + }, + }, + "required": ["event_name", "date", "participants"], + "additionalProperties": False, + }, + "description": "A calendar event.", + "strict": True, + } + }, + ) + print(response) + + # NOTE: The JSON schema is applied to the output text, not reasoning. + output_text = response.output[-1].content[0].text + event = json.loads(output_text) + + assert event["event_name"].lower() == "science fair" + assert event["date"] == "Friday" + participants = event["participants"] + assert len(participants) == 2 + assert participants[0] == "Alice" + assert participants[1] == "Bob" + + +@pytest.mark.asyncio +async def test_structured_output_with_parse(client: openai.AsyncOpenAI): + + class CalendarEvent(BaseModel): + event_name: str + date: str + participants: list[str] + + response = await client.responses.parse( + model=None, + instructions="Extract the event information.", + input="Alice and Bob are going to a science fair on Friday.", + text_format=CalendarEvent, + ) + print(response) + + # The output is successfully parsed. + event = response.output_parsed + assert event is not None + + # The output is correct. + assert event.event_name.lower() == "science fair" + assert event.date == "Friday" + participants = event.participants + assert len(participants) == 2 + assert participants[0] == "Alice" + assert participants[1] == "Bob" diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4b6c50526b10..012ea1d75f44 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -902,6 +902,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], ] = { "text": lambda part: _TextParser(part).get("text", None), + "input_text": + lambda part: _TextParser(part).get("text", None), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None), "image_embeds": @@ -1040,7 +1042,7 @@ def _parse_chat_message_content_part( "with empty / unparsable content.", part, part_type) return None - if part_type in ("text", "refusal"): + if part_type in ("text", "input_text", "refusal"): str_content = cast(str, content) if wrap_dicts: return {'type': 'text', 'text': str_content} diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6c0a95ebb1ee..d3b1a3802bba 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -69,8 +69,9 @@ PoolingCompletionRequest, PoolingRequest, PoolingResponse, RerankRequest, RerankResponse, - ScoreRequest, ScoreResponse, - TokenizeRequest, + ResponsesRequest, + ResponsesResponse, ScoreRequest, + ScoreResponse, TokenizeRequest, TokenizeResponse, TranscriptionRequest, TranscriptionResponse, @@ -87,6 +88,7 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) @@ -368,6 +370,10 @@ def models(request: Request) -> OpenAIServingModels: return request.app.state.openai_serving_models +def responses(request: Request) -> Optional[OpenAIServingResponses]: + return request.app.state.openai_serving_responses + + def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat @@ -531,6 +537,71 @@ async def show_version(): return JSONResponse(content=ver) +@router.post("/v1/responses", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +async def create_responses(request: ResponsesRequest, raw_request: Request): + handler = responses(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Responses API") + + generator = await handler.create_responses(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ResponsesResponse): + return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.get("/v1/responses/{response_id}") +async def retrieve_responses(response_id: str, raw_request: Request): + handler = responses(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Responses API") + + response = await handler.retrieve_responses(response_id) + + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + return JSONResponse(content=response.model_dump()) + + +@router.post("/v1/responses/{response_id}/cancel") +async def cancel_responses(response_id: str, raw_request: Request): + handler = responses(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Responses API") + + response = await handler.cancel_responses(response_id) + + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + return JSONResponse(content=response.model_dump()) + + @router.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)], responses={ @@ -1272,6 +1343,22 @@ async def init_app_state( prompt_adapters=args.prompt_adapters, ) await state.openai_serving_models.init_static_loras() + state.openai_serving_responses = OpenAIServingResponses( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + expand_tools_even_if_tool_choice_none=args. + expand_tools_even_if_tool_choice_none, + tool_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) if model_config.runner_type == "generate" else None state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d4db238f456e..14b2253d1dba 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,6 +11,12 @@ import regex as re import torch from fastapi import HTTPException, UploadFile +from openai.types.responses import (ResponseInputParam, ResponseOutputItem, + ResponseOutputMessage, ResponsePrompt, + ResponseStatus, ResponseTextConfig) +from openai.types.responses.response import ToolChoice +from openai.types.responses.tool import Tool +from openai.types.shared import Metadata, Reasoning from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) from typing_extensions import TypeAlias @@ -220,6 +226,124 @@ def get_logits_processors(processors: Optional[LogitsProcessors], return None +class ResponsesRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/responses/create + background: Optional[bool] = False + include: Optional[list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ]] = None + input: Union[str, ResponseInputParam] + instructions: Optional[str] = None + max_output_tokens: Optional[int] = None + max_tool_calls: Optional[int] = None + metadata: Optional[Metadata] = None + model: Optional[str] = None + parallel_tool_calls: Optional[bool] = True + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", + "priority"] = "auto" + store: Optional[bool] = True + stream: Optional[bool] = False + temperature: Optional[float] = None + text: Optional[ResponseTextConfig] = None + tool_choice: ToolChoice = "auto" + tools: list[Tool] = Field(default_factory=list) + top_logprobs: Optional[int] = 0 + top_p: Optional[float] = None + truncation: Optional[Literal["auto", "disabled"]] = "disabled" + user: Optional[str] = None + + # --8<-- [start:responses-extra-params] + request_id: str = Field( + default_factory=lambda: f"resp_{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + # --8<-- [end:responses-extra-params] + + _DEFAULT_SAMPLING_PARAMS = { + "temperature": 1.0, + "top_p": 1.0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + if self.max_output_tokens is None: + max_tokens = default_max_tokens + else: + max_tokens = min(self.max_output_tokens, default_max_tokens) + + default_sampling_params = default_sampling_params or {} + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + + # Structured output + guided_decoding = None + if self.text is not None and self.text.format is not None: + response_format = self.text.format + if response_format.type == "json_schema": + guided_decoding = GuidedDecodingParams.from_optional( + json=response_format.schema_) + elif response_format.type == "json_object": + raise NotImplementedError("json_object is not supported") + + # TODO: add more parameters + return SamplingParams.from_optional( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + logprobs=self.top_logprobs, + output_kind=(RequestOutputKind.DELTA + if self.stream else RequestOutputKind.FINAL_ONLY), + guided_decoding=guided_decoding, + ) + + @model_validator(mode="before") + def validate_background(cls, data): + if not data.get("background"): + return data + if not data.get("store", True): + raise ValueError( + "background can only be used when `store` is true") + return data + + @model_validator(mode="before") + def validate_prompt(cls, data): + if data.get("prompt") is not None: + raise ValueError("prompt template is not supported") + return data + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -1473,6 +1597,83 @@ class TranscriptionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class ResponseReasoningItem(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"rs_{random_uuid()}") + text: str + summary: list = Field(default_factory=list) + type: Literal["reasoning"] = "reasoning" + encrypted_content: Optional[str] = None + status: Optional[Literal["in_progress", "completed", "incomplete"]] + + +class ResponsesResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") + created_at: int = Field(default_factory=lambda: int(time.time())) + # error: Optional[ResponseError] = None + # incomplete_details: Optional[IncompleteDetails] = None + instructions: Optional[str] = None + metadata: Optional[Metadata] = None + model: str + object: Literal["response"] = "response" + output: list[Union[ResponseOutputMessage, ResponseReasoningItem]] + parallel_tool_calls: bool + temperature: float + tool_choice: ToolChoice + tools: list[Tool] + top_p: float + background: bool + max_output_tokens: int + max_tool_calls: Optional[int] = None + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] + status: ResponseStatus + text: Optional[ResponseTextConfig] = None + top_logprobs: int + truncation: Literal["auto", "disabled"] + usage: Optional[UsageInfo] = None + user: Optional[str] = None + + @classmethod + def from_request( + cls, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + created_time: int, + output: list[ResponseOutputItem], + status: ResponseStatus, + usage: Optional[UsageInfo] = None, + ) -> "ResponsesResponse": + return cls( + id=request.request_id, + created_at=created_time, + instructions=request.instructions, + metadata=request.metadata, + model=model_name, + output=output, + parallel_tool_calls=request.parallel_tool_calls, + temperature=sampling_params.temperature, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=sampling_params.top_p, + background=request.background, + max_output_tokens=sampling_params.max_tokens, + max_tool_calls=request.max_tool_calls, + previous_response_id=request.previous_response_id, + prompt=request.prompt, + reasoning=request.reasoning, + service_tier=request.service_tier, + status=status, + text=request.text, + top_logprobs=sampling_params.logprobs, + truncation=request.truncation, + user=request.user, + usage=usage, + ) + + BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest] diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index cf2b738ba55e..c4ebb7141d09 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -53,7 +53,8 @@ EmbeddingRequest, EmbeddingResponse, ErrorResponse, PoolingResponse, RerankRequest, - ScoreRequest, ScoreResponse, + ResponsesRequest, ScoreRequest, + ScoreResponse, TokenizeChatRequest, TokenizeCompletionRequest, TokenizeResponse, @@ -91,7 +92,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, + ResponsesRequest] AnyResponse = Union[ CompletionResponse, @@ -762,7 +764,7 @@ async def _preprocess_completion( async def _preprocess_chat( self, - request: ChatLikeRequest, + request: Union[ChatLikeRequest, ResponsesRequest], tokenizer: AnyTokenizer, messages: list[ChatCompletionMessageParam], chat_template: Optional[str], diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py new file mode 100644 index 000000000000..ac2b3dfafec3 --- /dev/null +++ b/vllm/entrypoints/openai/serving_responses.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from collections.abc import AsyncGenerator, AsyncIterator +from http import HTTPStatus +from typing import Callable, Final, Optional, Union + +import jinja2 +from fastapi import Request +from openai.types.responses import ResponseOutputMessage, ResponseOutputText + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption) +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ErrorResponse, + PromptTokenUsageInfo, + RequestResponseMetadata, + ResponseReasoningItem, + ResponsesRequest, + ResponsesResponse, UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +class OpenAIServingResponses(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + reasoning_parser: str = "", + enable_auto_tools: bool = False, + expand_tools_even_if_tool_choice_none: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + ) -> None: + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + ) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + self.reasoning_parser: Optional[Callable[[AnyTokenizer], + ReasoningParser]] = None + if reasoning_parser: + try: + self.reasoning_parser = ( + ReasoningParserManager.get_reasoning_parser( + reasoning_parser)) + assert self.reasoning_parser is not None + except Exception as e: + raise TypeError( + f"{reasoning_parser=} has not been registered") from e + + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_force_include_usage = enable_force_include_usage + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + source = self.model_config.generation_config + source = "model" if source == "auto" else source + logger.info("Using default chat sampling params from %s: %s", + source, self.default_sampling_params) + + # HACK(woosuk): This is a hack. We should use a better store. + # FIXME: This causes a memory leak since we never remove responses + # from the store. + self.response_store: dict[str, ResponsesResponse] = {} + self.response_store_lock = asyncio.Lock() + + # HACK(woosuk): This is a hack. We should use a better store. + # FIXME: This causes a memory leak since we never remove messages + # from the store. + self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} + + self.background_tasks: dict[str, asyncio.Task] = {} + + async def create_responses( + self, + request: ResponsesRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Handle the previous response ID. + prev_response_id = request.previous_response_id + if prev_response_id is not None: + if not prev_response_id.startswith("resp_"): + return self._make_invalid_id_error(prev_response_id) + async with self.response_store_lock: + prev_response = self.response_store.get(prev_response_id) + if prev_response is None: + return self._make_not_found_error(prev_response_id) + else: + prev_response = None + # Construct the input messages. + messages = self._construct_input_messages(request, prev_response) + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + model_name = self._get_model_name(request.model, lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + _, request_prompts, engine_prompts = await self._preprocess_chat( + request, + tokenizer, + messages, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + except (ValueError, TypeError, RuntimeError, + jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(f"{e} {e.__cause__}") + + request_metadata = RequestResponseMetadata( + request_id=request.request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + self._log_inputs(request.request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request.request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages + + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response + + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) + + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response + + if request.stream: + raise NotImplementedError("Streaming responses are not supported") + + try: + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + model_name, + tokenizer, + request_metadata, + ) + except Exception as e: + return self.create_error_response(str(e)) + + async def responses_full_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[RequestOutput], + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> Union[ErrorResponse, ResponsesResponse]: + if created_time is None: + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + assert len(final_res.outputs) == 1 + final_output = final_res.outputs[0] + + if self.reasoning_parser: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content(final_output.text, + request=request)) + else: + reasoning_content = None + content = final_output.text + + output = [] + if reasoning_content: + reasoning_item = ResponseReasoningItem( + text=reasoning_content, + status=None, # NOTE: Only the last output item has status. + ) + output.append(reasoning_item) + if content: + output_text = ResponseOutputText( + text=content, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + message = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + output.append(message) + + # Calculate usage. + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = len(final_output.token_ids) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + request_metadata.final_usage_info = usage + + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=output, + status="completed", + usage=usage, + ) + + if request.store: + async with self.response_store_lock: + stored_response = self.response_store.get(response.id) + # If the response is already cancelled, don't update it. + if (stored_response is None + or stored_response.status != "cancelled"): + self.response_store[response.id] = response + return response + + def _construct_input_messages( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse] = None, + ) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] + if request.instructions: + messages.append({ + "role": "system", + "content": request.instructions, + }) + + # Prepend the conversation history. + if prev_response is not None: + # Add the previous messages. + prev_msg = self.msg_store[prev_response.id] + messages.extend(prev_msg) + + # Add the previous output. + for output_item in prev_response.output: + # NOTE: We skip the reasoning output. + if isinstance(output_item, ResponseOutputMessage): + for content in output_item.content: + messages.append({ + "role": "assistant", + "content": content.text, + }) + + # Append the new input. + # Reponses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + else: + messages.extend(request.input) # type: ignore + return messages + + async def _run_background_request( + self, + request: ResponsesRequest, + *args, + **kwargs, + ): + try: + response = await self.responses_full_generator( + request, *args, **kwargs) + except Exception as e: + logger.exception("Background request failed for %s", + request.request_id) + response = self.create_error_response(str(e)) + + if isinstance(response, ErrorResponse): + # If the request has failed, update the status to "failed". + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + + async def retrieve_responses( + self, + response_id: str, + ) -> Union[ErrorResponse, ResponsesResponse]: + if not response_id.startswith("resp_"): + return self._make_invalid_id_error(response_id) + + async with self.response_store_lock: + response = self.response_store.get(response_id) + + if response is None: + return self._make_not_found_error(response_id) + return response + + async def cancel_responses( + self, + response_id: str, + ) -> Union[ErrorResponse, ResponsesResponse]: + if not response_id.startswith("resp_"): + return self._make_invalid_id_error(response_id) + + async with self.response_store_lock: + response = self.response_store.get(response_id) + if response is None: + return self._make_not_found_error(response_id) + + prev_status = response.status + if prev_status not in ("queued", "in_progress"): + return self.create_error_response( + err_type="invalid_request_error", + message="Cannot cancel a synchronous response.", + ) + + # Update the status to "cancelled". + response.status = "cancelled" + + # Abort the request. + if (task := self.background_tasks.get(response_id)): + task.cancel() + try: + await task + except asyncio.CancelledError: + logger.exception("Background task for %s was cancelled", + response_id) + return response + + def _make_invalid_id_error(self, response_id: str) -> ErrorResponse: + return self.create_error_response( + err_type="invalid_request_error", + message=(f"Invalid 'response_id': '{response_id}'. " + "Expected an ID that begins with 'resp'."), + ) + + def _make_not_found_error(self, response_id: str) -> ErrorResponse: + return self.create_error_response( + err_type="invalid_request_error", + message=f"Response with id '{response_id}' not found.", + status_code=HTTPStatus.NOT_FOUND, + ) diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index e827d381ca1d..c34189013d99 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -10,7 +10,7 @@ from typing import Callable, Optional, Union from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) + DeltaMessage, ResponsesRequest) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import import_from_path, is_list_of @@ -66,7 +66,9 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: @abstractmethod def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, + model_output: str, + request: Union[ChatCompletionRequest, ResponsesRequest], ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from a complete model-generated string. From 47db8c2c15209ca03dc57422d98518ca0199e657 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:42:06 +0800 Subject: [PATCH 02/27] [Misc] add a tip for pre-commit (#20536) Signed-off-by: reidliu41 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d962252eb3dd..720c06acf144 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -170,7 +170,7 @@ repos: # Keep `suggestion` last - id: suggestion name: Suggestion - entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' + entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' language: system verbose: true pass_filenames: false From 6e2c19ce227ecf285ed24a138b91570b3a2d57a6 Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Mon, 7 Jul 2025 12:32:32 +0800 Subject: [PATCH 03/27] [Refactor]Abstract Platform Interface for Distributed Backend and Add xccl Support for Intel XPU (#19410) Signed-off-by: dbyoung18 Signed-off-by: Kunshang Ji Co-authored-by: Kunshang Ji --- docs/getting_started/installation/gpu/xpu.inc.md | 5 +++++ vllm/platforms/__init__.py | 13 +++++++++++-- vllm/platforms/cpu.py | 1 + vllm/platforms/cuda.py | 1 + vllm/platforms/hpu.py | 1 + vllm/platforms/interface.py | 3 +++ vllm/platforms/neuron.py | 1 + vllm/platforms/rocm.py | 1 + vllm/platforms/tpu.py | 1 + vllm/platforms/xpu.py | 1 + vllm/utils/__init__.py | 6 ++++++ vllm/v1/worker/cpu_worker.py | 4 +++- vllm/v1/worker/gpu_worker.py | 3 ++- vllm/v1/worker/tpu_worker.py | 3 ++- vllm/worker/hpu_worker.py | 3 ++- vllm/worker/neuron_worker.py | 2 +- vllm/worker/worker.py | 3 ++- 17 files changed, 44 insertions(+), 8 deletions(-) diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index 4469be36c007..1514a0c2d3cd 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -81,4 +81,9 @@ python -m vllm.entrypoints.openai.api_server \ By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the helper script. # --8<-- [end:supported-features] +# --8<-- [start:distributed-backend] + +XPU platform uses **torch-ccl** for torch<2.8 and **xccl** for torch>=2.8 as distributed backend, since torch 2.8 supports **xccl** as built-in backend for XPU. + +# --8<-- [end:distributed-backend] # --8<-- [end:extra-information] diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 13453d2c4b4b..7b8953fd75bb 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional from vllm.plugins import load_plugins_by_group -from vllm.utils import resolve_obj_by_qualname +from vllm.utils import resolve_obj_by_qualname, supports_xccl from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum @@ -139,10 +139,19 @@ def xpu_platform_plugin() -> Optional[str]: try: # installed IPEX if the machine has XPUs. import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 import torch + if supports_xccl(): + dist_backend = "xccl" + else: + dist_backend = "ccl" + import oneccl_bindings_for_pytorch # noqa: F401 + if hasattr(torch, 'xpu') and torch.xpu.is_available(): is_xpu = True + from vllm.platforms.xpu import XPUPlatform + XPUPlatform.dist_backend = dist_backend + logger.debug("Confirmed %s backend is available.", + XPUPlatform.dist_backend) logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 1050d3c59344..676a440a79db 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -37,6 +37,7 @@ class CpuPlatform(Platform): device_name: str = "cpu" device_type: str = "cpu" dispatch_key: str = "CPU" + dist_backend: str = "gloo" @property def supported_dtypes(self) -> list[torch.dtype]: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0a5f4004e448..50eedfa3c412 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -56,6 +56,7 @@ class CudaPlatformBase(Platform): device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" + dist_backend: str = "nccl" device_control_env_var: str = "CUDA_VISIBLE_DEVICES" @property diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 3cf28950190c..0b1e2f232790 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -26,6 +26,7 @@ class HpuPlatform(Platform): device_type: str = "hpu" dispatch_key: str = "HPU" ray_device_key: str = "HPU" + dist_backend: str = "hccl" device_control_env_var: str = "HABANA_VISIBLE_MODULES" @classmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 567d5cbf503f..b0ef9905481b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -129,6 +129,9 @@ class Platform: # compilation strategy. simple_compile_backend: str = "inductor" + # The backend used for distributed communication. + dist_backend: str = "" + supported_quantization: list[str] = [] additional_env_vars: list[str] = [] diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 04e918d7aebe..cb8ac8db669f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -30,6 +30,7 @@ class NeuronPlatform(Platform): device_type: str = "neuron" ray_device_key: str = "neuron_cores" supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] + dist_backend: str = "gloo" device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4550ef570684..31f4699cd1b0 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -164,6 +164,7 @@ class RocmPlatform(Platform): device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" + dist_backend: str = "nccl" # rocm shares the same device control env var as CUDA device_control_env_var: str = "CUDA_VISIBLE_DEVICES" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index a8c8cb46de2c..6810944c848d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -31,6 +31,7 @@ class TpuPlatform(Platform): device_type: str = "tpu" dispatch_key: str = "XLA" ray_device_key: str = "TPU" + dist_backend: str = "gloo" device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 5bd34033233a..de715fd894c3 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -29,6 +29,7 @@ class XPUPlatform(Platform): # Intel XPU's device key is "GPU" for Ray. # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501 ray_device_key: str = "GPU" + dist_backend: str = "ccl" # ccl | xccl device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR" @classmethod diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9550b056fbba..9322e3cc477a 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1886,6 +1886,12 @@ def supports_dynamo() -> bool: return base_torch_version >= Version("2.4.0") +# Supports xccl with PyTorch versions >= 2.8.0 for XPU platform +def supports_xccl() -> bool: + return is_torch_equal_or_newer( + "2.8.0") and torch.distributed.is_xccl_available() + + # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index de575d604055..7712b7974544 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -11,6 +11,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput @@ -58,7 +59,8 @@ def init_device(self): # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank, "gloo") + self.local_rank, + current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9e7e44d06861..d1df0fd959b5 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -157,7 +157,8 @@ def init_device(self): # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank) + self.local_rank, + current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index a64ce881fe31..ade4d0821168 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput @@ -300,7 +301,7 @@ def _init_tpu_worker_distributed_environment( rank=rank, local_rank=local_rank, distributed_init_method=distributed_init_method, - backend="gloo", + backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 6d76ea499a90..560110df0a32 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -23,6 +23,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.utils import bind_kv_cache @@ -413,7 +414,7 @@ def init_worker_distributed_environment( rank, distributed_init_method, local_rank, - backend='hccl') + backend=current_platform.dist_backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 662bde6bc07b..4e1408300fb8 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -156,7 +156,7 @@ def init_distributed_environment(self): rank=self.rank, local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, - backend="gloo", + backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9a928632688a..21e684a3fb5a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -530,7 +530,8 @@ def init_worker_distributed_environment( set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) + distributed_init_method, local_rank, + current_platform.dist_backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) From 2e610deb72dfc1e34b904d9b6b02c85eefa451d2 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 7 Jul 2025 13:10:41 +0800 Subject: [PATCH 04/27] [CI/Build] Enable phi2 lora test (#20540) Signed-off-by: Jee Jee Li --- tests/lora/test_phi.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 9d75512a248b..3090941e6367 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - import vllm from vllm.lora.request import LoRARequest @@ -49,9 +47,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: return generated_texts -# Skipping for V1 for now as we are hitting, -# "Head size 80 is not supported by FlashAttention." error. -@pytest.mark.skip(reason="Head size 80 is not supported by FlashAttention") def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. From 2c5ebec064bf3684c8f02b70b5963615daa81b28 Mon Sep 17 00:00:00 2001 From: Liangliang Ma Date: Mon, 7 Jul 2025 16:16:40 +0800 Subject: [PATCH 05/27] [XPU][CI] add v1/core test in xpu hardware ci (#20537) Signed-off-by: Ma, Liangliang --- .buildkite/scripts/hardware_ci/run-xpu-test.sh | 6 ++++-- docker/Dockerfile.xpu | 2 +- vllm/platforms/xpu.py | 6 +----- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index cf3aaab8493b..a23abdc1ed6c 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -11,8 +11,8 @@ container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head docker build -t ${image_name} -f docker/Dockerfile.xpu . # Setup cleanup -remove_docker_container() { - docker rm -f "${container_name}" || true; +remove_docker_container() { + docker rm -f "${container_name}" || true; docker image rm -f "${image_name}" || true; docker system prune -f || true; } @@ -27,4 +27,6 @@ docker run \ "${image_name}" \ sh -c ' VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + cd tests + pytest -v -s v1/core ' diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 466ba9833363..41b4c42e4c4b 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' + pip install accelerate hf_transfer pytest 'modelscope!=1.15.0' ENV VLLM_USAGE_SOURCE production-docker-image \ TRITON_XPU_PROFILE 1 diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index de715fd894c3..39828d321ede 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -93,10 +93,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "mode.") model_config.enforce_eager = True - if vllm_config.speculative_config is not None: - raise NotImplementedError( - "XPU does not support speculative decoding") - if vllm_config.device_config is not None: assert vllm_config.device_config.device_type == "xpu" @@ -181,4 +177,4 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: @classmethod def device_count(cls) -> int: - return torch.xpu.device_count() \ No newline at end of file + return torch.xpu.device_count() From 1fd471e957526a34a0cb4b60d2e830cd6ca79fdc Mon Sep 17 00:00:00 2001 From: Michael Yao Date: Mon, 7 Jul 2025 16:31:49 +0800 Subject: [PATCH 06/27] Add docstrings to url_schemes.py to improve readability (#20545) Signed-off-by: windsonsea --- docs/mkdocs/hooks/url_schemes.py | 70 +++++++++++++++++++++++++++++++- 1 file changed, 69 insertions(+), 1 deletion(-) diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py index 6484581ed947..6fce6bd8130e 100644 --- a/docs/mkdocs/hooks/url_schemes.py +++ b/docs/mkdocs/hooks/url_schemes.py @@ -1,5 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This is basically a port of MyST parser’s external URL resolution mechanism +(https://myst-parser.readthedocs.io/en/latest/syntax/cross-referencing.html#customising-external-url-resolution) +to work with MkDocs. + +It allows Markdown authors to use GitHub shorthand links like: + + - [Text](gh-issue:123) + - + - [File](gh-file:path/to/file.py#L10) + +These are automatically rewritten into fully qualified GitHub URLs pointing to +issues, pull requests, files, directories, or projects in the +`vllm-project/vllm` repository. + +The goal is to simplify cross-referencing common GitHub resources +in project docs. +""" + import regex as re from mkdocs.config.defaults import MkDocsConfig from mkdocs.structure.files import Files @@ -7,11 +26,42 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, - files: Files): + files: Files) -> str: + """ + Custom MkDocs plugin hook to rewrite special GitHub reference links + in Markdown. + + This function scans the given Markdown content for specially formatted + GitHub shorthand links, such as: + - `[Link text](gh-issue:123)` + - `` + + And rewrites them into fully-qualified GitHub URLs with GitHub icons: + - `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)` + - `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)` + + Supported shorthand types: + - `gh-issue` + - `gh-pr` + - `gh-project` + - `gh-dir` + - `gh-file` + + Args: + markdown (str): The raw Markdown content of the page. + page (Page): The MkDocs page object being processed. + config (MkDocsConfig): The MkDocs site configuration. + files (Files): The collection of files in the MkDocs build. + + Returns: + str: The updated Markdown content with GitHub shorthand links replaced. + """ gh_icon = ":octicons-mark-github-16:" gh_url = "https://github.com" repo_url = f"{gh_url}/vllm-project/vllm" org_url = f"{gh_url}/orgs/vllm-project" + + # Mapping of shorthand types to their corresponding GitHub base URLs urls = { "issue": f"{repo_url}/issues", "pr": f"{repo_url}/pull", @@ -19,6 +69,8 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, "dir": f"{repo_url}/tree/main", "file": f"{repo_url}/blob/main", } + + # Default title prefixes for auto links titles = { "issue": "Issue #", "pr": "Pull Request #", @@ -27,11 +79,19 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, "file": "", } + # Regular expression to match GitHub shorthand links scheme = r"gh-(?P.+?):(?P.+?)(#(?P.+?))?" inline_link = re.compile(r"\[(?P[^\[]+?)\]\(" + scheme + r"\)") auto_link = re.compile(f"<{scheme}>") def replace_inline_link(match: re.Match) -> str: + """ + Replaces a matched inline-style GitHub shorthand link + with a full Markdown link. + + Example: + [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) + """ url = f'{urls[match.group("type")]}/{match.group("path")}' if fragment := match.group("fragment"): url += f"#{fragment}" @@ -39,6 +99,13 @@ def replace_inline_link(match: re.Match) -> str: return f'[{gh_icon} {match.group("title")}]({url})' def replace_auto_link(match: re.Match) -> str: + """ + Replaces a matched autolink-style GitHub shorthand + with a full Markdown link. + + Example: + <gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456) + """ type = match.group("type") path = match.group("path") title = f"{titles[type]}{path}" @@ -48,6 +115,7 @@ def replace_auto_link(match: re.Match) -> str: return f"[{gh_icon} {title}]({url})" + # Replace both inline and autolinks markdown = inline_link.sub(replace_inline_link, markdown) markdown = auto_link.sub(replace_auto_link, markdown) From 3112271f6e5d50b3d94a2efa88a5a8e77826b897 Mon Sep 17 00:00:00 2001 From: Yan Ma <yan.ma@intel.com> Date: Mon, 7 Jul 2025 16:38:22 +0800 Subject: [PATCH 07/27] [XPU] log clean up for XPU platform (#20553) Signed-off-by: yan <yan.ma@intel.com> --- vllm/_custom_ops.py | 3 ++- vllm/platforms/xpu.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eb9d0b405892..92db27f5b8dc 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -13,7 +13,8 @@ logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_hpu(): +if not current_platform.is_tpu() and not current_platform.is_hpu()\ + and not current_platform.is_xpu(): try: import vllm._C except ImportError as e: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 39828d321ede..e2871c106492 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -37,7 +37,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool) -> str: - if selected_backend != _Backend.IPEX: + if selected_backend is not None and selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) use_v1 = envs.VLLM_USE_V1 if not use_v1: @@ -133,8 +133,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: @classmethod def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on XPU.") - return False + return True @classmethod def get_current_memory_usage(cls, From eb0b2d2f08b622f4b93fb0a811a047ad987a46ca Mon Sep 17 00:00:00 2001 From: Michael Yao <haifeng.yao@daocloud.io> Date: Mon, 7 Jul 2025 16:46:31 +0800 Subject: [PATCH 08/27] [Docs] Clean up tables in supported_models.md (#20552) Signed-off-by: windsonsea <haifeng.yao@daocloud.io> --- docs/models/supported_models.md | 320 ++++++++++++++++---------------- 1 file changed, 160 insertions(+), 160 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7ec91df98b28..422c406d5f31 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -314,85 +314,85 @@ See [this page][generative-models] for more information on how to use generative Specified using `--task generate`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|---------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | -| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | -| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | -| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | -| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | ✅︎ | ✅︎ | -| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst` etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`,etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. | | ✅︎ | ✅︎ | -| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | -| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | -| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | -| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | -| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | -| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | -| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | | -| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ | -| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | -| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | -| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | -| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | -| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | -| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | -| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | -| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | -| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | -| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | -| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | -| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | ✅︎ | -| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ | ✅︎ | -| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | -| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | -| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | | -| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | +| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | +| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | +| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | +| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | +| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | +| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | | ✅︎ | ✅︎ | +| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. | | ✅︎ | ✅︎ | +| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | +| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | +| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | +| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | +| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | +| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | +| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | | +| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | | | ✅︎ | +| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | +| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | +| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | +| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | +| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | +| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | +| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | +| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | +| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | ✅︎ | +| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ | ✅︎ | +| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | +| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | +| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. @@ -412,19 +412,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling Specified using `--task embed`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|-----------------------| -| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | -| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | | | -| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ | | -| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ | | -| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ | | -| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | +| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | +| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | +| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | +| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | +| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. @@ -448,12 +448,12 @@ of the whole prompt are extracted from the normalized hidden state corresponding Specified using `--task reward`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|---------------------------|-----------------|------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | If your model is not in the above list, we will try to automatically convert the model using [as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. @@ -466,10 +466,10 @@ If your model is not in the above list, we will try to automatically convert the Specified using `--task classify`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|----------------------------------|----------|----------------------------------------|------------------------|-----------------------------|-----------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | -| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | +| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | If your model is not in the above list, we will try to automatically convert the model using [as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. @@ -478,13 +478,13 @@ If your model is not in the above list, we will try to automatically convert the Specified using `--task score`. -| Architecture | Models | Example HF Models | [V1](gh-issue:8779) | -|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | -| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | -| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | +| Architecture | Models | Example HF Models | [V1](gh-issue:8779) | +|--------------|--------|-------------------|---------------------| +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | +| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | +| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | !!! note Load the official original `mxbai-rerank-v2` by using the following command. @@ -555,50 +555,50 @@ See [this page][generative-models] for more information on how to use generative Specified using `--task generate`. -| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|----------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ | -| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | -| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | -| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | | ✅︎ | ✅︎ | -| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | | ✅︎ | ✅︎ | -| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | | -| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | -| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | -| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | -| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | -| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | -| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ | -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | -| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | -| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | -| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | -| `Mistral3ForConditionalGeneration` | Mistral3 | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | -| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | -| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | -| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | -| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | -| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | -| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | -| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | -| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | -| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | -| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | -| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | +| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | +| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | +| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | +| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | +| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | +| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | +| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | +| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | +| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | +| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | +| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | +| `Mistral3ForConditionalGeneration` | Mistral3 | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | +| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | +| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | +| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | +| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | +| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | +| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | +| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | +| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | +| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | <sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: @@ -677,9 +677,9 @@ Specified using `--task transcription`. Speech2Text models trained specifically for Automatic Speech Recognition. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|----------------------------------------------|------------------|------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | ### Pooling Models @@ -700,10 +700,10 @@ Any text generation model can be converted into an embedding model by passing `- The following table lists those that are tested in vLLM. -| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------| -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | -| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | | +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | +| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | | --- From 448acad31eae34508a7e0ed0877b95dad8df8bb9 Mon Sep 17 00:00:00 2001 From: Abirdcfly <fp544037857@gmail.com> Date: Mon, 7 Jul 2025 17:14:12 +0800 Subject: [PATCH 09/27] [Misc] remove unused jinaai_serving_reranking (#18878) Signed-off-by: Abirdcfly <fp544037857@gmail.com> --- vllm/entrypoints/openai/api_server.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d3b1a3802bba..e3285a9bf76d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1409,11 +1409,6 @@ async def init_app_state( enable_serving_reranking = (model_config.task == "classify" and getattr( model_config.hf_config, "num_labels", 0) == 1) - state.jinaai_serving_reranking = ServingScores( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger) if enable_serving_reranking else None state.openai_serving_scores = ServingScores( engine_client, model_config, From 4ff79a136ec466684e74502057acba578cfe947c Mon Sep 17 00:00:00 2001 From: Jee Jee Li <pandaleefree@gmail.com> Date: Mon, 7 Jul 2025 17:15:26 +0800 Subject: [PATCH 10/27] [Misc] Set the minimum openai version (#20539) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> --- requirements/common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/common.txt b/requirements/common.txt index 8bc0be7779af..90946df00d5d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -13,7 +13,7 @@ tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp -openai >= 1.52.0, <= 1.90.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) +openai >= 1.87.0, <= 1.90.0 # Ensure modern openai package (ensure ResponsePrompt exists in type.responses and max_completion_tokens field support) pydantic >= 2.10 prometheus_client >= 0.18.0 pillow # Required for image processing From 6e4bef1bea89c06100699bad4d4ad27ef0519e7f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 7 Jul 2025 11:35:47 +0100 Subject: [PATCH 11/27] [Doc] Remove extra whitespace from CI failures doc (#20565) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/contributing/ci-failures.md | 40 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/contributing/ci-failures.md b/docs/contributing/ci-failures.md index 7caaf10ceb5c..573efb3b05f6 100644 --- a/docs/contributing/ci-failures.md +++ b/docs/contributing/ci-failures.md @@ -6,9 +6,9 @@ the failure? - Check the dashboard of current CI test failures: 👉 [CI Failures Dashboard](https://github.com/orgs/vllm-project/projects/20) -- If your failure **is already listed**, it's likely unrelated to your PR. - Help fixing it is always welcome! - - Leave comments with links to additional instances of the failure. +- If your failure **is already listed**, it's likely unrelated to your PR. + Help fixing it is always welcome! + - Leave comments with links to additional instances of the failure. - React with a 👍 to signal how many are affected. - If your failure **is not listed**, you should **file an issue**. @@ -19,25 +19,25 @@ the failure? 👉 [New CI Failure Report](https://github.com/vllm-project/vllm/issues/new?template=450-ci-failure.yml) - **Use this title format:** - + ``` [CI Failure]: failing-test-job - regex/matching/failing:test ``` - **For the environment field:** - + ``` Still failing on main as of commit abcdef123 ``` - **In the description, include failing tests:** - + ``` - FAILED failing/test.py:failing_test1 - Failure description - FAILED failing/test.py:failing_test2 - Failure description - https://github.com/orgs/vllm-project/projects/20 - https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml - FAILED failing/test.py:failing_test3 - Failure description + FAILED failing/test.py:failing_test1 - Failure description + FAILED failing/test.py:failing_test2 - Failure description + https://github.com/orgs/vllm-project/projects/20 + https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml + FAILED failing/test.py:failing_test3 - Failure description ``` - **Attach logs** (collapsible section example): @@ -45,17 +45,17 @@ the failure? <summary>Logs:</summary> ```text - ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data + ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data --- Logging error --- Traceback (most recent call last): File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 203, in execute_model - return self.model_executor.execute_model(scheduler_output) + return self.model_executor.execute_model(scheduler_output) ... - FAILED failing/test.py:failing_test1 - Failure description - FAILED failing/test.py:failing_test2 - Failure description - FAILED failing/test.py:failing_test3 - Failure description + FAILED failing/test.py:failing_test1 - Failure description + FAILED failing/test.py:failing_test2 - Failure description + FAILED failing/test.py:failing_test3 - Failure description ``` - + </details> ## Logs Wrangling @@ -78,7 +78,7 @@ tail -525 ci_build.log | wl-copy ## Investigating a CI Test Failure -1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main) +1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main) 2. Bisect to find the first build that shows the issue. 3. Add your findings to the GitHub issue. 4. If you find a strong candidate PR, mention it in the issue and ping contributors. @@ -97,9 +97,9 @@ CI test failures may be flaky. Use a bash loop to run repeatedly: If you submit a PR to fix a CI failure: -- Link the PR to the issue: +- Link the PR to the issue: Add `Closes #12345` to the PR description. -- Add the `ci-failure` label: +- Add the `ci-failure` label: This helps track it in the [CI Failures GitHub Project](https://github.com/orgs/vllm-project/projects/20). ## Other Resources From 45877ef740e00cbb2dbe9fd7edc84638adc13037 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 7 Jul 2025 11:54:22 +0100 Subject: [PATCH 12/27] [Doc] Use `gh-pr` and `gh-issue` everywhere we can in the docs (#20564) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/ci/update_pytorch_version.md | 12 +++++------- docs/features/spec_decode.md | 6 +++--- docs/usage/troubleshooting.md | 4 ++-- docs/usage/v1_guide.md | 24 ++++++++++++------------ 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/docs/ci/update_pytorch_version.md b/docs/ci/update_pytorch_version.md index 69fdc82ef971..eb8f19455791 100644 --- a/docs/ci/update_pytorch_version.md +++ b/docs/ci/update_pytorch_version.md @@ -7,9 +7,8 @@ release in CI/CD. It is standard practice to submit a PR to update the PyTorch version as early as possible when a new [PyTorch stable release](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-cadence) becomes available. This process is non-trivial due to the gap between PyTorch -releases. Using [#16859](https://github.com/vllm-project/vllm/pull/16859) as -an example, this document outlines common steps to achieve this update along with -a list of potential issues and how to address them. +releases. Using <gh-pr:16859> as an example, this document outlines common steps to achieve this +update along with a list of potential issues and how to address them. ## Test PyTorch release candidates (RCs) @@ -68,7 +67,7 @@ and timeout. Additionally, since vLLM's fastcheck pipeline runs in read-only mod it doesn't populate the cache, so re-running it to warm up the cache is ineffective. -While ongoing efforts like [#17419](https://github.com/vllm-project/vllm/issues/17419) +While ongoing efforts like [#17419](gh-issue:17419) address the long build time at its source, the current workaround is to set VLLM_CI_BRANCH to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) when manually triggering a build on Buildkite. This branch accomplishes two things: @@ -129,6 +128,5 @@ to handle some platforms separately. The separation of requirements and Dockerfi for different platforms in vLLM CI/CD allows us to selectively choose which platforms to update. For instance, updating XPU requires the corresponding release from https://github.com/intel/intel-extension-for-pytorch by Intel. -While https://github.com/vllm-project/vllm/pull/16859 updated vLLM to PyTorch -2.7.0 on CPU, CUDA, and ROCm, https://github.com/vllm-project/vllm/pull/17444 -completed the update for XPU. +While <gh-pr:16859> updated vLLM to PyTorch 2.7.0 on CPU, CUDA, and ROCm, +<gh-pr:17444> completed the update for XPU. diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index abda7db53f91..f28a74ce2262 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -217,8 +217,8 @@ an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https A few important things to consider when using the EAGLE based draft models: 1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should - be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). - If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the + be able to be loaded and used directly by vLLM after <gh-pr:12304>. + If you are using vllm version before <gh-pr:12304>, please use the [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. @@ -228,7 +228,7 @@ A few important things to consider when using the EAGLE based draft models: 3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under - investigation and tracked here: [https://github.com/vllm-project/vllm/issues/9565](https://github.com/vllm-project/vllm/issues/9565). + investigation and tracked here: <gh-issue:9565>. A variety of EAGLE draft models are available on the Hugging Face hub: diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 7f1f76ce3d2e..2b7abc7f46df 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -212,7 +212,7 @@ if __name__ == '__main__': ## `torch.compile` Error -vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](https://github.com/vllm-project/vllm/pull/10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script: +vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](gh-pr:10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script: ??? Code @@ -231,7 +231,7 @@ vLLM heavily depends on `torch.compile` to optimize the model for better perform print(f(x)) ``` -If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See [this issue](https://github.com/vllm-project/vllm/issues/12219) for example. +If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See <gh-issue:12219> for example. ## Model failed to be inspected diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 82a2710d895c..f2a7679f5c51 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -2,7 +2,7 @@ !!! announcement - We have started the process of deprecating V0. Please read [RFC #18571](https://github.com/vllm-project/vllm/issues/18571) for more details. + We have started the process of deprecating V0. Please read [RFC #18571](gh-issue:18571) for more details. V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack). @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | <nobr>🚀 Optimized</nobr> | | **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> | | **Embedding Models** | <nobr>🟢 Functional</nobr> | -| **Mamba Models** | <nobr>🚧 WIP ([PR #19327](https://github.com/vllm-project/vllm/pull/19327))</nobr> | +| **Mamba Models** | <nobr>🚧 WIP (<gh-pr:19327>)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -98,14 +98,14 @@ See below for the status of models that are not yet supported or have more featu The initial basic support is now functional. -Later, we will consider using [hidden states processor](https://github.com/vllm-project/vllm/issues/12249), -which is based on [global logits processor](https://github.com/vllm-project/vllm/pull/13360) +Later, we will consider using [hidden states processor](gh-issue:12249), +which is based on [global logits processor](gh-pr:13360) to enable simultaneous generation and embedding using the same engine instance in V1. #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention (e.g., `MambaForCausalLM`, `JambaForCausalLM`) -will be supported via [PR #19327](https://github.com/vllm-project/vllm/pull/19327). +will be supported via <gh-pr:19327>. #### Encoder-Decoder Models @@ -120,13 +120,13 @@ are not yet supported. | **Chunked Prefill** | <nobr>🚀 Optimized</nobr> | | **LoRA** | <nobr>🚀 Optimized</nobr> | | **Logprobs Calculation** | <nobr>🟢 Functional</nobr> | -| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices ([PR #15191](https://github.com/vllm-project/vllm/pull/15191))</nobr>| +| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices (<gh-pr:15191>)</nobr>| | **Spec Decode** | <nobr>🚀 Optimized</nobr> | -| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](https://github.com/vllm-project/vllm/issues/13414))</nobr>| +| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](gh-issue:13414))</nobr>| | **Structured Output Alternative Backends** | <nobr>🟢 Functional</nobr> | | **Request-level Structured Output Backend** | <nobr>🔴 Deprecated</nobr> | -| **best_of** | <nobr>🔴 Deprecated ([RFC #13361](https://github.com/vllm-project/vllm/issues/13361))</nobr>| -| **Per-Request Logits Processors** | <nobr>🔴 Deprecated ([RFC #13360](https://github.com/vllm-project/vllm/pull/13360))</nobr> | +| **best_of** | <nobr>🔴 Deprecated ([RFC #13361](gh-issue:13361))</nobr>| +| **Per-Request Logits Processors** | <nobr>🔴 Deprecated ([RFC #13360](gh-pr:13360))</nobr> | | **GPU <> CPU KV Cache Swapping** | <nobr>🔴 Deprecated</nobr> | !!! note @@ -153,7 +153,7 @@ Support for logprobs with post-sampling adjustments is in progress and will be a **Prompt Logprobs with Prefix Caching** -Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](https://github.com/vllm-project/vllm/issues/13414). +Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414). #### Deprecated Features @@ -161,11 +161,11 @@ As part of the major architectural rework in vLLM V1, several legacy features ha **Sampling features** -- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361). +- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](gh-issue:13361). - **Per-Request Logits Processors**: In V0, users could pass custom processing functions to adjust logits on a per-request basis. In vLLM V1, this feature has been deprecated. Instead, the design is moving toward supporting **global logits - processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](https://github.com/vllm-project/vllm/pull/13360). + processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](gh-pr:13360). **KV Cache features** From 923147b5e8551887fd64a0fc242c361d5216e1d7 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 7 Jul 2025 12:15:50 +0100 Subject: [PATCH 13/27] [Doc] Fix internal links so they don't always point to latest (#20563) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/features/structured_outputs.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 614b0bfe9679..ea1d09644835 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -157,7 +157,7 @@ As an example, we can use to define a specific format of simplified SQL queries: print(completion.choices[0].message.content) ``` -See also: [full example](https://docs.vllm.ai/en/latest/examples/online_serving/structured_outputs.html) +See also: [full example](../examples/online_serving/structured_outputs.md) ## Reasoning Outputs @@ -200,7 +200,7 @@ Note that you can use reasoning with any provided structured outputs feature. Th print("content: ", completion.choices[0].message.content) ``` -See also: [full example](https://docs.vllm.ai/en/latest/examples/online_serving/structured_outputs.html) +See also: [full example](../examples/online_serving/structured_outputs.md) ## Experimental Automatic Parsing (OpenAI API) @@ -325,4 +325,4 @@ shown below: print(outputs[0].outputs[0].text) ``` -See also: [full example](https://docs.vllm.ai/en/latest/examples/online_serving/structured_outputs.html) +See also: [full example](../examples/online_serving/structured_outputs.md) From b8a498c9b2f3563666e830bf2ad7b9a888c184ed Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:43:26 +0100 Subject: [PATCH 14/27] [Doc] Add outline for content tabs (#20571) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/mkdocs/stylesheets/extra.css | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css index 892013c1cddf..5df9f1344012 100644 --- a/docs/mkdocs/stylesheets/extra.css +++ b/docs/mkdocs/stylesheets/extra.css @@ -143,3 +143,13 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . [data-md-color-scheme="slate"] .logo-light { display: none; } + +/* Outline for content tabs */ +.md-typeset .tabbed-set { + border: 0.075rem solid var(--md-default-fg-color); + border-radius: 0.2rem; +} + +.md-typeset .tabbed-content { + padding: 0 0.6em; +} \ No newline at end of file From 1ad69e8375e841095c2f682299be487fd9b8f47e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 7 Jul 2025 15:44:34 +0100 Subject: [PATCH 15/27] [Doc] Fix some MkDocs snippets used in the installation docs (#20572) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/getting_started/installation/cpu/apple.inc.md | 3 --- docs/getting_started/installation/cpu/arm.inc.md | 3 --- docs/getting_started/installation/cpu/s390x.inc.md | 3 --- docs/getting_started/installation/cpu/x86.inc.md | 3 --- docs/getting_started/installation/gpu.md | 4 ++-- docs/getting_started/installation/gpu/cuda.inc.md | 4 ---- docs/getting_started/installation/gpu/rocm.inc.md | 10 ++++++---- docs/getting_started/installation/gpu/xpu.inc.md | 6 ++---- 8 files changed, 10 insertions(+), 26 deletions(-) diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 1771213f5591..e17823b864ce 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -54,9 +54,6 @@ If the build has error like the following snippet where standard C++ headers can ``` # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] # --8<-- [end:pre-built-images] diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md index 6c05900cf45c..18112243c68f 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -28,9 +28,6 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. Testing has been conducted on AWS Graviton3 instances for compatibility. # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] # --8<-- [end:pre-built-images] diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md index 6c6c40baecec..67b96a8a04fa 100644 --- a/docs/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -56,9 +56,6 @@ Execute the following commands to build and install vLLM from the source. ``` # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] # --8<-- [end:pre-built-images] diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 0412d4ccef00..dc007dcff217 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -31,9 +31,6 @@ vLLM initially supports basic model inferencing and serving on x86 CPU platform, - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable `VLLM_CPU_AVX512BF16=1` before the building. # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] See [https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo) diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md index 1be7557b79e5..e688cefea076 100644 --- a/docs/getting_started/installation/gpu.md +++ b/docs/getting_started/installation/gpu.md @@ -46,11 +46,11 @@ vLLM is a Python library that supports the following GPU variants. Select your G === "AMD ROCm" - There is no extra information on creating a new Python environment for this device. + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:set-up-using-python" === "Intel XPU" - There is no extra information on creating a new Python environment for this device. + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:set-up-using-python" ### Pre-built wheels diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 0417a25f85ad..5ca5296d0a65 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -232,9 +232,6 @@ pip install -e . ``` # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for instructions on using the official Docker image. @@ -261,4 +258,3 @@ See [deployment-docker-build-image-from-source][deployment-docker-build-image-fr See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. # --8<-- [end:supported-features] -# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index aa4cacaf1aed..3765807ba21d 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -2,6 +2,9 @@ vLLM supports AMD GPUs with ROCm 6.3. +!!! tip + [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. + !!! warning There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. @@ -14,6 +17,8 @@ vLLM supports AMD GPUs with ROCm 6.3. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] +There is no extra information on creating a new Python environment for this device. + # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] @@ -123,9 +128,7 @@ Currently, there are no pre-built ROCm wheels. - For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level. For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). -## Set up using Docker (Recommended) - -# --8<-- [end:set-up-using-docker] +# --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] The [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized @@ -227,4 +230,3 @@ Where the `<path/to/model>` is the location where the model is stored, for examp See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. # --8<-- [end:supported-features] -# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index 1514a0c2d3cd..b77c4e00cf0c 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -14,6 +14,8 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] +There is no extra information on creating a new Python environment for this device. + # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] @@ -43,9 +45,6 @@ VLLM_TARGET_DEVICE=xpu python setup.py install type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] Currently, there are no pre-built XPU images. @@ -86,4 +85,3 @@ By default, a ray instance will be launched automatically if no existing one is XPU platform uses **torch-ccl** for torch<2.8 and **xccl** for torch>=2.8 as distributed backend, since torch 2.8 supports **xccl** as built-in backend for XPU. # --8<-- [end:distributed-backend] -# --8<-- [end:extra-information] From 110df74332785ee749af47c5a3eb634d216b8f3b Mon Sep 17 00:00:00 2001 From: "wang.yuqi" <noooop@126.com> Date: Mon, 7 Jul 2025 22:46:04 +0800 Subject: [PATCH 16/27] [Model][Last/4] Automatic conversion of CrossEncoding model (#19675) Signed-off-by: wang.yuqi <noooop@126.com> --- docs/models/supported_models.md | 8 + .../convert_model_to_seq_cls.py | 134 +++++++++++++++++ tests/models/language/pooling/mteb_utils.py | 5 +- .../pooling/test_bge_reranker_v2_gemma.py | 140 ++++++++++++++++++ .../language/pooling/test_mxbai_rerank.py | 2 - tests/models/registry.py | 7 +- vllm/config.py | 6 + vllm/entrypoints/llm.py | 12 +- vllm/entrypoints/openai/serving_score.py | 18 ++- vllm/model_executor/models/adapters.py | 48 ++++++ vllm/model_executor/models/gemma.py | 4 + vllm/model_executor/models/registry.py | 3 +- 12 files changed, 373 insertions(+), 14 deletions(-) create mode 100644 examples/offline_inference/convert_model_to_seq_cls.py create mode 100644 tests/models/language/pooling/test_bge_reranker_v2_gemma.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 422c406d5f31..f427968c8258 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -481,11 +481,19 @@ Specified using `--task score`. | Architecture | Models | Example HF Models | [V1](gh-issue:8779) | |--------------|--------|-------------------|---------------------| | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | +| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | | | `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | | `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | +!!! note + Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command. + + ```bash + vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' + ``` + !!! note Load the official original `mxbai-rerank-v2` by using the following command. diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/convert_model_to_seq_cls.py new file mode 100644 index 000000000000..72356020330f --- /dev/null +++ b/examples/offline_inference/convert_model_to_seq_cls.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import argparse +import json + +import torch +import transformers + +# Usage: +# for BAAI/bge-reranker-v2-gemma +# Caution: "Yes" and "yes" are two different tokens +# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# for mxbai-rerank-v2 +# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls +# for Qwen3-Reranker +# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls + + +def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device): + # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 + assert len(tokens) == 2 + + lm_head_weights = causal_lm.lm_head.weight + + false_id = tokenizer.convert_tokens_to_ids(tokens[0]) + true_id = tokenizer.convert_tokens_to_ids(tokens[1]) + + score_weight = lm_head_weights[true_id].to(device).to( + torch.float32 + ) - lm_head_weights[false_id].to(device).to(torch.float32) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0)) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device): + lm_head_weights = causal_lm.lm_head.weight + + token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] + + score_weight = lm_head_weights[token_ids].to(device) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +method_map = { + function.__name__: function for function in [from_2_way_softmax, no_post_processing] +} + + +def converting( + model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu" +): + assert method in method_map + + if method == "from_2_way_softmax": + assert len(classifier_from_tokens) == 2 + num_labels = 1 + else: + num_labels = len(classifier_from_tokens) + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + causal_lm = transformers.AutoModelForCausalLM.from_pretrained( + model_name, device_map=device + ) + + seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained( + model_name, + num_labels=num_labels, + ignore_mismatched_sizes=True, + device_map=device, + ) + + method_map[method]( + causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device + ) + + # `llm as reranker` defaults to not using pad_token + seq_cls_model.config.use_pad_token = use_pad_token + seq_cls_model.config.pad_token_id = tokenizer.pad_token_id + + seq_cls_model.save_pretrained(path) + tokenizer.save_pretrained(path) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Converting *ForCausalLM models to " + "*ForSequenceClassification models." + ) + parser.add_argument( + "--model_name", + type=str, + default="BAAI/bge-reranker-v2-gemma", + help="Model name", + ) + parser.add_argument( + "--classifier_from_tokens", + type=str, + default='["Yes"]', + help="classifier from tokens", + ) + parser.add_argument( + "--method", type=str, default="no_post_processing", help="Converting converting" + ) + parser.add_argument( + "--use-pad-token", action="store_true", help="Whether to use pad_token" + ) + parser.add_argument( + "--path", + type=str, + default="./bge-reranker-v2-gemma-seq-cls", + help="Path to save converted model", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + converting( + model_name=args.model_name, + classifier_from_tokens=json.loads(args.classifier_from_tokens), + method=args.method, + use_pad_token=args.use_pad_token, + path=args.path, + ) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index a83d25818584..59336c1f7906 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -267,7 +267,8 @@ def mteb_test_rerank_models(hf_runner, vllm_runner, model_info: RerankModelInfo, vllm_extra_kwargs=None, - hf_model_callback=None): + hf_model_callback=None, + vllm_mteb_encoder=VllmMtebEncoder): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -288,7 +289,7 @@ def mteb_test_rerank_models(hf_runner, assert (model_info.architecture in model_config.architectures) assert model_config.hf_config.num_labels == 1 - vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model), + vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py new file mode 100644 index 000000000000..7fa9485dbc7f --- /dev/null +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import numpy as np +import pytest +import torch + +from tests.conftest import HfRunner + +from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, + mteb_test_rerank_models) + +RERANK_MODELS = [ + RerankModelInfo("BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification"), +] + +PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 + + +class GemmaRerankerHfRunner(HfRunner): + + def __init__(self, + model_name: str, + dtype: str = "auto", + *args: Any, + **kwargs: Any) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, + padding_side='left') + self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") + + @torch.no_grad() + def predict(self, prompts: list[list[str]], *args, + **kwargs) -> torch.Tensor: + + def get_inputs(pairs, tokenizer, prompt=None): + if prompt is None: + prompt = PROMPT + + sep = "\n" + prompt_inputs = tokenizer(prompt, + return_tensors=None, + add_special_tokens=False)["input_ids"] + sep_inputs = tokenizer(sep, + return_tensors=None, + add_special_tokens=False)["input_ids"] + inputs = [] + for query, passage in pairs: + query_inputs = tokenizer( + f"A: {query}", + return_tensors=None, + add_special_tokens=False, + truncation=True, + ) + passage_inputs = tokenizer( + f"B: {passage}", + return_tensors=None, + add_special_tokens=False, + truncation=True, + ) + item = tokenizer.prepare_for_model( + [tokenizer.bos_token_id] + query_inputs["input_ids"], + sep_inputs + passage_inputs["input_ids"], + truncation="only_second", + padding=False, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=False, + ) + item["input_ids"] = item[ + "input_ids"] + sep_inputs + prompt_inputs + item["attention_mask"] = [1] * len(item["input_ids"]) + inputs.append(item) + return tokenizer.pad( + inputs, + padding=True, + return_tensors="pt", + ) + + scores = [] + for query, doc, *_ in prompts: + pairs = [(query, doc)] + inputs = get_inputs(pairs, self.tokenizer) + inputs = inputs.to(self.model.device) + _n_tokens = inputs["input_ids"].shape[1] + logits = self.model(**inputs, return_dict=True).logits + _scores = (logits[:, -1, + self.yes_loc].view(-1, ).float().sigmoid()) + scores.append(_scores[0].item()) + return torch.Tensor(scores) + + +class GemmaMtebEncoder(VllmMtebEncoder): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt = PROMPT + self.query_template = "A: {query}\n" + self.document_template = "B: {doc}\n{prompt}" + + def predict( + self, + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + + _sentences = [] + for query, corpus, prompt in sentences: + query = self.query_template.format(query=query) + corpus = self.document_template.format(doc=corpus, prompt=prompt) + _sentences.append((query, corpus, prompt)) + + return super().predict(_sentences, *args, **kwargs) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, + monkeypatch) -> None: + monkeypatch.setenv("VLLM_USE_V1", "0") + + assert model_info.architecture == "GemmaForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "hf_overrides": { + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + } + } + + mteb_test_rerank_models(GemmaRerankerHfRunner, + vllm_runner, + model_info, + vllm_extra_kwargs, + vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index a1293a95bfd5..e74c58744dd2 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -12,11 +12,9 @@ RERANK_MODELS = [ RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", - dtype="float32", enable_test=True), RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", - dtype="float32", enable_test=False) ] diff --git a/tests/models/registry.py b/tests/models/registry.py index aba01cefe993..48302f9d6648 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -319,9 +319,14 @@ def check_available_online( _CROSS_ENCODER_EXAMPLE_MODELS = { # [Text-only] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 + "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, + hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 + "classifier_from_token": ["Yes"], # noqa: E501 + "method": "no_post_processing"}), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 } _MULTIMODAL_EXAMPLE_MODELS = { diff --git a/vllm/config.py b/vllm/config.py index 724f69a3887f..b7ba434db917 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1449,6 +1449,12 @@ def is_matryoshka(self) -> bool: def matryoshka_dimensions(self): return getattr(self.hf_config, "matryoshka_dimensions", None) + @property + def use_pad_token(self) -> bool: + # cross_encoder models defaults to using pad_token. + # `llm as reranker` models defaults to not using pad_token. + return getattr(self.hf_config, "use_pad_token", True) + def get_and_verify_max_len(self, max_model_len: int): # For pooling models, the tokenizer's `model_max_length` is often a # reliable source for the maximum sequence length. However, for diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 6357c2a37c8f..16c051d61de3 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1205,7 +1205,6 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] pooling_params = PoolingParams(use_cross_encoder=True) - tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -1213,9 +1212,14 @@ def _cross_encoding_score( parsed_prompts = [] for q, t in input_pairs: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) + if self.llm_engine.model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + else: + # `llm as reranker` models defaults to not using pad_token. + prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs) engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 328d4ff0e6c0..8b2e3e507c4d 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -167,12 +167,22 @@ async def _cross_encoding_score( executor=self._tokenizer_executor) tokenization_kwargs = tokenization_kwargs or {} - tokenized_prompts = await asyncio.gather( - *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) - for t1, t2 in input_pairs)) + use_pad_token = self.model_config.use_pad_token + + if use_pad_token: + # cross_encoder models defaults to using pad_token. + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) + else: + # `llm as reranker` models defaults to not using pad_token. + tokenized_prompts = await asyncio.gather( + *(tokenize_async(text=t1 + t2, **tokenization_kwargs) + for t1, t2 in input_pairs)) for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if tokenizer.sep_token else '' + sep_token = tokenizer.sep_token if (tokenizer.sep_token + and use_pad_token) else '' request_prompt = f"{t1}{sep_token}{t2}" input_ids = prompt_inputs["input_ids"] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 78d86f6f2044..6584c84436c2 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -312,6 +312,10 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: else: config.num_labels = len(tokens) + # `llm as reranker` defaults to not using pad_token + use_pad_token = getattr(config, "use_pad_token", False) + config.use_pad_token = use_pad_token + def load_weights_using_from_2_way_softmax( model, weights: Iterable[tuple[str, torch.Tensor]]): @@ -356,8 +360,49 @@ def load_weights_using_from_2_way_softmax( return loaded_weights +def load_weights_no_post_processing(model, + weights: Iterable[tuple[str, + torch.Tensor]]): + from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) + from vllm.model_executor.models.utils import AutoWeightsLoader + + model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) + tokens = cast(list[int], tokens) + assert len(tokens) > 0 + + device = model.score.weight.device + + if model.config.tie_word_embeddings: + model.lm_head = model.model.embed_tokens + else: + model.lm_head = ParallelLMHead(model.config.vocab_size, + model.config.hidden_size, + quant_config=model.quant_config) + + loader = AutoWeightsLoader(model) + loaded_weights = loader.load_weights(weights) + + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer(model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) + + token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] + score_weight = model.lm_head.weight.data[token_ids].to(device) + model.score.weight.data.copy_(score_weight) + + del model.lm_head + loaded_weights.add("score.weight") + loaded_weights.discard("lm_head.weight") + return loaded_weights + + SEQ_CLS_LOAD_METHODS = { "from_2_way_softmax": load_weights_using_from_2_way_softmax, + "no_post_processing": load_weights_no_post_processing, } @@ -368,6 +413,9 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): # - Qwen3-Reranker # - Qwen2ForCausalLM # - mxbai-rerank-v2 + # - no_post_processing: + # - GemmaForCausalLM + # - bge-reranker-v2-gemma config = model.vllm_config.model_config.hf_config method = getattr(config, "method", None) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 59c3102add4c..bc8179f886fd 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -43,6 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -425,3 +426,6 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + +GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e377..27d476929855 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -179,8 +179,9 @@ "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] + "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501 "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 - "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 + "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 } _MULTIMODAL_MODELS = { From 16fc78d578e697acdad75b960e8968f143b828ff Mon Sep 17 00:00:00 2001 From: Amog Kamsetty <amogkamsetty@gmail.com> Date: Wed, 26 Mar 2025 21:48:28 +0000 Subject: [PATCH 17/27] wip kaiju Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> --- vllm/model_executor/layers/layernorm.py | 58 ++++ .../model_executor/layers/rotary_embedding.py | 24 ++ vllm/model_executor/models/kaiju.py | 291 ++++++++++++++++++ 3 files changed, 373 insertions(+) create mode 100644 vllm/model_executor/models/kaiju.py diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e8d1fd635505..a8f1aaa08f8e 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -285,3 +285,61 @@ def forward_cuda( self.forward_static) self._is_compiled = True return self.forward_native(x, residual) + + + +@CustomOp.register("kaiju_rms_norm") +class KaijuRMSNorm(CustomOp): + """RMS normalization for Kaiju. + + Differences from standard RMSNorm: + 1. No learnable weight parameter + 2. Clams output to be in range of [-4, 4] + 3. Calculation is done in fp32 and then converted to orig_dtype + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.variance_epsilon = eps + self.hidden_size = hidden_size + + @staticmethod + def forward_static( + variance_epsilon: float, + x: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + input_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + return x.to(input_dtype).clamp(-4, 4) + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.variance_epsilon, x) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> torch.Tensor: + if torch.compiler.is_compiling(): + return self.forward_native(x) + + if not getattr(self, "_is_compiled", False): + self.forward_static = torch.compile( # type: ignore + self.forward_static) + self._is_compiled = True + return self.forward_native(x, residual) + + def extra_repr(self) -> str: + s = f"hidden_size={self.hidden_size}" + s += f", eps={self.variance_epsilon}" + return s \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index a4615132a518..a1b2c8a67b3f 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -31,6 +31,7 @@ import torch import torch.nn as nn from transformers import PretrainedConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform @@ -72,6 +73,29 @@ def _apply_rotary_emb_torch( else: return torch.stack((o1, o2), dim=-1).flatten(-2) +@CustomOp.register("kaiju_rotary_embedding") +class KaijuRotaryEmbedding(CustomOp): + def __init__( + self, + hf_config: PretrainedConfig, + dtype: torch.dtype + ) -> None: + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(hf_config, "rope_scaling") and hf_config.rope_scaling is not None: + self.rope_type = hf_config.rope_scaling.get("rope_type", hf_config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = hf_config.max_position_embeddings + self.original_max_seq_len = hf_config.max_position_embeddings + + self.config = hf_config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool) -> torch.Tensor: diff --git a/vllm/model_executor/models/kaiju.py b/vllm/model_executor/models/kaiju.py new file mode 100644 index 000000000000..2fdde0455095 --- /dev/null +++ b/vllm/model_executor/models/kaiju.py @@ -0,0 +1,291 @@ +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ACT2FN + +from kaiju import KaijuTextConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import KaijuRMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig + +# from vllm.attention import Attention +# from vllm.compilation.decorators import support_torch_compile + +# from vllm.logger import init_logger +# from vllm.model_executor.layers.activation import GeluAndMul + +# from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +# QKVParallelLinear, +# RowParallelLinear) +# from vllm.model_executor.layers.logits_processor import LogitsProcessor +# from vllm.model_executor.layers.rotary_embedding import get_rope +# from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +# from vllm.model_executor.layers.vocab_parallel_embedding import ( +# VocabParallelEmbedding) +# from vllm.model_executor.model_loader.weight_utils import ( +# default_weight_loader, maybe_remap_kv_scale_name) +# from vllm.model_executor.sampling_metadata import SamplingMetadata +# from vllm.sequence import IntermediateTensors + +# from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + +class KaijuMLP(nn.Module): + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + rms_norm_eps: float, + ): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + self.residual_scale = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False) + self.pre_ffn_norm = KaijuRMSNorm(self.hidden_size, eps=rms_norm_eps) + + # TODO: Megatron style TP (MergedColumnParallelLinear then RowParallelLinear) + self.W_in = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.W_out = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + # WARNING: In whippet checkpoints, there is an `args["quantize"]["ffn_clamp_middle_output"]` + # It's only used in the backward pass in specific circumstances. + hidden_states = x + x = self.W_in(x) + x = clamp(x, 4) + x = self.act_fn(x) + x = self.W_out(x) + hidden_states *= self.residual_scale + return x + hidden_states + +@dataclass +class KaijuCache: + key_states : Optional[torch.Tensor] = None + value_states : Optional[torch.Tensor] = None + +class KaijuAttention(nn.Module): + def __init__(self, + config: KaijuTextConfig, + max_position_embeddings: int, + is_context_encoder: bool, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "" + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.is_context_encoder = is_context_encoder + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + # TODO: Combine into single proj matrix and use QKVParallelLinear + self.q_proj = nn.Linear( + self.hidden_size, self.q_size, bias=False + ) + if not self.is_context_encoder: + self.k_proj = nn.Linear( + self.hidden_size, self.kv_size, bias=False + ) + self.v_proj = nn.Linear( + self.hidden_size, self.kv_size, bias=False + ) + + # TODO: Use RowParallelLinear + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + self.pre_projection_norm = KaijuRMSNorm(self.config.hidden_size, eps=config.rms_norm_eps) + + layer_idx = extract_layer_index(prefix) + self.is_sliding = layer_idx not in self.config.global_attention_layer_schedule + if self.is_sliding: + self.sliding_window = 1024 + else: + self.sliding_window = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn" + ) + + def forward( + self, + positions_embeddings: Tuple[torch.Tensor, torch.Tensor], + hidden_states: torch.Tensor, + kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + ) -> torch.Tensor: + + processed_hidden_states = self.pre_projection_norm(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + cos, sin = position_embeddings + query_states = self.q_proj(processed_hidden_states).view(hidden_shape) + + if self.is_context_encoder: + assert kv_cache is None + key_states = kv_cache.key_states + value_states = kv_cache.value_states + else: + key_states = self.k_proj(processed_hidden_states).view(hidden_shape) + value_states = self.v_proj(processed_hidden_states).view(hidden_shape) + + if kv_cache is not None: + key_states = kv_cache.key_states + value_states = kv_cache.value_states + + + # We should probably cache the clamped values. + query_states = clamp(query_states, 4) + key_states = clamp(key_states, 4) + value_states = clamp(value_states, 4) + + # Should we cache post rope? + query_states, key_states = apply_rotary_pos_emb_kaiju(query_states, key_states, cos, sin, unsqueeze_dim=2) + + # TODO: attention masking + attn_output = self.attn(query_states, key_states, value_states) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + hidden_states *= self.residual_scale + hidden_states += attn_output + + return hidden_states + +class KaijuDecoderLayer(nn.Module): + def __init__( + self, + config: KaijuTextConfig, + is_context_encoder: bool, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "" + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = KaijuAttention( + config=config, + max_position_embeddings=config.max_position_embeddings, + is_context_encoder=is_context_encoder, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=None, + prefix=f"{prefix}.self_attn" + ) + + self.mlp = KaijuMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + rms_norm_eps=config.rms_norm_eps, + ) + + def forward( + self, + positions_embeddings: Tuple[torch.Tensor, torch.Tensor], + hidden_states: torch.Tensor, + output_attentions: bool = False, + kv_cache: Optional[KaijuCache] = None + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + # Self Attention + # attention module handles the residual stream update. + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + kv_cache=kv_cache, + ) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states,) + # This isn't necessary for inference, we can consider writing a slow + # attention implementation for debugging purposes. + assert not output_attentions, "TODO: Support this" + + return outputs + +@support_torch_compile +class KaijuModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.layer_to_kv_group = list(range(config.num_hidden_layers)) + for layers in config.share_kv_schedule: + for layer_idx in layers: + self.layer_to_kv_group[layer_idx] = min(layers) + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Vocab parallel embedding + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + # TODO: Get rid of this scale by "compiling" it into the embedding weights, then + # when we convert the lm head/etc we can just adjust that scale. + self.embedding_scale = nn.Parameter(torch.FloatTensor([0]), requires_grad=False) + + self.start_layer, self.end_layer, self.layers = make_layers_with_idx( + config.num_hidden_layers, + lambda prefix, idx: KaijuDecoderLayer( + config, is_context_encoder=idx != self.layer_to_kv_group[idx], cache_config=cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers" + ) + + + + + + + + + + + + + \ No newline at end of file From c88780f6af8156a57ee7597c7f21c84779525830 Mon Sep 17 00:00:00 2001 From: Alex <alexwu@character.ai> Date: Tue, 6 May 2025 02:29:58 +0000 Subject: [PATCH 18/27] Add classifier head vllm-project 17688 This PR outputs classifier results for each decode token (somewhat similar to logprobs). The corresponding PR upstream is vllm-project#17688 (though I don't expect to merge that any time soon). Signed-off-by: Alex <alexwu@character.ai> --- vllm/model_executor/models/transformers.py | 17 +++++++++ vllm/outputs.py | 9 +++-- vllm/sampling_params.py | 3 ++ vllm/sequence.py | 1 + vllm/v1/engine/__init__.py | 4 ++- vllm/v1/engine/additional_heads.py | 42 ++++++++++++++++++++++ vllm/v1/engine/logprobs.py | 2 +- vllm/v1/engine/output_processor.py | 9 +++++ vllm/v1/outputs.py | 10 ++++++ vllm/v1/worker/gpu_input_batch.py | 4 +++ vllm/v1/worker/gpu_model_runner.py | 37 ++++++++++++++++++- 11 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 vllm/v1/engine/additional_heads.py diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 04ee3a454f9d..a43cc049f65e 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -404,6 +404,15 @@ def forward( return hidden_states + def compute_additional_head( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + if get_pp_group().is_last_rank and hasattr(self.model, + "compute_additional_head"): + return self.model.compute_additional_head(hidden_states) + return None + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) @@ -499,6 +508,14 @@ def compute_logits( sampling_metadata) return logits + def compute_additional_head( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + if hasattr(self.model, "compute_additional_head"): + return self.model.compute_additional_head(hidden_states) + return None + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( diff --git a/vllm/outputs.py b/vllm/outputs.py index 9784a8894472..3cce2304545f 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -14,8 +14,9 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceGroupBase, SequenceStatus) +from vllm.sequence import (AdditionalHeads, PromptLogprobs, RequestMetrics, + SampleLogprobs, SequenceGroup, SequenceGroupBase, + SequenceStatus) logger = init_logger(__name__) @@ -32,6 +33,8 @@ class CompletionOutput: output text. logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. + additional_heads: The additional head outputs of the generated output + text. finish_reason: The reason why the sequence is finished. stop_reason: The stop string or token id that caused the completion to stop, None if the completion finished for some other reason @@ -47,6 +50,7 @@ class CompletionOutput: finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None + additional_heads: Optional[AdditionalHeads] = None def finished(self) -> bool: return self.finish_reason is not None @@ -57,6 +61,7 @@ def __repr__(self) -> str: f"token_ids={self.token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs}, " + f"additional_heads={self.additional_heads}, " f"finish_reason={self.finish_reason}, " f"stop_reason={self.stop_reason})") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a9a862384d11..135866d4992c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -248,6 +248,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Fields used for additional heads (e.g. classifiers) + additional_heads: Optional[bool] = None + @staticmethod def from_optional( n: Optional[int] = 1, diff --git a/vllm/sequence.py b/vllm/sequence.py index ffe890eb2dab..c606560026ac 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -54,6 +54,7 @@ class Logprob: PromptLogprobs = list[Optional[dict[int, Logprob]]] # {token_id -> logprob} for each sequence group. SampleLogprobs = list[dict[int, Logprob]] +AdditionalHeads = list[list[float]] class SequenceStatus(enum.IntEnum): diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cdd..95fb705cffdc 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -15,7 +15,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import LogprobsLists, LogprobsTensors +from vllm.v1.outputs import LogprobsLists, LogprobsTensors, AdditionalHeadOutputsPerRequest # These are possible values of RequestOutput.finish_reason, # so form part of the external API. @@ -107,6 +107,8 @@ class EngineCoreOutput( new_logprobs: Optional[LogprobsLists] = None new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + new_additional_head_outputs: Optional[ + AdditionalHeadOutputsPerRequest] = None pooling_output: Optional[torch.Tensor] = None diff --git a/vllm/v1/engine/additional_heads.py b/vllm/v1/engine/additional_heads.py new file mode 100644 index 000000000000..cd436f9870c1 --- /dev/null +++ b/vllm/v1/engine/additional_heads.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from vllm.logger import init_logger +from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest + +logger = init_logger(__name__) + + +@dataclass +class AdditionalHeadsProcessor: + """Processor for additional head outputs from the model. + + This class handles storing and managing additional head outputs + for generated tokens, similar to how LogprobsProcessor handles logprobs. + """ + + # Additional head outputs for this request + additional_head_outputs: list[list[float]] + + @classmethod + def from_new_request( + cls, + request: EngineCoreRequest, + ) -> "AdditionalHeadsProcessor": + """Create a new AdditionalHeadsProcessor for a request. + + Args: + request: The engine core request to process additional heads for. + """ + return cls(additional_head_outputs=[], ) + + def update_from_output(self, output: EngineCoreOutput) -> None: + """Update with additional head outputs from EngineCore. + + Args: + output: The engine core output containing new additional + head outputs. + """ + if output.new_additional_head_outputs is not None: + self.additional_head_outputs.append( + output.new_additional_head_outputs.additional_head_outputs) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index e95da0a5e5aa..b92067e344f3 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -197,4 +197,4 @@ def update_from_output(self, output: EngineCoreOutput) -> None: if output.new_logprobs is not None: self._update_sample_logprobs(output.new_logprobs) if output.new_prompt_logprobs_tensors is not None: - self._update_prompt_logprobs(output.new_prompt_logprobs_tensors) + self._update_prompt_logprobs(output.new_prompt_logprobs_tensors) \ No newline at end of file diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0aa..264171dd6ea1 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,6 +14,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason +from vllm.v1.engine.additional_heads import AdditionalHeadsProcessor from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest @@ -103,6 +104,7 @@ def __init__( self.prompt_token_ids = prompt_token_ids self.prompt_len = len(prompt_token_ids) self.logprobs_processor = logprobs_processor + self.additional_heads_processor = additional_heads_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param self.is_prefilling = True @@ -254,11 +256,18 @@ def _new_completion_output( if delta and logprobs: logprobs = logprobs[-len(token_ids):] + # Prepare additional heads, based on delta mode + additional_heads = ( + self.additional_heads_processor.additional_head_outputs or None) + if delta and additional_heads: + additional_heads = additional_heads[-len(token_ids):] + return CompletionOutput( index=self.request_index, text=text, token_ids=token_ids, logprobs=logprobs, + additional_heads=additional_heads, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, stop_reason=stop_reason if finished else None) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..d0a862e1fb9d 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -7,6 +7,16 @@ import torch +class AdditionalHeadOutputsPerRequest(NamedTuple): + # num_additional_head_outputs + additional_head_outputs: list[float] + + +class AdditionalHeadOutputs(NamedTuple): + # num_generated_tokens x num_additional_head_outputs + additional_head_outputs: list[Optional[AdditionalHeadOutputsPerRequest]] + + class LogprobsLists(NamedTuple): # [num_reqs, max_num_logprobs + 1] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a9..981ae7f8dfc8 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -207,6 +207,9 @@ def __init__( # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} + # req_idx -> bool + self.run_additional_heads: dict[int, bool] = {} + # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} @@ -412,6 +415,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) + self.run_additional_heads.pop(req_index, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) # LoRA diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db1f7..4635c1a63507 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -53,7 +53,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AdditionalHeadOutputs, + AdditionalHeadOutputsPerRequest, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -1420,6 +1421,40 @@ def execute_model( assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] + additional_head_indices_mask = [ + (i in self.input_batch.run_additional_heads) + for i in range(self.input_batch.num_reqs) + ] + run_additional_heads = any(additional_head_indices_mask) + + if run_additional_heads: + assert hasattr(self.model, "compute_additional_head") + + # NOTE: In theory not all logit indices need additional + # head outputs and we could save some flops by masking. + # In practice, this is a small number of flops and this + # is simpler/introduces less overhead. + additional_heads_tensor = self.model.compute_additional_head( + sample_hidden_states, ) + + # Should be num_decode_tokens x additional_head_size + assert len(additional_heads_tensor.shape) == 2 + + # Don't return the additional head outputs where they aren't needed. + additional_head_outputs = AdditionalHeadOutputs( + additional_head_outputs=[ + AdditionalHeadOutputsPerRequest( + additional_head_outputs= + additional_head_outputs_per_request, ) + if mask else None + for additional_head_outputs_per_request, mask in zip( + additional_heads_tensor.tolist(), + additional_head_indices_mask) + ], ) + + else: + additional_head_outputs = None + # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: self.apply_grammar_bitmask(scheduler_output, logits) From ca9f46bb49ebe6f22a23f318870a08c4c49f8835 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty <amogkamsetty@gmail.com> Date: Wed, 7 May 2025 22:03:32 +0000 Subject: [PATCH 19/27] support already quantized input into fp8 kernel Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> --- .../layers/quantization/utils/w8a8_utils.py | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index adc67aa64952..e0cbb457d61d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -345,24 +345,8 @@ def apply( out_dtype = input.dtype # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if self.cutlass_fp8_supported: - assert input.dtype != current_platform.fp8_dtype( - ), "FP8 input to cutlass is not currently implemented" - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic) - else: - if input.dtype != current_platform.fp8_dtype(): - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) - else: - qinput, x_scale = input_2d, input_scale + input_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) From 7b859f8b14dbfb39584e9bd645655ba6ae2f49ce Mon Sep 17 00:00:00 2001 From: rohingarg-c <rohin@character.ai> Date: Tue, 13 May 2025 15:42:24 -0700 Subject: [PATCH 20/27] added stream_n to v1/async_llm, created streaming_params (#1) * added stream_n to v1/async_llm, created streaming_params * Updated OpenAI compatible API to work with StreamingParams Signed-off-by: Rohin Garg <rohin@character.ai> --- requirements/test.txt | 20 +++++++- tests/v1/engine/test_async_llm.py | 6 ++- vllm/engine/async_llm_engine.py | 18 ++++++- vllm/engine/multiprocessing/client.py | 24 +++++++++- vllm/engine/protocol.py | 2 + vllm/entrypoints/openai/protocol.py | 21 ++++++++ vllm/entrypoints/openai/serving_chat.py | 3 ++ vllm/entrypoints/openai/serving_completion.py | 2 + vllm/streaming_params.py | 48 +++++++++++++++++++ vllm/v1/engine/async_llm.py | 16 ++++++- 10 files changed, 154 insertions(+), 6 deletions(-) create mode 100644 vllm/streaming_params.py diff --git a/requirements/test.txt b/requirements/test.txt index f6f599df758f..e1ee216f58dd 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -31,6 +31,10 @@ argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 # via isoduration +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -141,6 +145,11 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval +exceptiongroup==1.3.0 + # via + # anyio + # hypothesis + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -753,8 +762,13 @@ tokenizers==0.21.1 # via # -r requirements/test.in # transformers +toml==0.10.2 + # via datamodel-code-generator tomli==2.2.1 - # via schemathesis + # via + # black + # pytest + # schemathesis tomli-w==1.2.0 # via schemathesis torch==2.7.0+cu128 @@ -828,6 +842,9 @@ types-python-dateutil==2.9.0.20241206 # via arrow typing-extensions==4.12.2 # via + # anyio + # black + # exceptiongroup # huggingface-hub # librosa # mistral-common @@ -835,6 +852,7 @@ typing-extensions==4.12.2 # pqdm # pydantic # pydantic-core + # rich # torch # typer # typing-inspection diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index e137452f2625..8c851fbfe684 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -73,7 +73,8 @@ async def generate( ) async for out in engine.generate(request_id=request_id, prompt=prompt, - sampling_params=sampling_params): + sampling_params=sampling_params, + streaming_params=streaming_params): num_tokens = sum(len(output.token_ids) for output in out.outputs) if output_kind == RequestOutputKind.DELTA: @@ -244,7 +245,8 @@ async def test_finished_flag( out async for out in engine.generate(request_id="request-33", prompt=prompt, - sampling_params=sampling_params) + sampling_params=sampling_params, + streaming_params=streaming_params) ] # Assert only the last output has the finished flag set diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3d7d28055dd0..ad9aa62e5f3b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -32,6 +32,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest +from vllm.streaming_params import StreamingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, weak_bind @@ -897,6 +898,7 @@ async def generate( self, prompt: PromptType, sampling_params: SamplingParams, + streaming_params: StreamingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, @@ -973,6 +975,8 @@ async def generate( >>> ... """ try: + buffer: Optional[RequestOutput] = None # buffer of output tokens + buffer_token_count = 0 async for output in await self.add_request( request_id, prompt, @@ -983,7 +987,19 @@ async def generate( priority=priority, data_parallel_rank=data_parallel_rank, ): - yield LLMEngine.validate_output(output, RequestOutput) + output = LLMEngine.validate_output(output, RequestOutput) + if buffer is None: + buffer = output + else: + buffer.add(output, aggregate=True) + + buffer_token_count += sum( + len(o.token_ids) for o in output.outputs) + if buffer_token_count >= streaming_params.stream_n: + yield buffer + buffer = None + buffer_token_count = 0 + except asyncio.CancelledError: await self.abort(request_id) raise diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 9e018ec7f344..895e23b94e75 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -47,6 +47,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.streaming_params import StreamingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import Device @@ -445,6 +446,7 @@ def generate( self, prompt: PromptType, sampling_params: SamplingParams, + streaming_params: StreamingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, @@ -519,6 +521,7 @@ async def _process_request( prompt: PromptType, params: Union[SamplingParams, PoolingParams], request_id: str, + streaming_params: Optional[StreamingParams] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -589,6 +592,8 @@ async def _process_request( # queue after pulling them from the zmq socket. finished = False try: + buffer = None # buffer of output tokens + buffer_token_count = 0 while not finished: request_output = await queue.get() @@ -596,7 +601,24 @@ async def _process_request( raise request_output finished = request_output.finished - yield request_output + if buffer is None: + buffer = request_output + else: + buffer.add(request_output, aggregate=True) + + if isinstance(request_output, RequestOutput): + buffer_token_count += sum( + len(o.token_ids) for o in request_output.outputs) + else: + buffer_token_count += 1 + if streaming_params is None or \ + buffer_token_count >= streaming_params.stream_n or \ + finished: + + yield buffer + buffer = None + buffer_token_count = 0 + finally: # Request was canceled by the client. if not finished and not self.errored: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 8688fcc82cd9..c02261b181ff 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -18,6 +18,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.streaming_params import StreamingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Device, collect_from_async_generator, random_uuid @@ -52,6 +53,7 @@ def generate( self, prompt: PromptType, sampling_params: SamplingParams, + streaming_params: StreamingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 14b2253d1dba..18074491b0a2 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -29,6 +29,7 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.sequence import Logprob +from vllm.streaming_params import StreamingParams from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -160,6 +161,7 @@ class ResponseFormat(OpenAIBaseModel): class StreamOptions(OpenAIBaseModel): include_usage: Optional[bool] = True continuous_usage_stats: Optional[bool] = False + stream_n: Optional[int] = 1 class FunctionDefinition(OpenAIBaseModel): @@ -666,6 +668,13 @@ def to_sampling_params( extra_args=extra_args or None, ) + def to_streaming_params(self, ) -> StreamingParams: + stream_n = None + if self.stream_options is not None and \ + self.stream_options.stream_n is not None: + stream_n = self.stream_options.stream_n + return StreamingParams(stream_n=stream_n) + def _get_guided_json_from_tool( self) -> Optional[Union[str, dict, BaseModel]]: # user has chosen to not use any tool @@ -1104,6 +1113,13 @@ def to_sampling_params( extra_args=extra_args or None, ) + def to_streaming_params(self, ) -> StreamingParams: + stream_n = None + if self.stream_options is not None and \ + self.stream_options.stream_n is not None: + stream_n = self.stream_options.stream_n + return StreamingParams(stream_n=stream_n) + @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): @@ -2016,6 +2032,11 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, extra_args=self.vllm_xargs) + def to_streaming_params( + self, + ) -> StreamingParams: # stream_options not defined in transcription request + return StreamingParams(stream_n=None) + @model_validator(mode="before") @classmethod def validate_transcription_request(cls, data): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a802fbc3865f..90d24819dfdf 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -252,6 +252,8 @@ async def create_chat_completion( max_tokens, self.model_config.logits_processor_pattern, self.default_sampling_params) + streaming_params = request.to_streaming_params() + self._log_inputs(request_id, request_prompts[i], params=sampling_params, @@ -272,6 +274,7 @@ async def create_chat_completion( generator = self.engine_client.generate( engine_prompt, sampling_params, + streaming_params, request_id, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6c9c29b71445..3101e28080b5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -180,6 +180,7 @@ async def create_completion( self.default_sampling_params) request_id_item = f"{request_id}-{i}" + streaming_params = request.to_streaming_params() self._log_inputs(request_id_item, request_prompts[i], @@ -206,6 +207,7 @@ async def create_completion( generator = self.engine_client.generate( engine_prompt, sampling_params, + streaming_params, request_id_item, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/streaming_params.py b/vllm/streaming_params.py new file mode 100644 index 000000000000..e10cc5c5e1c5 --- /dev/null +++ b/vllm/streaming_params.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Streaming parameters for token streaming during text generation.""" +from typing import Annotated + +import msgspec + + +class StreamValidationError(ValueError): + pass + + +class StreamDefaults: + STREAM_N_DEFAULT = 1 + + +class StreamLimits: + STREAM_N_MIN = 1 + STREAM_N_MAX = 1024 + + +class StreamingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + dict=True): # type: ignore[call-arg] + """Streaming parameters for token streaming during text generation. + + Args: + stream_n: Number of tokens to stream at a time. Must be an integer >= 1. + Defaults to 1. + """ + + stream_n: Annotated[int, msgspec.Meta( + ge=StreamLimits.STREAM_N_MIN)] = (StreamDefaults.STREAM_N_DEFAULT) + + def __post_init__(self) -> None: + if self.stream_n is None: + self.stream_n = StreamDefaults.STREAM_N_DEFAULT + if not isinstance(self.stream_n, int): + raise StreamValidationError( + f"stream_n must be an integer, got {type(self.stream_n)}.") + if not (StreamLimits.STREAM_N_MIN <= self.stream_n <= + StreamLimits.STREAM_N_MAX): + raise StreamValidationError( + f"stream_n must be between {StreamLimits.STREAM_N_MIN} and " + f"{StreamLimits.STREAM_N_MAX}, got {self.stream_n}.") + + def __repr__(self) -> str: + return f"StreamingParams(stream_n={self.stream_n})" diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaaa..c1ac6a46b354 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -278,6 +278,7 @@ async def generate( self, prompt: PromptType, sampling_params: SamplingParams, + streaming_params: StreamingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, @@ -320,6 +321,8 @@ async def generate( # The output_handler task pushes items into the queue. # This task pulls from the queue and yields to caller. finished = False + buffer: Optional[RequestOutput] = None # buffer of output tokens + buffer_token_count = 0 while not finished: # Note: drain queue without await if possible (avoids # task switching under load which helps performance). @@ -328,7 +331,18 @@ async def generate( # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. finished = out.finished - yield out + + if buffer is None: + buffer = out + else: + buffer.add(out, aggregate=True) + + buffer_token_count += sum( + len(o.token_ids) for o in out.outputs) + if buffer_token_count >= streaming_params.stream_n or finished: + yield buffer + buffer = None + buffer_token_count = 0 # If the request is disconnected by the client, generate() # is cancelled or the generator is garbage collected. So, From fcc73a594fd4531f3820259c2c1ea34bcd635314 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty <amogkam@users.noreply.github.com> Date: Thu, 5 Jun 2025 14:06:47 -0400 Subject: [PATCH 21/27] Fp8 cleanup (#4) Small cleanup to still keep the existing codepath if non fp8 input is passed in --------- Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> --- .../layers/quantization/utils/w8a8_utils.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index e0cbb457d61d..86acfe11b18a 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -345,8 +345,28 @@ def apply( out_dtype = input.dtype # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - input_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) - qinput, x_scale = input_2d, input_scale + if self.cutlass_fp8_supported and input.dtype != current_platform.fp8_dtype(): + assert input.dtype != current_platform.fp8_dtype( + ), "FP8 input to cutlass is not currently implemented" + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + if input.dtype != current_platform.fp8_dtype(): + # Maybe apply padding to output, see comment in __init__ + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=self.output_padding, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + if x_scale is not None: + qinput, x_scale = input_2d, input_scale + else: + qinput = input_2d + x_scale = torch.tensor([1.0], dtype=torch.float32, device=input_2d.device) per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) From 2d9589e960859c1c2f432ecc3d63283120e884cf Mon Sep 17 00:00:00 2001 From: Amog Kamsetty <amogkam@users.noreply.github.com> Date: Thu, 5 Jun 2025 14:46:12 -0400 Subject: [PATCH 22/27] Update w8a8_utils.py (#6) --- vllm/model_executor/layers/quantization/utils/w8a8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 86acfe11b18a..ae6f204fefed 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -362,7 +362,7 @@ def apply( num_token_padding=self.output_padding, use_per_token_if_dynamic=use_per_token_if_dynamic) else: - if x_scale is not None: + if input_scale is not None: qinput, x_scale = input_2d, input_scale else: qinput = input_2d From 96ee8bbf76ea851b0458fbb0dac40d8281fae65b Mon Sep 17 00:00:00 2001 From: rohingarg-c <rohin@character.ai> Date: Fri, 27 Jun 2025 17:36:24 -0700 Subject: [PATCH 23/27] minor fixes to make classifier heads more usable (#10) Signed-off-by: Rohin Garg <rohin@character.ai> --- vllm/entrypoints/openai/protocol.py | 41 +++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 18074491b0a2..41da13fdb277 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1436,6 +1436,47 @@ class PoolingResponse(OpenAIBaseModel): usage: UsageInfo +class ClassificationRequest(OpenAIBaseModel): + model: Optional[str] = None + input: Union[list[str], str] + truncate_prompt_tokens: Optional[int] = None + user: Optional[str] = None + + # --8<-- [start:classification-pooling-params] + additional_data: Optional[Any] = None + # --8<-- [end:classification-pooling-params] + + # --8<-- [start:classification-extra-params] + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + + # --8<-- [end:classification-extra-params] + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class ClassificationData(OpenAIBaseModel): + index: int + label: Optional[str] + probs: list[float] + num_classes: int + + +class ClassificationResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"classify-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: list[ClassificationData] + usage: UsageInfo + + class ScoreResponseData(OpenAIBaseModel): index: int object: str = "score" From a03a7c33c52aff2382a239ed257fc7af2d0c566a Mon Sep 17 00:00:00 2001 From: Your Name <you@example.com> Date: Wed, 2 Jul 2025 18:39:43 +0000 Subject: [PATCH 24/27] add support for accumulate in vllm --- vllm/entrypoints/openai/serving_completion.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3101e28080b5..305aa0374581 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -315,6 +315,9 @@ async def completion_stream_generator( previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts + accumulated_text = [""] * num_choices * num_prompts + accumulated_tokens = [[] * num_choices * num_prompts] + accumulated_logprobs = [[] * num_choices * num_prompts] stream_options = request.stream_options if stream_options: @@ -370,6 +373,16 @@ async def completion_stream_generator( *(output.logprobs or []), ] has_echoed[i] = True + elif request.accumulate: + i = output.index + prompt_idx * num_choices + # return the accumulated response + accumulated_text[i] += output.text + accumulated_tokens[i].extend(output.token_ids) + accumulated_logprobs[i].extend(output.logprobs or []) + + delta_text = accumulated_text[i] + delta_token_ids = accumulated_tokens[i] + out_logprobs = accumulated_logprobs[i] else: # return just the delta delta_text = output.text From 2e339f1fac881315e1b77e9349dc8fa393986340 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty <amogkam@users.noreply.github.com> Date: Wed, 2 Jul 2025 20:51:07 -0400 Subject: [PATCH 25/27] Upgrade to 0.9.1 & support for `cached_tokens` info in completions request (#12) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix (#18100) Signed-off-by: Hongxia Yang <hongxia.yang@amd.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> * Prevent the cross-encoder logic from being applied to classification tasks (#18838) Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> * Add ability to use CUDAGraphs with use_inductor=False (#17345) Signed-off-by: rzou <zou3519@gmail.com> * [Bugfix][TPU] fix moe custom kernel import (#18853) Signed-off-by: Chengji Yao <chengjiyao@google.com> * [Doc][Neuron] Update documentation for Neuron (#18868) Signed-off-by: Elaine Zhao <elaineyz@amazon.com> * Skip device and quant Pydantic validation to make plugin device work (#18843) Signed-off-by: Yikun Jiang <yikunkero@gmail.com> * Fixes a dead link in nightly benchmark readme (#18856) Signed-off-by: Brent Salisbury <bsalisbu@redhat.com> * [Neuron] Add multi-LoRA support for Neuron. (#18284) Signed-off-by: Satyajith Chilappagari <satchill@amazon.com> * [LoRA] Add LoRA support for InternVL (#18842) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> * [Doc] Remove redundant spaces from compatibility_matrix.md (#18891) Signed-off-by: windsonsea <haifeng.yao@daocloud.io> * [doc] add CLI doc (#18871) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Bugfix] Fix misleading information in the documentation (#18845) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> * [Misc] Replace TODO in serving transcription (#18895) Signed-off-by: NickLucche <nlucches@redhat.com> * [Bugfix] Ensure tensors are contiguous during serialisation (#18860) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> * [BugFix] Update pydantic to fix error on python 3.10 (#18852) Signed-off-by: luka <luka@neuralmagic.com> * Fix an error in dummy weight loading for quantization models (#18855) Signed-off-by: Chenyaaang <chenyangli@google.com> * [Misc][Tools][Benchmark] Add benchmark_serving supports for llama.cpp. (#18692) Signed-off-by: Duyi-Wang <duyi.wang@intel.com> * [Doc] Fix codeblocks formatting in LoRA adapters documentation (#18907) Signed-off-by: Zerohertz <ohg3417@gmail.com> * [Bugfix] Fix the failing gte embedding test (#18720) Signed-off-by: Isotr0py <2037008807@qq.com> * [Attention][V1] Toggle for v1 attention backend (#18275) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> * [ROCm][V0][Attention] Revert to the previous FA triton kernel (#18226) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> * [Deprecation] Disallow pos-args other than `model` when initializing `LLM` (#18802) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [Misc] Remove duplicate init for self.vllm_config (#18896) Signed-off-by: googs1025 <googs1025@gmail.com> * [V1] Allocate kv_cache with stride order for V1 (#18775) Signed-off-by: nicklucche <nlucches@redhat.com> * [BugFix] Make DP work with connector-delayed new requests (#18559) Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Will Eaton <weaton@redhat.com> * [P/D] NixlConnector DP fixes (#18903) Signed-off-by: Will Eaton <weaton@redhat.com> * Use standalone_compile by default in torch >= 2.8.0 (#18846) Signed-off-by: rzou <zou3519@gmail.com> * [TPU] remove transpose ops in moe kernel (#18923) Signed-off-by: Chengji Yao <chengjiyao@google.com> * [Bugfix] Fix PP default fallback behavior for V1 (#18915) Signed-off-by: mgoin <mgoin64@gmail.com> * [Misc] Update type annotation for rotary embedding `base` (#18914) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [TPU][CI/CD] Clean up docker for TPU tests. (#18926) Signed-off-by: Carol Zheng <cazheng@google.com> * improve the robustness of parsing vlms config in AutoRound (#18894) Signed-off-by: wenhuach21 <wenhua.cheng@intel.com> * [Bugfix] Consistent ascii handling in tool parsers (#18883) Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> * [Model] Use AutoWeightsLoader for mamba2 (#18918) Signed-off-by: iLeGend <824040212@qq.com> * [docs] fix: fix markdown syntax (#18927) * [ROCm] Remove unnecessary assertion of max_model_len in ROCM_AITER_MLA attention backend. (#18938) Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> * [Bugfix] Remove NVFP4 scales assertions to fix load_format=dummy (#18861) Signed-off-by: mgoin <mgoin64@gmail.com> * [Deprecation] Remove mean pooling default for `Qwen2EmbeddingModel` (#18913) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [Misc]Fix benchmarks/README.md for speculative decoding (#18897) Signed-off-by: rabi <ramishra@redhat.com> * [doc] add mkdocs doc (#18930) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Model] Use in-place adds in SigLIP (#18922) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> * [Bugfix][Failing Test] Fix test_vllm_port.py (#18618) Signed-off-by: rabi <ramishra@redhat.com> * [Misc]Fix typo (#18947) * [Bugfix][TPU] Fix tpu model runner testcase failure (#18810) Signed-off-by: Carol Zheng <cazheng@google.com> * [CI/Build] remove regex from build dependencies (#18945) Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> * [Feature] minicpm eagle support (#18943) Signed-off-by: huangyuxiang03 <huangyx0321@gmail.com> Co-authored-by: huangyuxiang03 <huangyx0321@gmail.com> * [doc] show the count for fork and watch (#18950) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Docs] Update SECURITY.md with link to our security guide (#18961) Signed-off-by: Russell Bryant <rbryant@redhat.com> * Improve "failed to get the hash of the compiled graph" error (#18956) Signed-off-by: rzou <zou3519@gmail.com> * [Perf] API-server scaleout with many-to-many server-engine comms (#17546) * Benchmark script for fp8 vs bf16 gemm (#17126) Signed-off-by: mgoin <mgoin64@gmail.com> * [VLM] Add PP support and fix GPTQ inference for Ovis models (#18958) Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> * [Misc] add group_size is -1 in awq quantization (#18910) Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> * Tool parser regex timeout handling (#18960) Signed-off-by: Will Eaton <weaton@redhat.com> * [Docs] Correct multiprocessing design doc (#18964) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> * create util function for batched arange (#18937) * [Frontend] Add rerank support to run_batch endpoint (#16278) Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io> * [Misc] Fix estimated max model len msg (#18966) Signed-off-by: Yong Hoon Shin <yhshin@meta.com> * [Bugfix]: Fix the incompatibility issue with Structured Outputs when Thinking is disabled (#18879) Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> * fix security issue of logging llm output (#18980) Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> * [Neuron] Add Multi-Modal model support for Neuron (#18921) Signed-off-by: Satyajith Chilappagari <satchill@amazon.com> Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com> Co-authored-by: Rohith Nallamaddi <nalrohit@amazon.com> Co-authored-by: FeliciaLuo <luof@amazon.com> Co-authored-by: Elaine Zhao <elaineyz@amazon.com> * [doc] fix the list rendering issue - security.md (#18982) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [BugFix] Pydantic part 2 (#18911) Signed-off-by: luka <luka@neuralmagic.com> * [FEAT][ROCm] Add AITER grouped topk for DeepSeekV2 (#18825) Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> * [Bugfix] Fix for issue 17396 (#18773) Signed-off-by: Fred Reiss <frreiss@us.ibm.com> * [ROCm][Kernel] Add gfx950 support for skinny gemms (#18010) Signed-off-by: charlifu <charlifu@amd.com> * [P/D] NixlConnector use cache device index for memory registration (#18969) Signed-off-by: Piotr Tarasiewicz <ptarasiewicz@nvidia.com> * [BugFix] Fix multi-node offline data-parallel (#18981) Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com> * [Misc] add return token strs for tokenize (#18941) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Misc][Benchmark] Add support for CustomDataset (#18511) * [Bugfix] Fix EAGLE3 broken logits (#18909) Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> * [Core] Rework dtype resolution (#18751) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [LoRA] Support dynamically initialize `packed_modules_mapping` for VLM with arbitrary components (#18987) Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <2037008807@qq.com> * [doc] small fix - mkdocs (#18996) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * Let max_num_batched_tokens use human_readable_int for large numbers (#18968) Signed-off-by: mgoin <mgoin64@gmail.com> * [BugFix] fix data parallel construct ipv6 url addres (#18991) Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> * [BugFix] Fix incorrect metrics shutdown error log message (#18992) Signed-off-by: Nick Hill <nhill@redhat.com> * [doc] wrong output (#19000) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Misc] reuse num_tokens_across_dp of get_dp_padding to avoid unnecessary dp all reduce in set_forward_context (#18935) Signed-off-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> * [Bugfix][Nixl] Fix DP Metadata Handshake (#19008) Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> * [Core] Support inplace model weights loading (#18745) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> * [doc] add pytest tips (#19010) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Model] enable data parallel for Llama4 vision encoder (#18368) Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com> * [Frontend] enable custom logging for the uvicorn server (OpenAI API server) (#18403) Signed-off-by: François Paupier <francois.paupier@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> * [Bugfix][Model] Attempt to fix eagle in V0. (#18978) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> * add an absolute path for run.sh (#18258) Signed-off-by: calvin chen <120380290@qq.com> * [Hardware][TPU] Initial support of model parallelism with single worker using SPMD (#18011) Signed-off-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> * [Doc] Remove duplicate TOCs during MkDocs migration (#19021) Signed-off-by: Zerohertz <ohg3417@gmail.com> * [Bugfix][EP+DP] Use pplx-kernel internode instead of intranode (#19034) Signed-off-by: Tyler Michael Smith <tysmith@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * Adding "LoRA Test %N" to AMD production tests (#18929) Signed-off-by: Yida Wu <yidawu@alumni.cmu.edu> * [CPU][CI] Re-enable the CPU CI tests (#19046) Signed-off-by: jiang.li <jiang1.li@intel.com> * [ROCm][Build] Clean up the ROCm build (#19040) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> * [V1] Support DP with Ray (#18779) * Add tarsier model support (#18985) Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> * [bugfix] small fix logic issue (#18999) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * Reduce logs in CLI scripts and plugin loader (#18970) Signed-off-by: mgoin <mgoin64@gmail.com> * [Bugfix] Use cmake 3.26.1 instead of 3.26 to avoid build failure (#19019) Signed-off-by: Lu Fang <lufang@fb.com> * [v1][KVCacheManager] Rename BlockHashType to BlockHash (#19015) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * Update docker docs with ARM CUDA cross-compile (#19037) Signed-off-by: mgoin <michael@neuralmagic.com> * [Doc] Add InternVL LoRA support (#19055) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> * [Misc] Update `WeightsMapper` for qwen2-vl/qwen2.5-vl (#19054) Signed-off-by: Isotr0py <2037008807@qq.com> * [Doc] Update V1 user guide for embedding and enc-dec models (#19060) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [doc] clarify windows support (#19088) Signed-off-by: youkaichao <youkaichao@gmail.com> * [CI/Build] Remove V0 LoRA test (#19066) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> * Fix underscores in dict keys passed via CLI (#19030) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> * [Bugfix] disable processor cache (#19068) Signed-off-by: raushan <raushan@huggingface.co> * [Doc] Improve the Pull Request template with key components (#19086) Signed-off-by: Lu Fang <lufang@fb.com> * [Misc] Add missing `_Backend` enums (#19081) Signed-off-by: nicklucche <nlucches@redhat.com> * [Misc] fix: add miss best_of param validation (#18555) Signed-off-by: googs1025 <googs1025@gmail.com> * [Misc] Add SPDX-FileCopyrightText (#19100) Signed-off-by: simon-mo <simon.mo@hey.com> * [Doc] Readme standardization (#18695) Co-authored-by: Soren Dreano <soren@numind.ai> * [doc] update docker version (#19074) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Kernel] DeepEP dispatch-combine kernel integration (#18434) Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> * [V1] Support cross-layer KV sharing (#18212) Signed-off-by: Yong Hoon Shin <yhshin@meta.com> * [Perf] Tune `scaled_fp8_quant` by increasing vectorization (#18844) Signed-off-by: mgoin <mgoin64@gmail.com> * Fix interaction between `Optional` and `Annotated` in CLI typing (#19093) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Yikun Jiang <yikun@apache.org> * [v1] Re-init input batch for multiple kv cache groups (#18654) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * [V1][Spec Decode][Ngram] 1.35x gain -> 1.95x gain on InstructCoder with prompt fix (#18971) * [Bugfix] get_num_blocks_to_allocate with null_block (#19031) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * [Bugfix]: Fix the incompatibility issue with tool_choice 'required' when Thinking is enabled (#19075) Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> * [Bugfix][P/D] Fix Prefix Cache Bug (#18411) Signed-off-by: nicklucche <nlucches@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> * [Bugfix] Max concurrency estimation and check_enough_kv_cache_memory for models with sliding window layers (#19029) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * feat: add data parallel rank to KVEventBatch (#18925) * [Misc] Fix path and python alias errors in disagg_prefill exmaples (#18919) * [Docs] Add developer doc about CI failures (#18782) Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> * [CPU] V1 support for the CPU backend (#16441) * [Core] Cast multimodal input in hf processor (#18862) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> * [KERNEL] Sampler. CUDA kernel for applying repetition penalty (#18437) * [Cleanup][v1]:remote guided-decoding-backend for example (#19059) Signed-off-by: calvin chen <120380290@qq.com> * [NVIDIA] Add Cutlass MLA backend (#17625) * [Bugfix] Fix FA3 full cuda graph correctness (#19106) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> * Fix #19130 (#19132) Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> * [TPU] Skip hanging tests (#19115) Signed-off-by: Siyuan Liu <lsiyuan@google.com> * Fix ValueError: Missing value for tag key(s): model_name,engine. (#19113) Signed-off-by: Seiji Eicher <seiji@anyscale.com> * [Misc] Add packages for benchmark as extra dependency (#19089) Signed-off-by: Isotr0py <2037008807@qq.com> * Improve the output precision of embedding models (#19092) * [CI/Build][Bugfix] Ensure compatibility with transformers 4.52 (#18678) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * Add DeepSeek-R1-0528 function call chat template (#18874) Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> * Sm100 blockwise fp8 swap ab (#18564) * [Doc] Update V1 Guide for embedding models (#19141) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * Allow AsyncLLMEngine.generate to target a specific DP rank (#19102) Signed-off-by: Jon Swenson <jmswen@gmail.com> * [Bugfix][EP+DP] Fix internode check (#19112) Signed-off-by: Tyler Michael Smith <tysmith@redhat.com> * [Perf] Tunings for SM100 FP8 CUTLASS kernel (#18778) Signed-off-by: mgoin <mgoin64@gmail.com> * [TPU] Update dynamo dump file name in compilation test (#19108) Signed-off-by: Siyuan Liu <lsiyuan@google.com> * [Bugfix] fix v1 cpu worker fails on macOS (#19121) * [Kernel] Integrate batched/masked deepgemm kernel (#19111) Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com> * [Misc] refactor: simplify EngineCoreClient.make_async_mp_client in AsyncLLM (#18817) Signed-off-by: googs1025 <googs1025@gmail.com> * [P/D] Heterogeneous TP (#18833) Signed-off-by: nicklucche <nlucches@redhat.com> * [doc] small fix (#19167) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Bugfix][Nixl] Fix full prefix cache hit bug (#18632) Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> * [Bugfix] Fix port handling in make_zmq_path (#19117) * [Torch Nightly]add missing dependency (#18770) Signed-off-by: Yang Wang <elainewy@meta.com> * Handle non-serializable objects when dumping benchmark results (#19114) * [BugFix][Minor] Fix full cuda graph bug when max_num_seqs < 512 (#19171) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> * [Bugfix]: Fix the incompatibility issue with stream when Thinking is disabled (#19135) Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> * [Build] Annotate wheel and container path for release workflow (#19162) Signed-off-by: simon-mo <simon.mo@hey.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * [Misc] Remove unnecessary fallback to prefill-decode attention (#19138) Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> * [Misc] Do not override NCCL_CUMEM_ENABLE if set explicitly (#19105) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> * [Frontend] improve vllm run-batch --help display (#19187) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Bugfix] properly catch PIL-related errors for vision models when incorrect data urls are provided (#19202) Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com> * [mistral_common] Add v11 tokenizer (#19193) Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> * Add H20-3e fused MoE kernel tuning configs for DeepSeek-R1/V3 (#19205) * [Hardware][NVIDIA] FP4 MoE kernel optimization (#19110) Signed-off-by: Chiyue Wei <chiyuew@nvidia.com> Co-authored-by: Chiyue Wei <chiyuew@nvidia.com> * [MISC][Bugfix] Use less CPU when message queue has been empty for some time (#16226) Signed-off-by: Povilas Kanapickas <povilas@radix.lt> * [P/D][NixlConnector] Enable FlashInfer backend (#19090) * [Quantization] Skip Fp4 Test for `compressed-tensors` (#19217) * [V1] Use FlashInfer by default on Blackwell GPUs (#19118) * [Model] NemotronH support (#18863) Signed-off-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> Co-authored-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> * Fix AOPerModuleConfig name changes (#18869) Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> * [Bugfix] Fix EAGLE vocab embedding construction for Llama 70B (#19033) Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> * [v1] Hybrid Memory Allocator (#17996) Signed-off-by: Chen Zhang <zhangch99@outlook.com> * [TPU] update torch_xla pin (#19231) Signed-off-by: Chengji Yao <chengjiyao@google.com> * Support allowed_token_ids in ChatCompletionRequest (#19143) Signed-off-by: Xu Song <xusong.vip@gmail.com> * [Chore] update CODEOWNERS (#19247) Signed-off-by: Aaron Pham <contact@aarnphm.xyz> * [v1][P/D] Fix a edge case in kv cache schedule (#19182) Co-authored-by: jinghui <jinghui@fb.com> * [TPU] fix kv cache dtype in model runner (#19244) Signed-off-by: Chengji Yao <chengjiyao@google.com> * [Quantization] Bump compressed-tensors version; update NVFP4A16 test model (#19224) Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> * [Docs] Improve V1 KVConnector interface documentation (#19172) Signed-off-by: Nick Hill <nhill@redhat.com> * Fix CompilationConfig repr (#19091) Signed-off-by: rzou <zou3519@gmail.com> * Unit Test for run_dp_sharded_vision_model (#19103) Signed-off-by: Siqi Yan <siqi@meta.com> Co-authored-by: Siqi Yan <siqi@meta.com> * [Model] Optimize nemotron_h implementation (#19249) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> * [Core] Raise when non-multi-instance DP clients target a DP rank (#19227) Signed-off-by: Jon Swenson <jmswen@gmail.com> * improve logits bias (#19041) * Fixed ppc build when it runs on non-RHEL based linux distros (#18422) Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com> Signed-off-by: Md. Shafi Hussain <Md.Shafi.Hussain@ibm.com> Signed-off-by: npanpaliya <nishidha.panpaliya@partner.ibm.com> Co-authored-by: Md. Shafi Hussain <Md.Shafi.Hussain@ibm.com> * [BugFix] Fix MultiConnector test after HMA changes (#19291) Signed-off-by: Nick Hill <nhill@redhat.com> * [Bugfix][Core] Update cancellation logic in `generate()` to handle Generator exits (#19225) Co-authored-by: Adolfo Victoria <adovi@meta.com> * [Core] Fix abrupt request abort (#18485) Signed-off-by: nicklucche <nlucches@redhat.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> * [BugFix] Fix tpu_model_runner block_id concatenation (#19228) Signed-off-by: Nick Hill <nhill@redhat.com> * [Misc][Tools][Benchmark] Fix and improve auto tune script (#19163) Signed-off-by: Chenyaaang <chenyangli@google.com> * [Build][ROCm] Update Dockerfile.rocm (#19296) Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com> * [Easy][Test] Simplify test_function_tool_use with multiple parametrizes (#19269) Signed-off-by: Lu Fang <lufang@fb.com> * [Kernel] Integrate CUTLASS MoE kernel with PPLX (#18762) Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> * [TPU][Test] Add script to run benchmark on TPU for buildkite (#19039) Signed-off-by: Qiliang Cui <derrhein@gmail.com> * [CI][PowerPC] Use a more appropriate way to select testcase in tests/models/language/pooling/test_embedding.py (#19253) Signed-off-by: Aaruni Aggarwal <aaruniagg@gmail.com> * Add FlexAttention to V1 (#16078) Signed-off-by: drisspg <drisspguessous@gmail.com> * [Misc] refactor context extension (#19246) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [CI/Build] Improve Llama GGUF test robustness (#19287) Signed-off-by: Isotr0py <2037008807@qq.com> * [Nit][Benchmark]Fix example in benchmark_serving_structured_output.py (#19311) Signed-off-by: Lifan Shen <lifans@meta.com> * [AMD] Update compatible packaging version (#19309) Signed-off-by: pramkuma <Pramendra.Kumar@amd.com> * [BugFix][V1] Fix memory profiling bug (#18974) Signed-off-by: luka <luka@neuralmagic.com> * [Bugfix]: Fix TypeError: 'float' object cannot be interpreted as an integer (#19283) Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> * [Bugfix] Re-enable use_cudagraph in vLLM v1 (#19299) Signed-off-by: Richard Zou <zou3519@gmail.com> * [Misc] Change tests/compile to use VLLM_V1 by default (#19302) Signed-off-by: rzou <zou3519@gmail.com> * Add H20-3e fused MoE kernel tuning configs for Qwen3-235B-A22B (#19315) Signed-off-by: Xu Wenqing <xuwq1993@qq.com> * [Hardware][POWER] Add IBM POWER11 Support to CPU Extension Detection (#19082) Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> * [Quantization] Add compressed-tensors NVFP4 support (#18312) * [Multi Modal] Add an env var for message queue max chunk bytes (#19242) Signed-off-by: yZhen <yZhen@fb.com> Co-authored-by: yZhen <yZhen@fb.com> * [Bugfix] model_max_length should consider max_model_len in tokenizer_config (#19201) * [Deprecation] Remove `inputs` arg fallback in Engine classes (#18799) Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> * [Misc] Add documentation update reminder to PR template (#19289) Signed-off-by: Isotr0py <2037008807@qq.com> * [Frontend] Remove unreachable code from llm.py (#19288) Signed-off-by: KsuParkhamchuk <k.parkhamchuk@gmail.com> * [Misc] Cleanup compilation tests (#19343) Signed-off-by: rzou <zou3519@gmail.com> * [doc] improve ci doc (#19307) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Doc] Fix description in the Automatic Prefix Caching design doc (#19333) Signed-off-by: cr7258 <chengzw258@163.com> * [CI/Build] Fix LoRA test (#19350) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> * [Fix] Allow kernel compilation for CUDA capability 8.7 (#19328) Signed-off-by: Conroy Cheers <conroy@corncheese.org> * [CI] Introduce rules for llama auto-label (#19323) Signed-off-by: Lu Fang <lufang@fb.com> * [Docs] Fix a bullet list in usage/security.md (#19358) Signed-off-by: windsonsea <haifeng.yao@daocloud.io> * [full_graph] Fix query_start_loc padding (#19321) Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai> * [v1] Add fp32 support to v1 engine through flex attn (#19319) Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> * [Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298) Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com> * [Bugfix][Core] Prevent token lengths exceeding `max_model_len` in V0 (#19348) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> * [Quantization] Bump compressed-tensors version (#19295) Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * [Frontend] Make TIMEOUT_KEEP_ALIVE configurable through env var (#18472) Signed-off-by: liusiqian <liusiqian@tal.com> * [TPU]Fix KV cache sharing tests (#19371) * [HOT-FIX] Add `kv_sharing_target_layer_name` argument to cutlass_mla backend (#19374) Signed-off-by: Pavani Majety <pmajety@nvidia.com> * [Misc] Fix a config typo in disable_hybrid_kv_cache_manager configuration (#19383) Signed-off-by: Siyuan Liu <lsiyuan@google.com> * [V1] Reuse V0's memory_profiling util for gpu worker memory profiling (#19312) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> * [Bugfix] Fix benchmark_moe.py (#19016) Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn> * Use xla flag to improve the quantized model performance (#19303) Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> * Fix docs/mkdocs/hooks/remove_announcement.py (#19382) * [Frontend] Add tqdm_leave_pbar to control progress bar visibility (#19357) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * [Core] Use tuple for kv cache group block ids (#19175) Signed-off-by: Nick Hill <nhill@redhat.com> * [Bugfix] Fix modelscope token passed in (#19389) Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> * [Core] Batch multi modal input using pinned memory (#19169) Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> * Add security warning to bug report template (#19365) Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [Misc] refactor neuron_multimodal and profiling (#19397) Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com> * Add clear documentation around the impact of debugging flag (#19369) Signed-off-by: Anna Pendleton <pendleton@google.com> * Automatically bind CPU OMP Threads of a rank to CPU ids of a NUMA node. (#17930) Signed-off-by: Tsai, Louie <louie.tsai@intel.com> Co-authored-by: Li, Jiang <bigpyj64@gmail.com> * Revert "[v1] Add fp32 support to v1 engine through flex attn" (#19404) * [BugFix][FlashInfer] Fix attention backend interface mismatch with unexpected keyword `use_irope` (#19134) Signed-off-by: Yunqiu Guo <guorachel@meta.com> * [BugFix][CPU] Fix CPU CI by ignore collecting test_pixtral (#19411) Signed-off-by: jiang.li <jiang1.li@intel.com> * Simplify ep kernels installation (#19412) Signed-off-by: youkaichao <youkaichao@gmail.com> * [Misc] Slight improvement of the BNB (#19418) Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> * fix Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> * update config Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> * add Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> --------- Signed-off-by: Hongxia Yang <hongxia.yang@amd.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: rzou <zou3519@gmail.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: Elaine Zhao <elaineyz@amazon.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: Brent Salisbury <bsalisbu@redhat.com> Signed-off-by: Satyajith Chilappagari <satchill@amazon.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: reidliu41 <reid201711@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: luka <luka@neuralmagic.com> Signed-off-by: Chenyaaang <chenyangli@google.com> Signed-off-by: Duyi-Wang <duyi.wang@intel.com> Signed-off-by: Zerohertz <ohg3417@gmail.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: googs1025 <googs1025@gmail.com> Signed-off-by: nicklucche <nlucches@redhat.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Will Eaton <weaton@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Carol Zheng <cazheng@google.com> Signed-off-by: wenhuach21 <wenhua.cheng@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: iLeGend <824040212@qq.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: rabi <ramishra@redhat.com> Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com> Signed-off-by: huangyuxiang03 <huangyx0321@gmail.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: isotr0py <2037008807@qq.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Fred Reiss <frreiss@us.ibm.com> Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Piotr Tarasiewicz <ptarasiewicz@nvidia.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Tyler Michael Smith <tysmith@redhat.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Signed-off-by: François Paupier <francois.paupier@gmail.com> Signed-off-by: calvin chen <120380290@qq.com> Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Yida Wu <yidawu@alumni.cmu.edu> Signed-off-by: jiang.li <jiang1.li@intel.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com> Signed-off-by: Lu Fang <lufang@fb.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: raushan <raushan@huggingface.co> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Varun <vsundarr@redhat.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Seiji Eicher <seiji@anyscale.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Jon Swenson <jmswen@gmail.com> Signed-off-by: Yang Wang <elainewy@meta.com> Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Chiyue Wei <chiyuew@nvidia.com> Signed-off-by: Povilas Kanapickas <povilas@radix.lt> Signed-off-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> Signed-off-by: Jerry Zhang <jerryzh168@gmail.com> Signed-off-by: Xu Song <xusong.vip@gmail.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Siqi Yan <siqi@meta.com> Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com> Signed-off-by: Md. Shafi Hussain <Md.Shafi.Hussain@ibm.com> Signed-off-by: npanpaliya <nishidha.panpaliya@partner.ibm.com> Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Qiliang Cui <derrhein@gmail.com> Signed-off-by: Aaruni Aggarwal <aaruniagg@gmail.com> Signed-off-by: drisspg <drisspguessous@gmail.com> Signed-off-by: Lifan Shen <lifans@meta.com> Signed-off-by: pramkuma <Pramendra.Kumar@amd.com> Signed-off-by: Richard Zou <zou3519@gmail.com> Signed-off-by: Xu Wenqing <xuwq1993@qq.com> Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Signed-off-by: yZhen <yZhen@fb.com> Signed-off-by: KsuParkhamchuk <k.parkhamchuk@gmail.com> Signed-off-by: cr7258 <chengzw258@163.com> Signed-off-by: Conroy Cheers <conroy@corncheese.org> Signed-off-by: Yinghai Lu <yinghai@thinkingmachines.ai> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: liusiqian <liusiqian@tal.com> Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn> Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: Anna Pendleton <pendleton@google.com> Signed-off-by: Tsai, Louie <louie.tsai@intel.com> Signed-off-by: Yunqiu Guo <guorachel@meta.com> Signed-off-by: Amog Kamsetty <amogkamsetty@gmail.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Richard Zou <zou3519@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: aws-elaineyz <elaineyz@amazon.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Brent Salisbury <bsalisbu@redhat.com> Co-authored-by: Satyajith Chilappagari <satchill@amazon.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com> Co-authored-by: reidliu41 <reid201711@gmail.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com> Co-authored-by: Duyi-Wang <duyi.wang@intel.com> Co-authored-by: Hyogeun Oh (오효근) <ohg3417@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: CYJiang <86391540+googs1025@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Will Eaton <weaton@redhat.com> Co-authored-by: Will Eaton <wseaton@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Carol Zheng <cazheng@google.com> Co-authored-by: Wenhua Cheng <wenhua.cheng@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: iLeGend <youzhi.jin@intel.com> Co-authored-by: H <linhaibin.eric@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Rabi Mishra <ramishra@redhat.com> Co-authored-by: Always-Naive <97138029+Always-Naive@users.noreply.github.com> Co-authored-by: Daniele <36171005+dtrifiro@users.noreply.github.com> Co-authored-by: Shawn Huang <57223022+huangyuxiang03@users.noreply.github.com> Co-authored-by: huangyuxiang03 <huangyx0321@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Yu Guo <82124926+yuguo68@users.noreply.github.com> Co-authored-by: Pooya Davoodi <pooya.davoodi@parasail.io> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com> Co-authored-by: Rohith Nallamaddi <nalrohit@amazon.com> Co-authored-by: FeliciaLuo <luof@amazon.com> Co-authored-by: Fred Reiss <frreiss@us.ibm.com> Co-authored-by: Charlie Fu <charlifu@amd.com> Co-authored-by: ptarasiewiczNV <104908264+ptarasiewiczNV@users.noreply.github.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: zhrrr <43847754+izhuhaoran@users.noreply.github.com> Co-authored-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: jennyyyyzhen <47012288+jennyyyyzhen@users.noreply.github.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: Frαnçois <francois.paupier@gmail.com> Co-authored-by: Calvin Chen <45745657+calvin0327@users.noreply.github.com> Co-authored-by: Siyuan Liu <lsiyuan@google.com> Co-authored-by: Hossein Sarshar <hossein.sarshar@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Concurrensee <yidawu@alumni.cmu.edu> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Co-authored-by: 汪志鹏 <wangzhipeng628@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: SorenDreano <71752785+SorenDreano@users.noreply.github.com> Co-authored-by: Soren Dreano <soren@numind.ai> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Yikun Jiang <yikun@apache.org> Co-authored-by: Yan Ru Pei <yanrpei@gmail.com> Co-authored-by: Jiaxin Shan <seedjeffwan@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Kaixi Hou <kaixih@nvidia.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Lain <fusiyuan2000@hotmail.com> Co-authored-by: jmswen <jmswen@users.noreply.github.com> Co-authored-by: Kebe <mail@kebe7jun.com> Co-authored-by: Yang Wang <elainewy@meta.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Guillaume Calmettes <gcalmettes@scaleway.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Chiyue Wei <92623189+dubcyfor3@users.noreply.github.com> Co-authored-by: Chiyue Wei <chiyuew@nvidia.com> Co-authored-by: Povilas Kanapickas <povilas@radix.lt> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com> Co-authored-by: Luis Vega <vegaluisjose@users.noreply.github.com> Co-authored-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> Co-authored-by: Jerry Zhang <jerryzh168@gmail.com> Co-authored-by: Xu Song <xusong.vip@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Jinghui Zhang <jinghuizhang0804@gmail.com> Co-authored-by: jinghui <jinghui@fb.com> Co-authored-by: Siqi Yan <ysq0807@hotmail.com> Co-authored-by: Siqi Yan <siqi@meta.com> Co-authored-by: Nishidha <nishidha.panpaliya@partner.ibm.com> Co-authored-by: Md. Shafi Hussain <Md.Shafi.Hussain@ibm.com> Co-authored-by: Adolfo Victoria <adolfokarim@gmail.com> Co-authored-by: Adolfo Victoria <adovi@meta.com> Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: QiliangCui <derrhein@gmail.com> Co-authored-by: Aaruni Aggarwal <47731267+AaruniAggarwal@users.noreply.github.com> Co-authored-by: Driss Guessous <32754868+drisspg@users.noreply.github.com> Co-authored-by: Lifans <draftbks@gmail.com> Co-authored-by: pramenku <7664080+pramenku@users.noreply.github.com> Co-authored-by: Akash kaothalkar <61960177+Akashcodes732@users.noreply.github.com> Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com> Co-authored-by: Kseniya Parkhamchuk <43078183+KsuParkhamchuk@users.noreply.github.com> Co-authored-by: Se7en <chengzw258@163.com> Co-authored-by: Conroy Cheers <conroy@corncheese.org> Co-authored-by: Yinghai Lu <yinghai@thinkingmachines.ai> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: liusiqian-tal <141730978+liusiqian-tal@users.noreply.github.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn> Co-authored-by: XiongfeiWei <isaacwxf23@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Anna Pendleton <pendleton@google.com> Co-authored-by: Louie Tsai <louie.tsai@intel.com> Co-authored-by: Li, Jiang <bigpyj64@gmail.com> Co-authored-by: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com> --- csrc/cpu/torch_bindings.cpp | 27 ++ .../cutlass_w8a8/scaled_mm_entry.cu | 23 ++ .../installation/ai_accelerator.md | 117 ++++++ .../installation/ai_accelerator/neuron.inc.md | 154 ++++++++ .../installation/intel_gaudi.md | 25 +- tests/compile/piecewise/test_simple.py | 8 + tests/compile/piecewise/test_toy_llama.py | 8 + .../openai/correctness/test_mteb.py | 43 +++ tests/kernels/moe/deepep_utils.py | 191 ++++++++++ .../models/language/generation/test_common.py | 3 + .../generation/test_granitemoehybrid.py | 42 +++ .../models/language/generation/test_hybrid.py | 3 + tests/models/registry.py | 8 +- tests/pplx_utils.py | 123 ++++++ tests/test_config.py | 12 + tests/v1/engine/test_engine_core_client.py | 1 + tests/v1/sample/test_topk_topp_sampler.py | 1 + tests/v1/sample/utils.py | 2 + vllm/benchmarks/datasets.py | 93 +++++ vllm/benchmarks/serve.py | 48 +++ .../kv_transfer/kv_connector/utils.py | 1 + vllm/engine/arg_utils.py | 23 +- vllm/entrypoints/chat_utils.py | 3 +- vllm/entrypoints/cli/serve.py | 1 + vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 10 + vllm/entrypoints/openai/serving_engine.py | 11 + .../layers/fused_moe/fused_moe.py | 14 + vllm/model_executor/layers/fused_moe/layer.py | 218 +++++++++++ .../compressed_tensors_moe.py | 35 ++ vllm/model_executor/models/bert.py | 13 +- vllm/platforms/cpu.py | 19 + vllm/platforms/tpu.py | 16 + vllm/v1/attention/backends/flash_attn.py | 2 + vllm/v1/core/kv_cache_manager.py | 6 + vllm/v1/core/sched/scheduler.py | 30 ++ vllm/v1/engine/core.py | 1 + vllm/v1/request.py | 6 +- vllm/v1/utils.py | 353 ++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 24 ++ vllm/v1/worker/tpu_model_runner.py | 11 + 41 files changed, 1712 insertions(+), 19 deletions(-) create mode 100644 docs/getting_started/installation/ai_accelerator.md create mode 100644 docs/getting_started/installation/ai_accelerator/neuron.inc.md create mode 100644 tests/entrypoints/openai/correctness/test_mteb.py create mode 100644 tests/kernels/moe/deepep_utils.py create mode 100644 tests/models/language/generation/test_granitemoehybrid.py create mode 100644 tests/pplx_utils.py diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index ebfc81f85836..12928585923a 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -191,6 +191,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#elif defined(__powerpc64__) + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + // Compute int8 quantized tensor and scaling factor ops.def( "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 31b60488dfb7..c79a085a1b09 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -294,6 +294,29 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, version_num, ". Required capability: 90 or 100"); } +void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, const int64_t n, + const int64_t k) { + // This function currently gets compiled only if we have a valid cutlass moe + // mm to run it for. + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, + problem_sizes2, expert_num_tokens, + num_local_experts, padded_m, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " + "for CUDA device capability: ", + version_num, ". Required capability: 90"); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/docs/getting_started/installation/ai_accelerator.md b/docs/getting_started/installation/ai_accelerator.md new file mode 100644 index 000000000000..a4f136a172fe --- /dev/null +++ b/docs/getting_started/installation/ai_accelerator.md @@ -0,0 +1,117 @@ +# Other AI accelerators + +vLLM is a Python library that supports the following AI accelerators. Select your AI accelerator type to see vendor specific instructions: + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:installation" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:installation" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:installation" + +## Requirements + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:requirements" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:requirements" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:requirements" + +## Configure a new environment + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:configure-a-new-environment" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:configure-a-new-environment" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:configure-a-new-environment" + +## Set up using Python + +### Pre-built wheels + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:pre-built-wheels" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:pre-built-wheels" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:pre-built-wheels" + +### Build wheel from source + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:build-wheel-from-source" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:build-wheel-from-source" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:build-wheel-from-source" + +## Set up using Docker + +### Pre-built images + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:pre-built-images" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:pre-built-images" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:pre-built-images" + +### Build image from source + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:build-image-from-source" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:build-image-from-source" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:build-image-from-source" + +## Extra information + +=== "Google TPU" + + --8<-- "docs/getting_started/installation/ai_accelerator/tpu.inc.md:extra-information" + +=== "Intel Gaudi" + + --8<-- "docs/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md:extra-information" + +=== "AWS Neuron" + + --8<-- "docs/getting_started/installation/ai_accelerator/neuron.inc.md:extra-information" diff --git a/docs/getting_started/installation/ai_accelerator/neuron.inc.md b/docs/getting_started/installation/ai_accelerator/neuron.inc.md new file mode 100644 index 000000000000..86c12472fb36 --- /dev/null +++ b/docs/getting_started/installation/ai_accelerator/neuron.inc.md @@ -0,0 +1,154 @@ +# --8<-- [start:installation] + +[AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/) is the software development kit (SDK) used to run deep learning and + generative AI workloads on AWS Inferentia and AWS Trainium powered Amazon EC2 instances and UltraServers (Inf1, Inf2, Trn1, Trn2, + and Trn2 UltraServer). Both Trainium and Inferentia are powered by fully-independent heterogeneous compute-units called NeuronCores. + This tab describes how to set up your environment to run vLLM on Neuron. + +!!! warning + There are no pre-built wheels or images for this device, so you must build vLLM from source. + +# --8<-- [end:installation] +# --8<-- [start:requirements] + +- OS: Linux +- Python: 3.9 or newer +- Pytorch 2.5/2.6 +- Accelerator: NeuronCore-v2 (in trn1/inf2 chips) or NeuronCore-v3 (in trn2 chips) +- AWS Neuron SDK 2.23 + +## Configure a new environment + +### Launch a Trn1/Trn2/Inf2 instance and verify Neuron dependencies + +The easiest way to launch a Trainium or Inferentia instance with pre-installed Neuron dependencies is to follow this +[quick start guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/neuron-setup/multiframework/multi-framework-ubuntu22-neuron-dlami.html#setup-ubuntu22-multi-framework-dlami) using the Neuron Deep Learning AMI (Amazon machine image). + +- After launching the instance, follow the instructions in [Connect to your instance](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AccessingInstancesLinux.html) to connect to the instance +- Once inside your instance, activate the pre-installed virtual environment for inference by running +```console +source /opt/aws_neuronx_venv_pytorch_2_6_nxd_inference/bin/activate +``` + +Refer to the [NxD Inference Setup Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/nxdi-setup.html) +for alternative setup instructions including using Docker and manually installing dependencies. + +!!! note + NxD Inference is the default recommended backend to run inference on Neuron. If you are looking to use the legacy [transformers-neuronx](https://github.com/aws-neuron/transformers-neuronx) + library, refer to [Transformers NeuronX Setup](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/setup/index.html). + +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] + +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] + +Currently, there are no pre-built Neuron wheels. + +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] + +#### Install vLLM from source + +Install vllm as follows: + +```console +git clone https://github.com/vllm-project/vllm.git +cd vllm +pip install -U -r requirements/neuron.txt +VLLM_TARGET_DEVICE="neuron" pip install -e . +``` + +AWS Neuron maintains a [Github fork of vLLM](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2) at + [https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2), which contains several features in addition to what's + available on vLLM V0. Please utilize the AWS Fork for the following features: + +- Llama-3.2 multi-modal support +- Multi-node distributed inference + +Refer to [vLLM User Guide for NxD Inference](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/vllm-user-guide.html) + for more details and usage examples. + +To install the AWS Neuron fork, run the following: + +```console +git clone -b neuron-2.23-vllm-v0.7.2 https://github.com/aws-neuron/upstreaming-to-vllm.git +cd upstreaming-to-vllm +pip install -r requirements/neuron.txt +VLLM_TARGET_DEVICE="neuron" pip install -e . +``` + +Note that the AWS Neuron fork is only intended to support Neuron hardware; compatibility with other hardwares is not tested. + +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] + +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] + +Currently, there are no pre-built Neuron images. + +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] + +See [deployment-docker-build-image-from-source][deployment-docker-build-image-from-source] for instructions on building the Docker image. + +Make sure to use <gh-file:docker/Dockerfile.neuron> in place of the default Dockerfile. + +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] + +[](){ #feature-support-through-nxd-inference-backend } +### Feature support through NxD Inference backend + +The current vLLM and Neuron integration relies on either the `neuronx-distributed-inference` (preferred) or `transformers-neuronx` backend + to perform most of the heavy lifting which includes PyTorch model initialization, compilation, and runtime execution. Therefore, most + [features supported on Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html) are also available via the vLLM integration. + +To configure NxD Inference features through the vLLM entrypoint, use the `override_neuron_config` setting. Provide the configs you want to override +as a dictionary (or JSON object when starting vLLM from the CLI). For example, to disable auto bucketing, include +```console +override_neuron_config={ + "enable_bucketing":False, +} +``` +or when launching vLLM from the CLI, pass +```console +--override-neuron-config "{\"enable_bucketing\":false}" +``` + +Alternatively, users can directly call the NxDI library to trace and compile your model, then load the pre-compiled artifacts +(via `NEURON_COMPILED_ARTIFACTS` environment variable) in vLLM to run inference workloads. + +### Known limitations + +- EAGLE speculative decoding: NxD Inference requires the EAGLE draft checkpoint to include the LM head weights from the target model. Refer to this + [guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/feature-guide.html#eagle-checkpoint-compatibility) + for how to convert pretrained EAGLE model checkpoints to be compatible for NxDI. +- Quantization: the native quantization flow in vLLM is not well supported on NxD Inference. It is recommended to follow this + [Neuron quantization guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/nxd-inference/developer_guides/custom-quantization.html) + to quantize and compile your model using NxD Inference, and then load the compiled artifacts into vLLM. +- Multi-LoRA serving: NxD Inference only supports loading of LoRA adapters at server startup. Dynamic loading of LoRA adapters at + runtime is not currently supported. Refer to [multi-lora example](https://github.com/aws-neuron/upstreaming-to-vllm/blob/neuron-2.23-vllm-v0.7.2/examples/offline_inference/neuron_multi_lora.py) +- Multi-modal support: multi-modal support is only available through the AWS Neuron fork. This feature has not been upstreamed + to vLLM main because NxD Inference currently relies on certain adaptations to the core vLLM logic to support this feature. +- Multi-node support: distributed inference across multiple Trainium/Inferentia instances is only supported on the AWS Neuron fork. Refer + to this [multi-node example](https://github.com/aws-neuron/upstreaming-to-vllm/tree/neuron-2.23-vllm-v0.7.2/examples/neuron/multi_node) + to run. Note that tensor parallelism (distributed inference across NeuronCores) is available in vLLM main. +- Known edge case bug in speculative decoding: An edge case failure may occur in speculative decoding when sequence length approaches + max model length (e.g. when requesting max tokens up to the max model length and ignoring eos). In this scenario, vLLM may attempt + to allocate an additional block to ensure there is enough memory for number of lookahead slots, but since we do not have good support + for paged attention, there isn't another Neuron block for vLLM to allocate. A workaround fix (to terminate 1 iteration early) is + implemented in the AWS Neuron fork but is not upstreamed to vLLM main as it modifies core vLLM logic. + + +### Environment variables +- `NEURON_COMPILED_ARTIFACTS`: set this environment variable to point to your pre-compiled model artifacts directory to avoid + compilation time upon server initialization. If this variable is not set, the Neuron module will perform compilation and save the + artifacts under `neuron-compiled-artifacts/{unique_hash}/` sub-directory in the model path. If this environment variable is set, + but the directory does not exist, or the contents are invalid, Neuron will also fallback to a new compilation and store the artifacts + under this specified path. +- `NEURON_CONTEXT_LENGTH_BUCKETS`: Bucket sizes for context encoding. (Only applicable to `transformers-neuronx` backend). +- `NEURON_TOKEN_GEN_BUCKETS`: Bucket sizes for token generation. (Only applicable to `transformers-neuronx` backend). + +# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 7a7a5a51c24c..7cf20cd171cf 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -5,7 +5,8 @@ This page provides instructions on running vLLM with Intel Gaudi devices. !!! warning There are no pre-built wheels or images for this device, so you must build vLLM from source. -## Requirements +# --8<-- [end:installation] +# --8<-- [start:requirements] - OS: Ubuntu 22.04 LTS - Python: 3.10 @@ -55,13 +56,16 @@ docker run \ vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest ``` -## Set up using Python +# --8<-- [end:requirements] +# --8<-- [start:set-up-using-python] -### Pre-built wheels +# --8<-- [end:set-up-using-python] +# --8<-- [start:pre-built-wheels] Currently, there are no pre-built Intel Gaudi wheels. -### Build wheel from source +# --8<-- [end:pre-built-wheels] +# --8<-- [start:build-wheel-from-source] To build and install vLLM from source, run: @@ -82,13 +86,16 @@ pip install -r requirements/hpu.txt python setup.py develop ``` -## Set up using Docker +# --8<-- [end:build-wheel-from-source] +# --8<-- [start:set-up-using-docker] -### Pre-built images +# --8<-- [end:set-up-using-docker] +# --8<-- [start:pre-built-images] Currently, there are no pre-built Intel Gaudi images. -### Build image from source +# --8<-- [end:pre-built-images] +# --8<-- [start:build-image-from-source] ```bash docker build -f docker/Dockerfile.hpu -t vllm-hpu-env . @@ -105,7 +112,8 @@ docker run \ !!! tip If you're observing the following error: `docker: Error response from daemon: Unknown runtime specified habana.`, please refer to "Install Using Containers" section of [Intel Gaudi Software Stack and Driver Installation](https://docs.habana.ai/en/v1.18.0/Installation_Guide/Bare_Metal_Fresh_OS.html). Make sure you have `habana-container-runtime` package installed and that `habana` container runtime is registered. -## Extra information +# --8<-- [end:build-image-from-source] +# --8<-- [start:extra-information] ### Supported features @@ -401,3 +409,4 @@ the below: higher batches. You can do that by adding `--enforce-eager` flag to server (for online serving), or by passing `enforce_eager=True` argument to LLM constructor (for offline inference). +# --8<-- [end:extra-information] diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 06ac3527e1fb..eb589f52f9d9 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -114,3 +114,11 @@ def test_simple_piecewise_compile(use_inductor): output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + + +def test_simple_piecewise_compile_inductor(): + _test_simple_piecewise_compile(use_inductor=True) + + +def test_simple_piecewise_compile_no_inductor(): + _test_simple_piecewise_compile(use_inductor=False) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index b7ed8353b3ce..145b8c16d081 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -384,6 +384,14 @@ def test_toy_llama(use_inductor: bool): assert torch.allclose(outputs[0], outputs[i]) +def test_toy_llama_inductor(): + _test_toy_llama(use_inductor=True) + + +def test_toy_no_inductor(): + _test_toy_llama(use_inductor=False) + + @torch.inference_mode def benchmark(): from triton.testing import do_bench diff --git a/tests/entrypoints/openai/correctness/test_mteb.py b/tests/entrypoints/openai/correctness/test_mteb.py new file mode 100644 index 000000000000..437c48511352 --- /dev/null +++ b/tests/entrypoints/openai/correctness/test_mteb.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import pytest + +from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, + MTEB_EMBED_TOL, + OpenAIClientMtebEncoder, + run_mteb_embed_task, + run_mteb_embed_task_st) +from tests.utils import RemoteOpenAIServer + +os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" + +MODEL_NAME = "BAAI/bge-m3" +DTYPE = "float16" +MAIN_SCORE = 0.7873427091972599 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", "embed", "--dtype", DTYPE, "--enforce-eager", + "--max-model-len", "512" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def test_mteb(server): + client = server.get_client() + encoder = OpenAIClientMtebEncoder(MODEL_NAME, client) + vllm_main_score = run_mteb_embed_task(encoder, MTEB_EMBED_TASKS) + st_main_score = MAIN_SCORE or run_mteb_embed_task_st( + MODEL_NAME, MTEB_EMBED_TASKS) + + print("VLLM main score: ", vllm_main_score) + print("SentenceTransformer main score: ", st_main_score) + print("Difference: ", st_main_score - vllm_main_score) + + assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/deepep_utils.py new file mode 100644 index 000000000000..117f1babdf62 --- /dev/null +++ b/tests/kernels/moe/deepep_utils.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +DeepEP test utilities +""" +import dataclasses +import importlib +import traceback +from typing import Callable, Optional + +import torch +from torch.distributed import ProcessGroup +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +has_deep_ep = importlib.util.find_spec("deep_ep") is not None +if has_deep_ep: + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + +## Parallel Processes Utils + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +## DeepEP specific utils + + +@dataclasses.dataclass +class DeepEPHTArgs: + num_local_experts: int + + +@dataclasses.dataclass +class DeepEPLLArgs: + max_tokens_per_rank: int + hidden_size: int + num_experts: int + use_fp8_dispatch: bool + + +def make_deepep_ht_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + ht_args: DeepEPHTArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + + import deep_ep + + # high throughput a2a + num_nvl_bytes = 1024 * 1024 * 1024 # 1GB + num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 + buffer = deep_ep.Buffer(group=pg, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank) + return DeepEPHTPrepareAndFinalize(buffer=buffer, + world_size=pgi.world_size, + rank=pgi.rank, + dp_size=dp_size, + rank_expert_offset=pgi.rank * + ht_args.num_local_experts, + quant_dtype=q_dtype, + block_shape=block_shape) + + +def make_deepep_ll_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ll_args: DeepEPLLArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + + import deep_ep + + # low-latency a2a + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, + pgi.world_size, deepep_ll_args.num_experts) + + buffer = deep_ep.Buffer(group=pg, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=deepep_ll_args.num_experts // + pgi.world_size) + + return DeepEPLLPrepareAndFinalize( + buffer=buffer, + world_size=pgi.world_size, + dp_size=dp_size, + max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, + quant_dtype=q_dtype, + block_shape=block_shape, + use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, + ) + + +def make_deepep_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ht_args: Optional[DeepEPHTArgs], + deepep_ll_args: Optional[DeepEPLLArgs], + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + if deepep_ht_args is not None: + assert deepep_ll_args is None + return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, + block_shape) + + assert deepep_ll_args is not None + return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, + block_shape) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 7d7a62eec118..6df9ddda33f1 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -83,6 +83,9 @@ pytest.param( "Qwen/Qwen3-8B", # qwen (text-only) ), + pytest.param( + "Qwen/Qwen3-8B", # qwen (text-only) + ), pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py new file mode 100644 index 000000000000..952449f28415 --- /dev/null +++ b/tests/models/language/generation/test_granitemoehybrid.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from ...utils import check_logprobs_close + +# Path of the checkpoints +MODELS = [ + "ibm-granite/granite-4.0-tiny-preview", +] + + +@pytest.mark.skip( + reason="Granite 4.0 is not yet available in huggingface transformers") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_model_equivalence_to_hf_greedy( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +): + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index ecaae3ec1fc4..2554785fe2e3 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -25,6 +25,9 @@ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", + # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as + # it is not yet available in huggingface transformers + # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. diff --git a/tests/models/registry.py b/tests/models/registry.py index 48302f9d6648..e8dc137e0a34 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -240,7 +240,7 @@ def check_available_online( "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", trust_remote_code=True), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), - "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), + "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), # Blocksparse attention not supported in V1 yet "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", @@ -334,7 +334,8 @@ def check_available_online( "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 - extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 + extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501 + v0_only=True), "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 @@ -362,7 +363,8 @@ def check_available_online( trust_remote_code=True), "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 max_model_len=10240), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", diff --git a/tests/pplx_utils.py b/tests/pplx_utils.py new file mode 100644 index 000000000000..2d5d5be80c3f --- /dev/null +++ b/tests/pplx_utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import os +import traceback +from typing import Callable + +import torch +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + args, + nprocs=world_local_size, + join=True, + ) diff --git a/tests/test_config.py b/tests/test_config.py index 6ed7ef9e6a40..0806da378e30 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -31,6 +31,18 @@ class _TestConfigFields: c: str = "default" +def test_compile_config_repr_succeeds(): + # setup: VllmBackend mutates the config object + config = VllmConfig() + backend = VllmBackend(config) + backend.configure_post_pass() + + # test that repr(config) succeeds + val = repr(config) + assert 'VllmConfig' in val + assert 'inductor_passes' in val + + def test_get_field(): with pytest.raises(ValueError): get_field(_TestConfigFields, "a") diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb2..8b43173020a2 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -28,6 +28,7 @@ SyncMPClient) from vllm.v1.engine.utils import CoreEngineProcManager from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import CoreEngineProcManager from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py index ccf38c31d39e..801530582cab 100644 --- a/tests/v1/sample/test_topk_topp_sampler.py +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest import torch +from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs from torch import Generator from vllm.platforms import current_platform diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index e33efb413d02..de33a97f457f 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -8,6 +8,8 @@ import regex as re import torch +import regex as re + from vllm import CompletionOutput from vllm.utils import make_tensor_with_pad from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index b3688d2340e4..fa9b1fef1507 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -719,6 +719,99 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: # ----------------------------------------------------------------------------- +class CustomDataset(BenchmarkDataset): + """ + Implements the Custom dataset. Loads data from a JSONL file and generates + sample requests based on conversation turns. E.g., + ``` + {"prompt": "What is the capital of India?"} + {"prompt": "What is the capital of Iran?"} + {"prompt": "What is the capital of China?"} + ``` + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.load_data() + + def load_data(self) -> None: + if self.dataset_path is None: + raise ValueError("dataset_path must be provided for loading data.") + + # self.data will be a list of dictionaries + # e.g., [{"prompt": "What is the capital of India?"}, ...] + # This will be the standardized format which load_data() + # has to convert into depending on the filetype of dataset_path. + # sample() will assume this standardized format of self.data + self.data = [] + + # Load the JSONL file + if self.dataset_path.endswith(".jsonl"): + jsonl_data = pd.read_json(path_or_buf=self.dataset_path, + lines=True) + + # check if the JSONL file has a 'prompt' column + if "prompt" not in jsonl_data.columns: + raise ValueError("JSONL file must contain a 'prompt' column.") + + # Convert each row to a dictionary and append to self.data + # This will convert the DataFrame to a list of dictionaries + # where each dictionary corresponds to a row in the DataFrame. + # This is the standardized format we want for self.data + for _, row in jsonl_data.iterrows(): + self.data.append(row.to_dict()) + else: + raise NotImplementedError( + "Only JSONL format is supported for CustomDataset.") + + random.seed(self.random_seed) + random.shuffle(self.data) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + lora_path: Optional[str] = None, + max_loras: Optional[int] = None, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + skip_chat_template: bool = False, + **kwargs, + ) -> list: + sampled_requests = [] + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["prompt"] + + # apply template + if not skip_chat_template: + prompt = tokenizer.apply_chat_template( + [{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + )) + self.maybe_oversample_requests(sampled_requests, num_requests) + + return sampled_requests + + +# ----------------------------------------------------------------------------- +# Custom Dataset Implementation +# ----------------------------------------------------------------------------- + + class CustomDataset(BenchmarkDataset): """ Implements the Custom dataset. Loads data from a JSONL file and generates diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 8b16fea9e3d3..25969b67d0f4 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -830,6 +830,54 @@ def add_cli_args(parser: argparse.ArgumentParser): "decoding (i.e. temperature==0.0).", ) + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) + + sampling_group = parser.add_argument_group("sampling parameters") + sampling_group.add_argument( + "--top-p", + type=float, + default=None, + help="Top-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--top-k", + type=int, + default=None, + help="Top-k sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--min-p", + type=float, + default=None, + help="Min-p sampling parameter. Only has effect on " + "openai-compatible backends.", + ) + sampling_group.add_argument( + "--temperature", + type=float, + default=None, + help="Temperature sampling parameter. Only has effect on " + "openai-compatible backends. If not specified, default to greedy " + "decoding (i.e. temperature==0.0).", + ) + parser.add_argument( '--tokenizer-mode', type=str, diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 5cbc8ca31752..78e95903593b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,6 +3,7 @@ """ KV cache helper for store. """ + import torch import vllm.envs as envs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cf94b6a64281..c1b859c776d8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -654,6 +654,27 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["tensor_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + help='Number of data parallel replicas ' + 'to run on this node.') + parallel_group.add_argument('--data-parallel-address', + '-dpa', + type=str, + help='Address of data parallel cluster ' + 'head-node.') + parallel_group.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + help='Port for data parallel RPC ' + 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + default='mp', + help='Backend for data parallel, either ' + '"mp" or "ray".') parallel_group.add_argument( '--data-parallel-rank', '-dpn', @@ -1354,7 +1375,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Skip this check if we are running on a non-GPU platform, # or if the device capability is not available # (e.g. in a Ray actor without GPUs). - from vllm.platforms import current_platform + from vllm.platforms import CpuArchEnum, current_platform if (current_platform.is_cuda() and current_platform.get_device_capability() and current_platform.get_device_capability().major < 8): diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 012ea1d75f44..f9ff65e9a1fb 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -488,7 +488,8 @@ def resolve_chat_template_content_format( detected_format=detected_format, ) - return detected_format + return detected_format if given_format == "auto" else given_format + diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9e24b31e1aae..5788613293bd 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -8,6 +8,7 @@ from typing import Optional import uvloop +import zmq import vllm import vllm.envs as envs diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 90d24819dfdf..177193f47157 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -913,7 +913,7 @@ async def chat_completion_stream_generator( total_tokens=num_prompt_tokens + completion_tokens, ) - data = chunk.model_dump_json(exclude_unset=True) + data = chunk.model_dump_json(exclude_none=True) yield f"data: {data}\n\n" # once the final token is handled, if stream_options.include_usage diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 305aa0374581..85934d1b66c1 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -23,6 +23,7 @@ CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo) from vllm.entrypoints.openai.serving_engine import ( @@ -315,6 +316,7 @@ async def completion_stream_generator( previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts + num_cached_tokens = [0] * num_prompts accumulated_text = [""] * num_choices * num_prompts accumulated_tokens = [[] * num_choices * num_prompts] accumulated_logprobs = [[] * num_choices * num_prompts] @@ -441,10 +443,15 @@ async def completion_stream_generator( total_prompt_tokens = sum(num_prompt_tokens) total_completion_tokens = sum(previous_num_tokens) + total_cached_tokens = sum(num_cached_tokens) final_usage_info = UsageInfo( prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens) + if self.enable_prompt_tokens_details and total_cached_tokens: + final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=total_cached_tokens + ) if include_usage: final_usage_chunk = CompletionStreamResponse( @@ -549,6 +556,9 @@ def request_output_to_completion_response( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) + if self.enable_prompt_tokens_details and final_res_batch[0].num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res_batch[0].num_cached_tokens) request_metadata.final_usage_info = usage diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c4ebb7141d09..a15c3f7dbaff 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -106,6 +106,17 @@ ScoreResponse, ] +AnyResponse = Union[ + CompletionResponse, + ChatCompletionResponse, + EmbeddingResponse, + TranscriptionResponse, + TokenizeResponse, + PoolingResponse, + ClassificationResponse, + ScoreResponse, +] + class TextTokensPrompt(TypedDict): prompt: str diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d90..b0786c78b2ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -987,6 +987,20 @@ def get_config_dtype_str( return None +# TODO (bnell): use scalar_type instead of bools? +def get_config_qtype( + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, +) -> Optional[torch.dtype]: + if use_fp8_w8a8: + return torch.float8_e4m3fn + elif use_int8_w8a8: + return torch.int8 + return None + + def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36ac75a8df4b..c9dcdcb6f7ae 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib from abc import abstractmethod from collections.abc import Iterable from enum import Enum @@ -8,6 +9,9 @@ import torch import torch.nn.functional as F +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -35,6 +39,9 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +has_pplx = importlib.util.find_spec("pplx_kernels") is not None +has_deepep = importlib.util.find_spec("deep_ep") is not None + if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts @@ -63,6 +70,205 @@ logger = init_logger(__name__) +# Note: this limit is somewhat arbitrary and might be changed later. +# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim. +MOE_DP_CHUNK_SIZE = 256 + + +@dataclass +class FusedMoEParallelConfig: + tp_size: int + dp_size: int + ep_size: int + tp_rank: int + dp_rank: int + ep_rank: int + + use_ep: bool # whether to use EP or not + + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + + @property + def use_pplx_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + + @staticmethod + def make(tp_size_: int, dp_size_: int, + vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": + """ + Determine MoE parallel configuration. Based on the input tp_size_, + dp_size_, ep_size_ and vllm's parallel config, determine what + level's of parallelism to use in the fused moe layer. + + Args: + tp_size_ (int): tp_size passed into the FusedMoE constructor. + dp_size_ (int): dp_size passed into the FusedMoE constructor. + ep_size_ (int): ep_size passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vllm's parallel config + object. + + Examples: + When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, + we simply return the sizes unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either dp_size_ or tp_size_ + is non trivial. + + When TP = 2, DP = 1 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // + legend : {size, rank} + - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} + - Comment : Tensors are sharded across 2 devices. + + When TP = 1, DP = 2 and EP = False, the configuration on different + devices, + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} + - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices, + - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} + - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} + - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} + - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} + - Comment: There are 2 engine instances and the tensors are sharded + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} + - Comment: The experts are split between the 2 devices. + + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} + - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} + - Comment: There are 2 engine instances and the experts are split + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices, + - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} + - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} + - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} + - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} + - Comment: There are 2 engine instances and the experts are split + between the 4 devices. + """ + + def flatten_tp_across_dp(dp_rank: int): + tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size_ * tp_size_ devices. Update tp_size + # and tp_rank so we shard across all devices. + tp_size = dp_size_ * tp_size_ + tp_rank = dp_rank * tp_size_ + tp_rank + return tp_size, tp_rank + + use_ep = (dp_size_ * tp_size_ > 1 + and vllm_parallel_config.enable_expert_parallel) + + dp_size = dp_size_ + dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 + tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + + if not use_ep: + return FusedMoEParallelConfig(tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=1, + ep_rank=0, + use_ep=False) + # DP + EP / TP + EP / DP + TP + EP + assert use_ep + # In EP, each device owns a set of experts fully. There is no tensor + # parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that. + ep_size = tp_size + ep_rank = tp_rank + return FusedMoEParallelConfig(tp_size=1, + tp_rank=0, + dp_size=dp_size, + dp_rank=dp_rank, + ep_size=ep_size, + ep_rank=ep_rank, + use_ep=True) + + +# Adapted from pplx-kernels tests/all_to_all_utils.py +@dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + + num_local_experts: int + moe_parallel_config: FusedMoEParallelConfig + + in_dtype: torch.dtype # The activation type. + quant_dtype: torch.dtype = None + + # TODO: add more quantization params, blocked, per-token, etc. + block_size: int = 128 + + max_num_tokens: int = MOE_DP_CHUNK_SIZE + + @property + def tp_size(self): + return self.moe_parallel_config.tp_size + + @property + def dp_size(self): + return self.moe_parallel_config.dp_size + + @property + def ep_size(self): + return self.moe_parallel_config.ep_size + + @property + def tp_rank(self): + return self.moe_parallel_config.tp_rank + + @property + def dp_rank(self): + return self.moe_parallel_config.dp_rank + + @property + def ep_rank(self): + return self.moe_parallel_config.ep_rank + + @property + def use_ep(self): + return self.moe_parallel_config.use_ep + + @property + def use_pplx_kernels(self): + return self.moe_parallel_config.use_pplx_kernels + + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -71,6 +277,18 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" +def get_quant_config_input_activations( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get( + "input_activations") + else: + return None + + class FusedMoEMethodBase(QuantizeMethodBase): moe: FusedMoEConfig diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ef67cc0eda46..df305a3f8e45 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum +import importlib from enum import Enum from typing import Callable, Optional @@ -687,6 +688,40 @@ def apply( assert self.fused_experts_func is not None + if self.rocm_aiter_moe_enabled: + return self.rocm_aiter_fused_experts_func( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy == + QuantizationStrategy.CHANNEL, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + return self.fused_experts_func( hidden_states=x, w1=layer.w13_weight, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6e955e1c5121..36c5db49ea56 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -412,10 +412,15 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.model(input_ids=input_ids, - position_ids=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + hidden_states = self.model(input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + # convert the embedding output to float32, + # otherwise precision will be lost significantly + hidden_states = hidden_states.to(torch.float32) + return hidden_states def pooler( self, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 676a440a79db..14592378e0f1 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -39,6 +39,20 @@ class CpuPlatform(Platform): dispatch_key: str = "CPU" dist_backend: str = "gloo" + @property + def supported_dtypes(self) -> list[torch.dtype]: + if self.get_cpu_architecture() == CpuArchEnum.POWERPC: + return [torch.bfloat16, torch.float32] + elif sys.platform.startswith( + "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: + # TODO: change this condition to check if the platform support bf16 + # instead of checking the OS. For instance M2 shall supports bf16 + # already. But we need to modify `cpu_extension.cmake` to activate + # the feature in the build. + return [torch.float16, torch.float32] + # x86/aarch64 CPU has supported both bf16 and fp16 natively. + return [torch.bfloat16, torch.float16, torch.float32] + @property def supported_dtypes(self) -> list[torch.dtype]: if self.get_cpu_architecture() == CpuArchEnum.POWERPC: @@ -91,6 +105,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: model_config.disable_cascade_attn = True + model_config.disable_cascade_attn = True + cache_config = vllm_config.cache_config ipex_available = find_spec("intel_extension_for_pytorch") is not None @@ -205,6 +221,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + # Share the cpusets list among ranks by spawning process instead + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + # Intel OpenMP setting ld_prealod_str = os.getenv("LD_PRELOAD", "") if "libiomp5.so" in ld_prealod_str: diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6810944c848d..bf153fd9f6b6 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -84,6 +84,22 @@ def can_update_inplace(cls): def get_lora_vocab_padding_size(cls) -> int: return 1 + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + + @classmethod + def can_update_inplace(cls): + return False + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + return 1 + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65a..f499996816ef 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -23,6 +23,8 @@ reshape_and_cache_flash) from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6937455e7d85..3317e3c757c8 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -84,6 +84,12 @@ def __init__( self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + assert len( + set(g.kv_cache_spec.block_size + for g in kv_cache_config.kv_cache_groups) + ) == 1, "Only one block size is supported for now" + self.block_size = kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size self.block_size: Optional[int] = None if self.enable_caching: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2f..6d74d2fcbfeb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -351,6 +351,19 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.prepend_request(request) continue + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.popleft() + skipped_waiting_requests.appendleft(request) + continue + # Skip request if the structured output request is still waiting # for FSM compilation. if request.status == RequestStatus.WAITING_FOR_FSM: @@ -991,6 +1004,23 @@ def _free_request(self, request: Request) -> Optional[dict[str, Any]]: def _free_blocks(self, request: Request): assert request.is_finished() + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) + request_id = request.request_id + self._cached_reqs_data.pop(request_id, None) + self.finished_req_ids.add(request_id) + if self.finished_req_ids_dict is not None: + self.finished_req_ids_dict[request.client_index].add(request_id) + + if not delay_free_blocks: + self._free_blocks(request) + + return kv_xfer_params + + def _free_blocks(self, request: Request): + assert request.is_finished() + assert request.request_id not in self._cached_reqs_data self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11c..0061c801f65d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -42,6 +42,7 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager +from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9b96f4599f92..f00498f00fc3 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -88,6 +88,11 @@ def __init__( self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # P/D: Connector-specific KV transfer parameters. + kv_params = (None if sampling_params.extra_args is None else + sampling_params.extra_args.get("kv_transfer_params")) + self.kv_transfer_params: Optional[dict[str, Any]] = kv_params + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: @@ -124,7 +129,6 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": sampling_params=request.sampling_params, pooling_params=request.pooling_params, eos_token_id=request.eos_token_id, - arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params) \ diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 6b40cf6fd36d..29181a8956cc 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -11,8 +11,11 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union, overload) +import msgspec import torch +import zmq +from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, @@ -21,6 +24,8 @@ kill_process_tree) if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + from vllm.attention.layer import Attention from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.utils import (CoreEngineActorManager, @@ -30,6 +35,8 @@ T = TypeVar("T") +STARTUP_POLL_PERIOD_MS = 10000 + class ConstantList(Generic[T], Sequence): @@ -178,6 +185,352 @@ def __init__( def close(self) -> None: self._finalizer() + def join_first(self): + """Wait for any process to exit.""" + connection.wait(proc.sentinel for proc in self.processes) + + def sentinels(self) -> list: + return [proc.sentinel for proc in self.processes] + + def finished_procs(self) -> dict[str, int]: + """Returns dict of proc name -> exit code for any finished procs.""" + return { + proc.name: proc.exitcode + for proc in self.processes if proc.exitcode is not None + } + + +class CoreEngineActorManager: + """ + Utility class to handle creation, readiness, and shutdown + of core engine Ray actors used by the AsyncLLM and LLMEngine. + + Different from CoreEngineProcManager, this class manages + core engines for both local and remote nodes. + """ + + def __init__( + self, + vllm_config: VllmConfig, + addresses: EngineZmqAddresses, + executor_class: type[Executor], + log_stats: bool, + placement_groups: Optional[list["PlacementGroup"]] = None, + local_dp_ranks: Optional[list[int]] = None, + ): + import copy + + import ray + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + + from vllm.v1.engine.core import DPEngineCoreActor + + self.local_engine_actors: list[ray.ActorHandle] = [] + self.remote_engine_actors: list[ray.ActorHandle] = [] + dp_size = vllm_config.parallel_config.data_parallel_size + local_engine_count = \ + vllm_config.parallel_config.data_parallel_size_local + world_size = vllm_config.parallel_config.world_size + + if ray.is_initialized(): + logger.info( + "Ray is already initialized. Skipping Ray initialization.") + else: + ray.init() + + if placement_groups is not None: + assert local_dp_ranks is not None, ( + "local_dp_ranks must be provided if " + "placement_groups is provided") + assert len(placement_groups) == len(local_dp_ranks), ( + "placement_groups and local_dp_ranks must " + "have the same length") + logger.info("Using provided placement groups") + # TODO(rui): validate passed-in placement groups + self.created_placement_groups = [] + else: + placement_groups, local_dp_ranks = \ + CoreEngineActorManager.create_dp_placement_groups(vllm_config) + self.created_placement_groups = placement_groups + assert len(placement_groups) == dp_size, ( + "Number of placement groups must match data parallel size") + + refs = [] + for index in range(dp_size): + local_index = local_dp_ranks[index] + dp_vllm_config = copy.deepcopy(vllm_config) + pg = placement_groups[index] + dp_vllm_config.parallel_config.placement_group = pg + on_head_node = index < local_engine_count + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + on_head_node=on_head_node, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index) + if on_head_node: + self.local_engine_actors.append(actor) + else: + self.remote_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + ray.get(refs) + self.run_refs = [] + for actor in self.local_engine_actors + self.remote_engine_actors: + self.run_refs.append(actor.run.remote()) + + @staticmethod + def create_dp_placement_groups( + vllm_config: VllmConfig + ) -> tuple[list["PlacementGroup"], list[int]]: + + import ray + from ray._private.state import available_resources_per_node + from ray.util.state import list_nodes + + logger.info("Creating placement groups for data parallel") + dp_master_ip = \ + vllm_config.parallel_config.data_parallel_master_ip + dp_size = vllm_config.parallel_config.data_parallel_size + local_engine_count = \ + vllm_config.parallel_config.data_parallel_size_local + + nodes = list_nodes() + nodes = sorted(list_nodes(), + key=lambda node: node.node_ip != dp_master_ip) + assert nodes[0].node_ip == dp_master_ip, ( + "The first node must be the head node") + assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( + "There can only be one head node") + + available_resources = available_resources_per_node() + world_size = vllm_config.parallel_config.world_size + placement_groups: list[PlacementGroup] = [] + local_dp_ranks: list[int] = [] + + for node in nodes: + node_ip = node.node_ip + node_resources = available_resources[node.node_id] + # For now, each DP rank can only be assigned to one node + # TODO(rui): support allocating a single DP rank + # to multiple nodes + available_engine_count = int(node_resources["GPU"]) // world_size + if node_ip == dp_master_ip: + assert available_engine_count >= local_engine_count, ( + "Not enough resources to allocate DP ranks " + f"on DP master node {node_ip}") + for i in range(local_engine_count): + bundles = [{ + "GPU": 1.0, + "node:" + dp_master_ip: 0.001 + }] * world_size + [{ + "CPU": 1.0 + }] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + else: + for i in range(available_engine_count): + if len(placement_groups) == dp_size: + break + bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + return placement_groups, local_dp_ranks + + def get_run_refs(self): + return self.run_refs + + def close(self): + import ray + for actor in self.local_engine_actors + self.remote_engine_actors: + ray.kill(actor) + for pg in self.created_placement_groups: + ray.util.remove_placement_group(pg) + + +def wait_for_engine_startup( + handshake_socket: zmq.Socket, + addresses: EngineZmqAddresses, + core_engines: list[CoreEngine], + parallel_config: ParallelConfig, + cache_config: CacheConfig, + proc_manager: Optional[CoreEngineProcManager], + coord_process: Optional[Process], +): + + # Wait for engine core process(es) to send ready messages. + local_count = parallel_config.data_parallel_size_local + remote_count = len(core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() + poller.register(handshake_socket, zmq.POLLIN) + + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + if coord_process is not None: + poller.register(coord_process.sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): + events = poller.poll(STARTUP_POLL_PERIOD_MS) + if not events: + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) + continue + if len(events) > 1 or events[0][0] != handshake_socket: + # One of the local core processes exited. + finished = proc_manager.finished_procs() if proc_manager else {} + if coord_process is not None and coord_process.exitcode is not None: + finished[coord_process.name] = coord_process.exitcode + raise RuntimeError("Engine core initialization failed. " + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = handshake_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, "little") + engine = next((e for e in core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = msgspec.msgpack.encode( + EngineHandshakeMetadata( + addresses=addresses, + parallel_config={ + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + })) + handshake_socket.send_multipart((eng_identity, init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg["num_gpu_blocks"] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) + + +def wait_for_completion_or_failure( + api_server_manager: APIServerProcessManager, + engine_manager: Optional[Union[CoreEngineProcManager, + CoreEngineActorManager]] = None, + coordinator: Optional["DPCoordinator"] = None) -> None: + """Wait for all processes to complete or detect if any fail. + + Raises an exception if any process exits with a non-zero status. + + Args: + api_server_manager: The manager for API servers. + engine_manager: The manager for engine processes. + If CoreEngineProcManager, it manages local engines; + if CoreEngineActorManager, it manages all engines. + coordinator: The coordinator for data parallel. + """ + + try: + logger.info("Waiting for API servers to complete ...") + # Create a mapping of sentinels to their corresponding processes + # for efficient lookup + sentinel_to_proc: dict[Any, BaseProcess] = { + proc.sentinel: proc + for proc in api_server_manager.processes + } + + if coordinator: + sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc + + actor_run_refs = [] + if isinstance(engine_manager, CoreEngineProcManager): + for proc in engine_manager.processes: + sentinel_to_proc[proc.sentinel] = proc + elif isinstance(engine_manager, CoreEngineActorManager): + actor_run_refs = engine_manager.get_run_refs() + + # Check if any process terminates + while sentinel_to_proc or actor_run_refs: + # Wait for any process to terminate + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, + timeout=5) + + # Process any terminated processes + for sentinel in ready_sentinels: + proc = sentinel_to_proc.pop(sentinel) + + # Check if process exited with error + if proc.exitcode != 0: + raise RuntimeError( + f"Process {proc.name} (PID: {proc.pid}) " + f"died with exit code {proc.exitcode}") + + if actor_run_refs: + import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) + + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, shutting down API servers...") + except Exception as e: + logger.exception("Exception occurred while running API servers: %s", + str(e)) + raise + finally: + logger.info("Terminating remaining processes ...") + api_server_manager.close() + if coordinator: + coordinator.close() + if engine_manager: + engine_manager.close() + def wait_for_completion_or_failure( api_server_manager: APIServerProcessManager, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4635c1a63507..da033bd83cde 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -576,6 +576,26 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def _get_cumsum_and_arange( + self, + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -2014,6 +2034,10 @@ def _dummy_run( num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad + # Padding for DP + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f5f26d8fff98..8d4767645d79 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -739,6 +739,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) + attn_metadata = PallasMetadata( slot_mapping=slot_mapping_metadata, block_tables=block_tables, From 155b547b431bd96672a51b44fc71e321cd7ce79d Mon Sep 17 00:00:00 2001 From: Your Name <you@example.com> Date: Thu, 3 Jul 2025 00:45:26 +0000 Subject: [PATCH 26/27] yield last chunk if it's usage --- vllm/entrypoints/openai/serving_completion.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 85934d1b66c1..2c00fa25f483 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -330,6 +330,7 @@ async def completion_stream_generator( else: include_usage, include_continuous_usage = False, False + chunk = None try: async for prompt_idx, res in result_generator: prompt_token_ids = res.prompt_token_ids @@ -461,6 +462,12 @@ async def completion_stream_generator( choices=[], usage=final_usage_info, ) + + # if accumulate, send the usage info attached to last chunk instead + if request.accumulate: + chunk.usage = final_usage_info + final_usage_chunk = chunk + final_usage_data = (final_usage_chunk.model_dump_json( exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" From a05d08c7bdf0cec25388e4aa68df84266dd593dd Mon Sep 17 00:00:00 2001 From: Your Name <you@example.com> Date: Thu, 3 Jul 2025 00:46:12 +0000 Subject: [PATCH 27/27] chunk --- vllm/entrypoints/openai/serving_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 2c00fa25f483..341006cae6ba 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -464,7 +464,7 @@ async def completion_stream_generator( ) # if accumulate, send the usage info attached to last chunk instead - if request.accumulate: + if request.accumulate and chunk is not None: chunk.usage = final_usage_info final_usage_chunk = chunk