diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index a23abdc1ed6c..7589b48b584d 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -27,6 +27,8 @@ docker run \ "${image_name}" \ sh -c ' VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests pytest -v -s v1/core ' diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md index 33584fdd5d40..5ac80fa66bf2 100644 --- a/docs/contributing/incremental_build.md +++ b/docs/contributing/incremental_build.md @@ -84,6 +84,7 @@ Below is an example of what the generated `CMakeUserPresets.json` might look lik ``` **What do the various configurations mean?** + - `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically. - `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default. - `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable. diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 13a8386a2971..c68b3aef5828 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -268,10 +268,10 @@ Flags: `--tool-call-parser hermes` Supported models: -* `MiniMaxAi/MiniMax-M1-40k` (use with ) -* `MiniMaxAi/MiniMax-M1-80k` (use with ) +* `MiniMaxAi/MiniMax-M1-40k` (use with ) +* `MiniMaxAi/MiniMax-M1-80k` (use with ) -Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax.jinja` +Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax_m1.jinja` ### DeepSeek-V3 Models (`deepseek_v3`) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e003a3e31717..576372ac0609 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,7 +316,8 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -332,7 +333,7 @@ Specified using `--task generate`. | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -345,7 +346,7 @@ Specified using `--task generate`. | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | @@ -357,14 +358,14 @@ Specified using `--task generate`. | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | -| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | +| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | @@ -389,7 +390,7 @@ Specified using `--task generate`. | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | | | `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. @@ -738,4 +739,4 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. \ No newline at end of file diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 8b50802e6a8e..459ea2d676c1 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | 🚀 Optimized | | **Encoder-Decoder Models** | 🟠 Delayed | | **Embedding Models** | 🟢 Functional | -| **Mamba Models** | 🚧 WIP () | +| **Mamba Models** | 🟢 (Mamba-2), 🟡 (Mamba-1) | | **Multimodal Models** | 🟢 Functional | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -104,8 +104,16 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention (e.g., `MambaForCausalLM`, `JambaForCausalLM`) -will be supported via . +Models using selective state-space mechanisms instead of standard transformer attention are partially supported. +Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers +(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet suported. Please note that these models currently require +enforcing eager mode and disabling prefix caching in V1. + +Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that +these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention +backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible +using the `vllm serve` CLI yet). #### Encoder-Decoder Models diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 84ad7a09165a..799648d3992e 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -145,7 +145,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "enable_thinking": enable_thinking } }) - + if enable_thinking: + assert chat_completion.choices[0].message.\ + reasoning_content is not None + assert chat_completion.choices[0].message.\ + reasoning_content != "" assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 else: diff --git a/tests/models/registry.py b/tests/models/registry.py index 48302f9d6648..10da077e5b5a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -141,6 +141,8 @@ def check_available_online( trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", @@ -412,7 +414,8 @@ def check_available_online( hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 + max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), @@ -500,4 +503,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/utils.py b/tests/utils.py index a37872830dad..f4317e6bdb40 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -818,14 +818,15 @@ def create_new_process_for_each_test( Args: method: The process creation method. Can be either "spawn" or "fork". - If not specified, - it defaults to "spawn" on ROCm platforms and "fork" otherwise. + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. Returns: A decorator to run test functions in separate processes. """ if method is None: - method = "spawn" if current_platform.is_rocm() else "fork" + use_spawn = current_platform.is_rocm() or current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index 161bcd4d3ef9..f2f460513605 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -5,10 +5,10 @@ from vllm import LLM, SamplingParams -from ...utils import fork_new_process_for_each_test +from ...utils import create_new_process_for_each_test -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) def test_cascade_attention(example_system_message, monkeypatch, attn_backend): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a802fbc3865f..451241d3f9f7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1049,6 +1049,7 @@ async def chat_completion_full_generator( message = ChatMessage( role=role, content="", + reasoning_content=reasoning_content, tool_calls=[ tool_call_class(function=FunctionCall( name=tool_call.name, diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 6f11dcd19e9c..dec32f8e50fa 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -62,7 +62,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1 and not current_platform.is_xpu(): + if envs.VLLM_USE_V1: # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d771a7a54cfc..de588d512739 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, - w1: torch.Tensor, +def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, w2: torch.Tensor) -> bool: - def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): - return M >= 128 and N % 128 == 0 and K % 128 == 0 + def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): + return N % 128 == 0 and K % 128 == 0 - m = hidden_states.size(0) _, K, N = w2.size() - if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): logger.debug( "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") return False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d90..d0ff44a38a4a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1180,7 +1180,7 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( a=hidden_states, diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py new file mode 100644 index 000000000000..a70c7300d29e --- /dev/null +++ b/vllm/model_executor/models/bailing_moe.py @@ -0,0 +1,540 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py +# Copyright 2023 The vLLM team. +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BailingMoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +KVCache = tuple[torch.Tensor, torch.Tensor] + + +class BailingAttention(nn.Module): + + def __init__( + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.total_num_heads % tp_size == 0 + assert self.total_kv_heads % tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = config.head_dim or (self.hidden_size // + self.total_num_heads) + self.q_size_per_rank = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // tp_size + self.kv_size_per_rank = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_kv_heads, + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + is_neox_style=True, + rope_scaling=config.rope_scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([ + self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank + ], + dim=-1) + + q, k = self.rotary_emb(position_ids, q, k) + + context_layer = self.attn( + q, + k, + v, + ) + + attn_output, _ = self.dense(context_layer) + return attn_output + + +class BailingMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BailingMoE(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_expert_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + if self.num_shared_experts > 0: + intermediate_size = (config.moe_intermediate_size * + self.num_shared_experts) + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + +class BailingMoeBlock(nn.Module): + + def __init__( + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.attention = BailingAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.post_attention_layernorm = RMSNorm(hidden_size, + eps=config.rms_norm_eps) + self.mlp = BailingMoE(intermediate_size, + config, + quant_config, + True, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.attention( + hidden_states=hidden_states, + position_ids=position_ids, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BailingMoeModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, self.embed_dim) + else: + self.word_embeddings = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BailingMoeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers") + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + hidden_states, + position_ids, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + self.max_position_embeddings = config.max_position_embeddings + self.model = BailingMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = (self.word_embeddings if config.tie_word_embeddings + else ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config)) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (("v_head" in name) or ("inv_freq" in name) or + (self.config.tie_word_embeddings and "lm_head" in name)): + continue + if self.config.norm_head and "lm_head.weight" in name: + import torch.nn.functional as F + loaded_weight = F.normalize(loaded_weight, + dim=0, + p=2, + eps=1e-7) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 27d476929855..867fd0bddc2f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -41,6 +41,7 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 734f1e09d0fd..68aa187a13b9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -30,6 +31,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ + "BailingMoeConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py new file mode 100644 index 000000000000..60315dc950be --- /dev/null +++ b/vllm/transformers_utils/configs/bailing_moe.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/configuration_bailing_moe.py +from transformers.configuration_utils import PretrainedConfig + + +class BailingMoeConfig(PretrainedConfig): + model_type = "bailing_moe" + + def __init__( + self, + vocab_size=30592, + hidden_size=1024, + intermediate_size=None, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=0, + hidden_act="silu", + use_qkv_bias=False, # bailing only + use_bias=True, # bailing only + rms_norm_eps=1e-05, + norm_head=False, # bailing only + tie_word_embeddings=False, # PretrainedConfig key, + # here change default value. + embedding_dropout=0.1, + attention_dropout=0.1, + output_dropout=0.1, + initializer_range=0.02, + max_position_embeddings=16384, + rope_theta=10000.0, + use_cache=True, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + rope_scaling=None, + pad_token_id=126081, + num_experts=16, + num_shared_experts=0, + num_experts_per_tok=2, + norm_topk_prob=True, + moe_intermediate_size=None, + first_k_dense_replace=0, + head_dim=None, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.use_qkv_bias = use_qkv_bias + self.use_bias = use_bias + self.norm_head = norm_head + self.rms_norm_eps = rms_norm_eps + self.embedding_dropout = embedding_dropout + self.attention_dropout = attention_dropout + self.output_dropout = output_dropout + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.use_cache = use_cache + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.head_dim = (head_dim if head_dim is not None else + self.hidden_size // self.num_attention_heads) + self.rope_scaling = rope_scaling + + # MoE configs + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + + super().__init__(pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bfdbd682464a..cf7320a19e4d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1535,6 +1535,13 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def xpu_is_initialized() -> bool: + """Check if XPU is initialized.""" + if not torch.xpu._is_compiled(): + return False + return torch.xpu.is_initialized() + + def cuda_get_device_properties(device, names: Sequence[str], init_cuda=False) -> tuple[Any, ...]: @@ -2848,6 +2855,8 @@ def _maybe_force_spawn(): reason = None if cuda_is_initialized(): reason = "CUDA is initialized" + elif xpu_is_initialized(): + reason = "XPU is initialized" elif is_in_ray_actor(): # even if we choose to spawn, we need to pass the ray address # to the subprocess so that it knows how to connect to the ray cluster. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8658d7d916f0..ef03626cf14d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2219,8 +2219,8 @@ def profile_run(self) -> None: encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens_per_mm_item) + max_num_mm_items_encoder_budget = encoder_budget // \ + max_tokens_per_mm_item # Check how many items of this modality can be supported by # the decoder budget. @@ -2233,8 +2233,10 @@ def profile_run(self) -> None: max_num_mm_items_decoder_budget = self.max_num_reqs * \ max_mm_items_per_req - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) + max_num_mm_items = max( + 1, + min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget)) logger.info( "Encoder cache will be initialized with a budget of %s tokens," @@ -2244,7 +2246,7 @@ def profile_run(self) -> None: # Create dummy batch of multimodal inputs. dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=max_tokens_per_mm_item, mm_counts={ dummy_data_modality: 1 },