diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md new file mode 100644 index 00000000000..a66199a5397 --- /dev/null +++ b/docs/models/supported_models.md @@ -0,0 +1,700 @@ +--- +title: Supported Models +--- +[](){ #supported-models } + +vLLM supports [generative](./generative_models.md) and [pooling](./pooling_models.md) models across various tasks. +If a model supports more than one task, you can set the task via the `--task` argument. + +For each task, we list the model architectures that have been implemented in vLLM. +Alongside each architecture, we include some popular models that use it. + +## Model Implementation + +### vLLM + +If vLLM natively supports a model, its implementation can be found in . + +These models are what we list in [supported-text-models][supported-text-models] and [supported-mm-models][supported-mm-models]. + +[](){ #transformers-backend } + +### Transformers + +vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned! + +To check if the modeling backend is Transformers, you can simply do this: + +```python +from vllm import LLM +llm = LLM(model=..., task="generate") # Name or path of your model +llm.apply_model(lambda model: print(type(model))) +``` + +If it is `TransformersForCausalLM` then it means it's based on Transformers! + +!!! tip + You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][openai-compatible-server]. + +!!! note + vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. + +#### Custom models + +If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! + +For a model to be compatible with the Transformers backend for vLLM it must: + +- be a Transformers compatible custom model (see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)): + * The model directory must have the correct structure (e.g. `config.json` is present). + * `config.json` must contain `auto_map.AutoModel`. +- be a Transformers backend for vLLM compatible model (see [writing-custom-models][writing-custom-models]): + * Customisation should be done in the base model (e.g. in `MyModel`, not `MyModelForCausalLM`). + +If the compatible model is: + +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for [offline-inference][offline-inference] or `--trust-remote-code` for the [openai-compatible-server][openai-compatible-server]. +- in a local directory, simply pass directory path to `model=` for [offline-inference][offline-inference] or `vllm serve ` for the [openai-compatible-server][openai-compatible-server]. + +This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! + +[](){ #writing-custom-models } + +#### Writing custom models + +This section details the necessary modifications to make to a Transformers compatible custom model that make it compatible with the Transformers backend for vLLM. (We assume that a Transformers compatible custom model has already been created, see [Transformers - Customizing models](https://huggingface.co/docs/transformers/en/custom_models)). + +To make your model compatible with the Transformers backend, it needs: + +1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. +2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. +3. `MyModel` must contain `_supports_attention_backend = True`. + +```python title="modeling_my_model.py" + +from transformers import PreTrainedModel +from torch import nn + +class MyAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): + ... + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) + ... + +class MyModel(PreTrainedModel): + _supports_attention_backend = True +``` + +Here is what happens in the background when this model is loaded: + +1. The config is loaded. +2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. +3. `MyModel` is loaded into `TransformersForCausalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. + +That's it! + +For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class: + +```python title="configuration_my_model.py" + +from transformers import PretrainedConfig + +class MyConfig(PretrainedConfig): + base_model_tp_plan = { + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } +``` + +- `base_model_tp_plan` is a `dict` that maps fully qualified layer name patterns to tensor parallel styles (currently only `"colwise"` and `"rowwise"` are supported). +- `base_model_pp_plan` is a `dict` that maps direct child layer names to `tuple`s of `list`s of `str`s: + * You only need to do this for layers which are not present on all pipeline stages + * vLLM assumes that there will be only one `nn.ModuleList`, which is distributed across the pipeline stages + * The `list` in the first element of the `tuple` contains the names of the input arguments + * The `list` in the last element of the `tuple` contains the names of the variables the layer outputs to in your modeling code + +## Loading a Model + +### Hugging Face Hub + +By default, vLLM loads models from [Hugging Face (HF) Hub](https://huggingface.co/models). To change the download path for models, you can set the `HF_HOME` environment variable; for more details, refer to [their official documentation](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhome). + +To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository. +If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. + +Models do not _need_ to be natively supported to be used in vLLM. +The [Transformers backend][transformers-backend] enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). + +!!! tip + The easiest way to check if your model is really supported at runtime is to run the program below: + + ```python + from vllm import LLM + + # For generative models (task=generate) only + llm = LLM(model=..., task="generate") # Name or path of your model + output = llm.generate("Hello, my name is") + print(output) + + # For pooling models (task={embed,classify,reward,score}) only + llm = LLM(model=..., task="embed") # Name or path of your model + output = llm.encode("Hello, my name is") + print(output) + ``` + + If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported. + +Otherwise, please refer to [Adding a New Model][new-model] for instructions on how to implement your model in vLLM. +Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. + +#### Download a model + +If you prefer, you can use the Hugging Face CLI to [download a model](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-download) or specific files from a model repository: + +```console +# Download a model +huggingface-cli download HuggingFaceH4/zephyr-7b-beta + +# Specify a custom cache directory +huggingface-cli download HuggingFaceH4/zephyr-7b-beta --cache-dir ./path/to/cache + +# Download a specific file from a model repo +huggingface-cli download HuggingFaceH4/zephyr-7b-beta eval_results.json +``` + +#### List the downloaded models + +Use the Hugging Face CLI to [manage models](https://huggingface.co/docs/huggingface_hub/guides/manage-cache#scan-your-cache) stored in local cache: + +```console +# List cached models +huggingface-cli scan-cache + +# Show detailed (verbose) output +huggingface-cli scan-cache -v + +# Specify a custom cache directory +huggingface-cli scan-cache --dir ~/.cache/huggingface/hub +``` + +#### Delete a cached model + +Use the Hugging Face CLI to interactively [delete downloaded model](https://huggingface.co/docs/huggingface_hub/guides/manage-cache#clean-your-cache) from the cache: + +```console +# The `delete-cache` command requires extra dependencies to work with the TUI. +# Please run `pip install huggingface_hub[cli]` to install them. + +# Launch the interactive TUI to select models to delete +$ huggingface-cli delete-cache +? Select revisions to delete: 1 revisions selected counting for 438.9M. + ○ None of the following (if selected, nothing will be deleted). +Model BAAI/bge-base-en-v1.5 (438.9M, used 1 week ago) +❯ ◉ a5beb1e3: main # modified 1 week ago + +Model BAAI/bge-large-en-v1.5 (1.3G, used 1 week ago) + ○ d4aa6901: main # modified 1 week ago + +Model BAAI/bge-reranker-base (1.1G, used 4 weeks ago) + ○ 2cfc18c9: main # modified 4 weeks ago + +Press to select, to validate and to quit without modification. + +# Need to confirm after selected +? Select revisions to delete: 1 revision(s) selected. +? 1 revisions selected counting for 438.9M. Confirm deletion ? Yes +Start deletion. +Done. Deleted 1 repo(s) and 0 revision(s) for a total of 438.9M. +``` + +#### Using a proxy + +Here are some tips for loading/downloading models from Hugging Face using a proxy: + +- Set the proxy globally for your session (or set it in the profile file): + +```shell +export http_proxy=http://your.proxy.server:port +export https_proxy=http://your.proxy.server:port +``` + +- Set the proxy for just the current command: + +```shell +https_proxy=http://your.proxy.server:port huggingface-cli download + +# or use vllm cmd directly +https_proxy=http://your.proxy.server:port vllm serve --disable-log-requests +``` + +- Set the proxy in Python interpreter: + +```python +import os + +os.environ['http_proxy'] = 'http://your.proxy.server:port' +os.environ['https_proxy'] = 'http://your.proxy.server:port' +``` + +### ModelScope + +To use models from [ModelScope](https://www.modelscope.cn) instead of Hugging Face Hub, set an environment variable: + +```shell +export VLLM_USE_MODELSCOPE=True +``` + +And use with `trust_remote_code=True`. + +```python +from vllm import LLM + +llm = LLM(model=..., revision=..., task=..., trust_remote_code=True) + +# For generative models (task=generate) only +output = llm.generate("Hello, my name is") +print(output) + +# For pooling models (task={embed,classify,reward,score}) only +output = llm.encode("Hello, my name is") +print(output) +``` + +[](){ #feature-status-legend } + +## Feature Status Legend + +- ✅︎ indicates that the feature is supported for the model. + +- 🚧 indicates that the feature is planned but not yet supported for the model. + +- ⚠️ indicates that the feature is available but may have known issues or limitations. + +[](){ #supported-text-models } + +## List of Text-only Language Models + +### Generative Models + +See [this page][generative-models] for more information on how to use generative models. + +#### Text Generation + +Specified using `--task generate`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|---------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------| +| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | +| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | +| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | +| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | +| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | +| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | +| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | ✅︎ | +| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | +| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | +| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | +| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | +| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | +| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | +| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | +| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | +| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | +| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | +| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | +| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | +| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | +| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | +| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | +| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | +| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | +| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | +| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | +| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | +| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | +| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | +| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | +| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | +| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | +| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | +| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | +| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | +| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | +| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | +| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | +| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | +| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | +| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | +| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | +| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ | +| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | +| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | +| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | +| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | +| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | +| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | + +!!! note + Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. + +### Pooling Models + +See [this page](./pooling_models.md) for more information on how to use pooling models. + +!!! warning + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +#### Text Embedding + +Specified using `--task embed`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------| +| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | +| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | +| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | | +| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ | +| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ | +| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ | +| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | +| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | +| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | + +!!! note + `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. + You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`. + +!!! note + For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded. + See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882). + +!!! note + `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. + +!!! note + The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture. + +If your model is not in the above list, we will try to automatically convert the model using +[as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. By default, the embeddings +of the whole prompt are extracted from the normalized hidden state corresponding to the last token. + +#### Reward Modeling + +Specified using `--task reward`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|---------------------------|-----------------|------------------------------------------------------------------------|------------------------|-----------------------------| +| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | +| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | + +If your model is not in the above list, we will try to automatically convert the model using +[as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. + +!!! warning + For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, + e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + +#### Classification + +Specified using `--task classify`. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|----------------------------------|----------|----------------------------------------|------------------------|-----------------------------| +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | + +If your model is not in the above list, we will try to automatically convert the model using +[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. + +#### Sentence Pair Scoring + +Specified using `--task score`. + +| Architecture | Models | Example HF Models | +|---------------------------------------|-------------------|--------------------------------------------------------------------------------------| +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | +| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | + +!!! note + Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: . + + ```bash + vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' + ``` +[](){ #supported-mm-models } + +## List of Multimodal Language Models + +The following modalities are supported depending on the model: + +- **T**ext +- **I**mage +- **V**ideo +- **A**udio + +Any combination of modalities joined by `+` are supported. + +- e.g.: `T + I` means that the model supports text-only, image-only, and text-with-image inputs. + +On the other hand, modalities separated by `/` are mutually exclusive. + +- e.g.: `T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. + +See [this page][multimodal-inputs] on how to pass multi-modal inputs to the model. + +!!! warning + **To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference) + or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt: + + Offline inference: + + ```python + from vllm import LLM + + llm = LLM( + model="Qwen/Qwen2-VL-7B-Instruct", + limit_mm_per_prompt={"image": 4}, + ) + ``` + + Online serving: + + ```bash + vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}' + ``` + + **This is no longer required if you are using vLLM V1.** + +!!! note + vLLM currently only supports adding LoRA to the language backbone of multimodal models. + +### Generative Models + +See [this page][generative-models] for more information on how to use generative models. + +#### Text Generation + +Specified using `--task generate`. + +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|----------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| +| `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | | ✅︎ | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | +| `Blip2ForConditionalGeneration` | BLIP-2 | T + IE | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | +| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | | ✅︎ | ✅︎ | +| `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | | ✅︎ | ✅︎ | +| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | | +| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | +| `GLM4VForCausalLM`^ | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | +| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | +| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | +| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | +| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + IE+ | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + IE+ | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | +| `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | +| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | | +| `Mistral3ForConditionalGeneration` | Mistral3 | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | +| `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | +| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | +| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | +| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `PixtralForConditionalGeneration` | Pixtral | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | +| `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | +| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | +| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | +| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | + +^ You need to set the architecture name via `--hf-overrides` to match the one in vLLM. +    • For example, to use DeepSeek-VL2 series models: +      `--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` +E Pre-computed embeddings can be inputted for this modality. ++ Multiple items can be inputted per text prompt for this modality. + +!!! warning + Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. + However, there are differences in how they handle text + image inputs: + + V0 correctly implements the model's attention pattern: + - Uses bidirectional attention between the image tokens corresponding to the same image + - Uses causal attention for other tokens + - Implemented via (naive) PyTorch SDPA with masking tensors + - Note: May use significant memory for long prompts with image + + V1 currently uses a simplified attention pattern: + - Uses causal attention for all tokens, including image tokens + - Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": true}` + - Will be updated in the future to support the correct behavior + + This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends. + +!!! note + Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently. + +!!! note + `h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80. + +!!! note + To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. + +!!! warning + The output quality of `AllenAI/Molmo-7B-D-0924` (especially in object localization tasks) has deteriorated in recent updates. + + For the best results, we recommend using the following dependency versions (tested on A10 and L40): + + ```text + # Core vLLM-compatible dependencies with Molmo accuracy setup (tested on L40) + torch==2.5.1 + torchvision==0.20.1 + transformers==4.48.1 + tokenizers==0.21.0 + tiktoken==0.7.0 + vllm==0.7.0 + + # Optional but recommended for improved performance and stability + triton==3.1.0 + xformers==0.0.28.post3 + uvloop==0.21.0 + protobuf==5.29.3 + openai==1.60.2 + opencv-python-headless==4.11.0.86 + pillow==10.4.0 + + # Installed FlashAttention (for float16 only) + flash-attn>=2.5.6 # Not used in float32, but should be documented + ``` + + **Note:** Make sure you understand the security implications of using outdated packages. + +!!! note + The official `openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (`HwwwH/MiniCPM-V-2`) for now. + For more details, please see: + +!!! warning + Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1. + +!!! note + To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via + `pip install git+https://github.com/huggingface/transformers.git`. + + Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1. + `--mm-processor-kwargs '{"use_audio_in_video": true}'`. + +### Pooling Models + +See [this page](./pooling_models.md) for more information on how to use pooling models. + +!!! warning + Since some model architectures support both generative and pooling tasks, + you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. + +#### Text Embedding + +Specified using `--task embed`. + +Any text generation model can be converted into an embedding model by passing `--task embed`. + +!!! note + To get the best results, you should use pooling models that are specifically trained as such. + +The following table lists those that are tested in vLLM. + +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------| +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | +| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | + +#### Transcription + +Specified using `--task transcription`. + +Speech2Text models trained specifically for Automatic Speech Recognition. + +| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | +|----------------|----------|---------------------|------------------------|-----------------------------| + +--- + +## Model Support Policy + +At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support: + +1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated! + +2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results. + + !!! tip + When comparing the output of `model.generate` from Hugging Face Transformers with the output of `llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., [generation_config.json](https://github.com/huggingface/transformers/blob/19dabe96362803fb0a9ae7073d03533966598b17/src/transformers/generation/utils.py#L1945)) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs. + +3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback. + +4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use. + +5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement. + +Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem. + +Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard. + +We have the following levels of testing for models: + +1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. +2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. +3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py new file mode 100644 index 00000000000..27c4071bf09 --- /dev/null +++ b/examples/offline_inference/qwen3_reranker.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +from vllm import LLM + +model_name = "Qwen/Qwen3-Reranker-0.6B" + +# What is the difference between the official original version and one +# that has been converted into a sequence classification model? +# Qwen3-Reranker is a language model that doing reranker by using the +# logits of "no" and "yes" tokens. +# It needs to computing 151669 tokens logits, making this method extremely +# inefficient, not to mention incompatible with the vllm score API. +# A method for converting the original model into a sequence classification +# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 +# Models converted offline using this method can not only be more efficient +# and support the vllm score API, but also make the init parameters more +# concise, for example. +# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score") + +# If you want to load the official original version, the init parameters are +# as follows. + +model = LLM( + model=model_name, + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + }, +) + +# Why do we need hf_overrides for the official original version: +# vllm converts it to Qwen3ForSequenceClassification when loaded for +# better performance. +# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],` +# to manually route to Qwen3ForSequenceClassification. +# - Then, we will extract the vector corresponding to classifier_from_token +# from lm_head using `"classifier_from_token": ["no", "yes"]`. +# - Third, we will convert these two vectors into one vector. The use of +# conversion logic is controlled by `using "is_original_qwen3_reranker": True`. + +# Please use the query_template and document_template to format the query and +# document for better reranker results. + +prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n' +suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + +query_template = "{prefix}: {instruction}\n: {query}\n" +document_template = ": {doc}{suffix}" + +if __name__ == "__main__": + instruction = ( + "Given a web search query, retrieve relevant passages that answer the query" + ) + + queries = [ + "What is the capital of China?", + "Explain gravity", + ] + + documents = [ + "The capital of China is Beijing.", + "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.", + ] + + queries = [ + query_template.format(prefix=prefix, instruction=instruction, query=query) + for query in queries + ] + documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] + + outputs = model.score(queries, documents) + + print([output.outputs.score for output in outputs]) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py new file mode 100644 index 00000000000..6a3a0f150b6 --- /dev/null +++ b/tests/models/language/pooling/test_gte.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import pytest + +from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models + +MODELS = [ + ########## BertModel + EmbedModelInfo("thenlper/gte-large", + architecture="BertModel", + enable_test=True), + EmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + enable_test=False), + ########### NewModel + EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + enable_test=True), + ########### Qwen2ForCausalLM + EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=True), + ########## ModernBertModel + EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True), + ########## Qwen3ForCausalLM + EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=True), + EmbedModelInfo("Qwen/Qwen3-Embedding-4B", + architecture="Qwen3ForCausalLM", + dtype="float32", + enable_test=False), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + mteb_test_embed_models(hf_runner, vllm_runner, model_info, + vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness(hf_runner, vllm_runner, + model_info: EmbedModelInfo, + example_prompts) -> None: + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + correctness_test_embed_models(hf_runner, vllm_runner, model_info, + example_prompts, vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py new file mode 100644 index 00000000000..63b37d9a077 --- /dev/null +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +model_name = "Qwen/Qwen3-Reranker-4B" + +text_1 = "What is the capital of France?" +texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", +] + + +def vllm_reranker(model_name): + from vllm import LLM + + model = LLM(model=model_name, + task="score", + hf_overrides={ + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + }, + dtype="float32") + + text_1 = "What is the capital of France?" + texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + ] + + outputs = model.score(text_1, texts_2) + + return [output.outputs.score for output in outputs] + + +def hf_reranker(model_name): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') + model = AutoModelForCausalLM.from_pretrained(model_name).eval() + + token_false_id = tokenizer.convert_tokens_to_ids("no") + token_true_id = tokenizer.convert_tokens_to_ids("yes") + + max_length = 8192 + + def process_inputs(pairs): + inputs = tokenizer(pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False, + max_length=max_length) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = ele + inputs = tokenizer.pad(inputs, + padding=True, + return_tensors="pt", + max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])] + inputs = process_inputs(pairs) + scores = compute_logits(inputs) + + return scores + + +@pytest.mark.parametrize("model_name", [model_name]) +def test_model(model_name): + hf_outputs = hf_reranker(model_name) + vllm_outputs = vllm_reranker(model_name) + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py b/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py new file mode 100644 index 00000000000..ee07f6ff9dc --- /dev/null +++ b/tests/models/language/pooling/test_qwen3_reranker_seq_cls.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls" + +text_1 = "What is the capital of France?" +texts_2 = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", +] + + +def vllm_reranker(model_name): + from vllm import LLM + + model = LLM(model=model_name, task="score") + outputs = model.score(text_1, texts_2) + + return [output.outputs.score for output in outputs] + + +def hf_reranker(model_name): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') + model = AutoModelForCausalLM.from_pretrained(model_name).eval() + + token_false_id = tokenizer.convert_tokens_to_ids("no") + token_true_id = tokenizer.convert_tokens_to_ids("yes") + + max_length = 8192 + + def process_inputs(pairs): + inputs = tokenizer(pairs, + padding=False, + truncation='longest_first', + return_attention_mask=False, + max_length=max_length) + for i, ele in enumerate(inputs['input_ids']): + inputs['input_ids'][i] = ele + inputs = tokenizer.pad(inputs, + padding=True, + return_tensors="pt", + max_length=max_length) + for key in inputs: + inputs[key] = inputs[key].to(model.device) + return inputs + + @torch.no_grad() + def compute_logits(inputs, **kwargs): + batch_scores = model(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, token_true_id] + false_vector = batch_scores[:, token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])] + inputs = process_inputs(pairs) + scores = compute_logits(inputs) + + return scores + + +@pytest.mark.parametrize("model_name", [model_name]) +def test_model(model_name): + hf_outputs = hf_reranker(model_name) + vllm_outputs = vllm_reranker(model_name) + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/tests/models/registry.py b/tests/models/registry.py index cd5e1dab0a4..78ea8953402 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -225,8 +225,8 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b", - is_available_online=False), + "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 + "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 is_available_online=False), "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index aa5e3675405..e6145b325fb 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -238,6 +238,13 @@ def __init__(self, *, normalize: bool, softmax: bool) -> None: def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Using float32 in PoolerHead + if isinstance(pooled_data, list): + for i in range(len(pooled_data)): + pooled_data[i] = pooled_data[i].to(torch.float32) + else: + pooled_data = pooled_data.to(torch.float32) + dimensions_list = [ pooling_param.dimensions for _, pooling_param in pooling_metadata.seq_groups @@ -261,9 +268,16 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], if self.softmax: if isinstance(pooled_data, list): - pooled_data = [F.softmax(data, dim=-1) for data in pooled_data] + pooled_data = [ + F.softmax(data, dim=-1) + if data.shape[-1] >= 2 else F.sigmoid(data) + for data in pooled_data + ] else: - pooled_data = F.softmax(pooled_data, dim=-1) + if pooled_data.shape[-1] >= 2: + pooled_data = F.softmax(pooled_data, dim=-1) + else: + pooled_data = F.sigmoid(pooled_data) return pooled_data diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index c22015d156d..72fae00eb68 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -35,13 +35,15 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix @@ -316,3 +318,122 @@ def load_weights(self, weights: Iterable[Tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + +class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, + SupportsCrossEncoding): + + def __init__( + self, + vllm_config: "VllmConfig", + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + pooler_config = vllm_config.model_config.pooler_config + + self.vllm_config = vllm_config + self.config = config + self.quant_config = quant_config + self.prefix = prefix + self.model = Qwen3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.score = RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix(prefix, "score")) + + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + hidden_states = self._pooler.extract_states(hidden_states, + pooling_metadata) + logits, _ = self.score(hidden_states) + pooled_data = self._pooler.head(logits, pooling_metadata) + pooled_outputs = [ + self._pooler.build_output(data.squeeze(-1)) for data in pooled_data + ] + return PoolerOutput(outputs=pooled_outputs) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + is_original_qwen3_reranker = getattr(self.config, + "is_original_qwen3_reranker", + False) + + if not is_original_qwen3_reranker: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + return self.load_weights_from_original_qwen3_reranker(weights) + + def load_weights_from_original_qwen3_reranker( + self, weights: Iterable[tuple[str, torch.Tensor]]): + tokens = getattr(self.config, "classifier_from_token", None) + assert tokens is not None and len(tokens) == 2, \ + ("Try loading the original Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") + + self.config.num_labels = 1 + model_config = self.vllm_config.model_config + + device = self.score.weight.device + self.score = RowParallelLinear(self.config.hidden_size, + self.config.num_labels, + quant_config=self.quant_config, + input_is_parallel=False, + bias=False, + prefix=maybe_prefix( + self.prefix, "score")).to(device) + + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix( + self.prefix, "lm_head")) + + loader = AutoWeightsLoader(self) + loaded_weights = loader.load_weights(weights) + + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer( + model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) + + a = tokenizer.convert_tokens_to_ids(tokens[0]) + b = tokenizer.convert_tokens_to_ids(tokens[1]) + weight = self.lm_head.weight.data[b].to( + device) - self.lm_head.weight.data[a].to(device) + self.score.weight.data.copy_(weight) + + del self.lm_head + loaded_weights.add("classifier.weight") + loaded_weights.discard("lm_head.weight") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 19153efd8e1..c8bc594da2f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -166,6 +166,7 @@ "RobertaForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), + "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 } _MULTIMODAL_MODELS = { diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 6d748232160..a5240013004 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -130,7 +130,7 @@ def forward_hpu( device=inputs_embeds.device) valid_input_mask = expected_pos < seq_len expected_pos = expected_pos * valid_input_mask - assert torch.equal(positions, expected_pos) + position_ids[index] = create_position_ids_from_input_ids_hpu( tokens, self.padding_idx, seq_len) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f2310fcee3b..124d601ce1e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -398,11 +398,11 @@ def _merge_multimodal_embeddings( flattened = multimodal_embeddings.reshape(-1, hidden_size) inputs_embeds[is_multimodal] = flattened inputs_embeds = inputs_embeds.reshape(batch_size, seq_length, - hidden_size) + hidden_size) else: flattened = _flatten_embeddings(multimodal_embeddings) - inputs_embeds[is_multimodal] = flattened - + inputs_embeds[is_multimodal] = flattened + return inputs_embeds num_expected_tokens = is_multimodal.sum().item() @@ -713,6 +713,7 @@ def extract_layer_index(layer_name: str) -> int: " only contain one integer") return int_vals[0] + def get_input_mask(hidden_states: torch.Tensor, valid_len: torch.Tensor) -> torch.Tensor: """ @@ -727,6 +728,7 @@ def get_input_mask(hidden_states: torch.Tensor, mask = mask.to(hidden_states.dtype) return mask + def cast_overflow_tensors( tensors: torch.Tensor, offset: float = 1000, diff --git a/vllm/worker/hpu_pooling_model_runner.py b/vllm/worker/hpu_pooling_model_runner.py index f9f58e5307e..8b8196d112e 100755 --- a/vllm/worker/hpu_pooling_model_runner.py +++ b/vllm/worker/hpu_pooling_model_runner.py @@ -130,7 +130,7 @@ def execute_model( return [] return [ - self.model.model._pooler( + self.model.model.pooler( hidden_states=hidden_states, pooling_metadata=model_input.pooling_metadata) ]