From 8bbd7e56173ac43255e195449c38fddfb83f67d3 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Fri, 4 Jul 2025 07:21:51 +0000 Subject: [PATCH 01/17] Add Bailing_moe Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 520 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + .../transformers_utils/configs/bailing_moe.py | 76 +++ 4 files changed, 599 insertions(+) create mode 100644 vllm/model_executor/models/bailing_moe.py create mode 100644 vllm/transformers_utils/configs/bailing_moe.py diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py new file mode 100644 index 000000000000..7c0603983760 --- /dev/null +++ b/vllm/model_executor/models/bailing_moe.py @@ -0,0 +1,520 @@ +# coding=utf-8 +""" PyTorch Bailing model. """ + +from typing import Iterable, Optional, Tuple, Union, Set + +import torch +from torch import nn + +from vllm.model_executor.layers.activation import get_act_fn, SiluAndMul +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import fused_moe, FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.config import LoRAConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class BailingAttention(nn.Module): + + def __init__( + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.total_num_heads % tp_size == 0 + assert self.total_kv_heads % tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) + self.q_size_per_rank = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // tp_size + self.kv_size_per_rank = self.num_kv_heads * self.head_dim + + self.scale = self.head_dim ** -0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_kv_heads, + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + + self.dense = RowParallelLinear(self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense",) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + is_neox_style=True, + rope_scaling=config.rope_scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split( + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], + dim=-1 + ) + + + q, k = self.rotary_emb(position_ids, q, k) + + context_layer = self.attn( + q, + k, + v, + ) + + attn_output, _ = self.dense(context_layer) + return attn_output + + +class BailingMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + +class BailingMoE(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_expert_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE( + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts" + ) + + if self.num_shared_experts > 0: + intermediate_size = (config.moe_intermediate_size * + self.num_shared_experts) + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts" + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + +class BailingMoeBlock(nn.Module): + + def __init__( + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.attention = BailingAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.mlp = BailingMoE(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.attention( + hidden_states=hidden_states, + position_ids=position_ids, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BailingMoeModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.word_embeddings = VocabParallelEmbedding(self.vocab_size, self.embed_dim) + else: + self.word_embeddings = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BailingMoeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers" + ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + hidden_states, + position_ids, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + self.max_position_embeddings = config.max_position_embeddings + self.model = BailingMoeModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = self.word_embeddings if config.tie_word_embeddings \ + else ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if (("v_head" in name) or ("inv_freq" in name) or + (self.config.tie_word_embeddings and "lm_head" in name)): + continue + if self.config.norm_head and "lm_head.weight" in name: + import torch.nn.functional as F + loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e377..c274df1b30aa 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -43,6 +43,7 @@ "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 734f1e09d0fd..458555f22fad 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -28,6 +28,7 @@ from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig __all__ = [ "ChatGLMConfig", @@ -54,4 +55,5 @@ "SolarConfig", "Telechat2Config", "UltravoxConfig", + "BailingMoeConfig", ] diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py new file mode 100644 index 000000000000..8eed67147a93 --- /dev/null +++ b/vllm/transformers_utils/configs/bailing_moe.py @@ -0,0 +1,76 @@ +""" Bailing MoE model configuration """ + +from transformers.configuration_utils import PretrainedConfig + + +class BailingMoeConfig(PretrainedConfig): + model_type = "bailing_moe" + + def __init__( + self, + vocab_size=30592, + hidden_size=1024, + intermediate_size=None, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=0, + hidden_act="silu", + use_qkv_bias=False, # bailing only + use_bias=True, # bailing only + rms_norm_eps=1e-05, + norm_head=False, # bailing only + tie_word_embeddings=False, # PretrainedConfig key, here change default value. + embedding_dropout=0.1, + attention_dropout=0.1, + output_dropout=0.1, + initializer_range=0.02, + max_position_embeddings=16384, + rope_theta=10000.0, + use_cache=True, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + rope_scaling=None, + pad_token_id=126081, + num_experts=16, + num_shared_experts=0, + num_experts_per_tok=2, + norm_topk_prob=True, + moe_intermediate_size=None, + first_k_dense_replace=0, + head_dim=None, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.use_qkv_bias = use_qkv_bias + self.use_bias = use_bias + self.norm_head = norm_head + self.rms_norm_eps = rms_norm_eps + self.embedding_dropout = embedding_dropout + self.attention_dropout = attention_dropout + self.output_dropout = output_dropout + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.use_cache = use_cache + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.rope_scaling = rope_scaling + + # MoE configs + self.num_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + self.moe_intermediate_size = moe_intermediate_size + self.first_k_dense_replace = first_k_dense_replace + + super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) \ No newline at end of file From 42f5f31a69f524472033eb1bbd3a44e633f477d8 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Mon, 7 Jul 2025 08:56:07 +0000 Subject: [PATCH 02/17] fix based on response Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 161 ++++++++++-------- vllm/model_executor/models/registry.py | 2 +- vllm/transformers_utils/configs/__init__.py | 4 +- .../transformers_utils/configs/bailing_moe.py | 14 +- 4 files changed, 103 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 7c0603983760..ab6e7b9dd827 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -1,25 +1,46 @@ -# coding=utf-8 -""" PyTorch Bailing model. """ - -from typing import Iterable, Optional, Tuple, Union, Set +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py +# Copyright 2023 The vLLM team. +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BailingMoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union import torch from torch import nn -from vllm.model_executor.layers.activation import get_act_fn, SiluAndMul -from vllm.attention import Attention, AttentionMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.model_executor.layers.fused_moe import fused_moe, FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ReplicatedLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.distributed import (get_pp_group, @@ -28,12 +49,10 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.config import LoRAConfig from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, @@ -42,17 +61,17 @@ make_layers, maybe_prefix) -KVCache = Tuple[torch.Tensor, torch.Tensor] +KVCache = tuple[torch.Tensor, torch.Tensor] class BailingAttention(nn.Module): def __init__( - self, - config: BailingMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -107,9 +126,9 @@ def __init__( ) def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) @@ -134,16 +153,17 @@ def forward( class BailingMLP(nn.Module): def __init__( - self, - intermediate_size: int, - config: BailingMoeConfig, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: Optional[bool] = True, - prefix: str = "", + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, [intermediate_size] * 2, + config.hidden_size, + [intermediate_size] * 2, bias=config.use_bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", @@ -167,12 +187,12 @@ def forward(self, x): class BailingMoE(nn.Module): def __init__( - self, - intermediate_size: int, - config: BailingMoeConfig, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: Optional[bool] = True, - prefix: str = "", + self, + intermediate_size: int, + config: BailingMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", ): super().__init__() @@ -234,11 +254,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BailingMoeBlock(nn.Module): def __init__( - self, - config: BailingMoeConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: BailingMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -252,10 +272,10 @@ def __init__( self.mlp = BailingMoE(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp") def forward( - self, - hidden_states: torch.Tensor, - position_ids: torch.Tensor, - residual: Optional[torch.Tensor], + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + residual: Optional[torch.Tensor], ) -> torch.Tensor: if residual is None: residual = hidden_states @@ -278,10 +298,10 @@ def forward( class BailingMoeModel(nn.Module): def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", + self, + *, + vllm_config: VllmConfig, + prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_config @@ -326,11 +346,11 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -365,7 +385,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"], "gate_up_proj": [ "gate_proj", "up_proj", @@ -376,8 +395,6 @@ class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): supported_lora_modules = [ "query_key_value", "dense", - "dense_h_to_4h", - "dense_4h_to_h", "gate_up_proj", "down_proj", ] @@ -385,10 +402,10 @@ class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_padding_modules = [] def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", + self, + *, + vllm_config: VllmConfig, + prefix: str = "", ) -> None: super().__init__() @@ -420,33 +437,33 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return model_output def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -459,7 +476,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: num_experts=self.config.num_experts) params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() + loaded_params: set[str] = set() for name, loaded_weight in weights: if (("v_head" in name) or ("inv_freq" in name) or (self.config.tie_word_embeddings and "lm_head" in name)): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c274df1b30aa..2e6728ce2c69 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -41,9 +41,9 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 458555f22fad..68aa187a13b9 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.cohere2 import Cohere2Config from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -28,9 +29,9 @@ from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.telechat2 import Telechat2Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig -from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig __all__ = [ + "BailingMoeConfig", "ChatGLMConfig", "Cohere2Config", "DbrxConfig", @@ -55,5 +56,4 @@ "SolarConfig", "Telechat2Config", "UltravoxConfig", - "BailingMoeConfig", ] diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py index 8eed67147a93..fd6260989d31 100644 --- a/vllm/transformers_utils/configs/bailing_moe.py +++ b/vllm/transformers_utils/configs/bailing_moe.py @@ -1,5 +1,8 @@ -""" Bailing MoE model configuration """ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/configuration_bailing_moe.py from transformers.configuration_utils import PretrainedConfig @@ -19,7 +22,8 @@ def __init__( use_bias=True, # bailing only rms_norm_eps=1e-05, norm_head=False, # bailing only - tie_word_embeddings=False, # PretrainedConfig key, here change default value. + tie_word_embeddings=False, # PretrainedConfig key, + # here change default value. embedding_dropout=0.1, attention_dropout=0.1, output_dropout=0.1, @@ -62,7 +66,11 @@ def __init__( self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window self.max_window_layers = max_window_layers - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.head_dim = ( + head_dim + if head_dim is not None + else self.hidden_size // self.num_attention_heads + ) self.rope_scaling = rope_scaling # MoE configs From 128f4cf6371f5a8a421ede48b833f58490972dc3 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Mon, 7 Jul 2025 09:25:54 +0000 Subject: [PATCH 03/17] Fix the response: E501 line too long Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index ab6e7b9dd827..94439f4b384a 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -422,8 +422,15 @@ def __init__( prefix=maybe_prefix(prefix, "model") ) if get_pp_group().is_last_rank: - self.lm_head = self.word_embeddings if config.tie_word_embeddings \ - else ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + self.lm_head = ( + self.word_embeddings + if config.tie_word_embeddings + else ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config + ) + ) self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() From 3173bc28bc43a1e2ce9db3a40de286b7bf1960af Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Mon, 7 Jul 2025 12:01:48 +0000 Subject: [PATCH 04/17] Adjust import order Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 38 +++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 94439f4b384a..039c805ecd59 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -29,30 +29,30 @@ import torch from torch import nn -from vllm.model_executor.layers.activation import SiluAndMul from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - ReplicatedLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig -from vllm.model_executor.layers.logits_processor import LogitsProcessor from .interfaces import SupportsLoRA, SupportsPP from .utils import (PPMissingLayer, @@ -84,7 +84,8 @@ def __init__( assert self.total_num_heads >= self.total_kv_heads self.num_heads = self.total_num_heads // tp_size - self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) + self.head_dim = config.head_dim or (self.hidden_size // + self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads self.num_kv_heads = self.total_kv_heads // tp_size @@ -115,7 +116,6 @@ def __init__( cache_config=cache_config, prefix=f"{prefix}.attn") - self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -137,7 +137,6 @@ def forward( dim=-1 ) - q, k = self.rotary_emb(position_ids, q, k) context_layer = self.attn( @@ -162,7 +161,7 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, + config.hidden_size, [intermediate_size] * 2, bias=config.use_bias, quant_config=quant_config, @@ -184,6 +183,7 @@ def forward(self, x): x, _ = self.down_proj(x) return x + class BailingMoE(nn.Module): def __init__( @@ -265,9 +265,9 @@ def __init__( intermediate_size = config.intermediate_size self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) self.attention = BailingAttention(config, - cache_config, - quant_config, - prefix=f"{prefix}.attention") + cache_config, + quant_config, + prefix=f"{prefix}.attention") self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) self.mlp = BailingMoE(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp") @@ -299,7 +299,7 @@ class BailingMoeModel(nn.Module): def __init__( self, - *, + *, vllm_config: VllmConfig, prefix: str = "", ): @@ -451,7 +451,7 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + inputs_embeds) return model_output def compute_logits( @@ -486,7 +486,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: if (("v_head" in name) or ("inv_freq" in name) or - (self.config.tie_word_embeddings and "lm_head" in name)): + (self.config.tie_word_embeddings and "lm_head" in name)): continue if self.config.norm_head and "lm_head.weight" in name: import torch.nn.functional as F From db76d88c6244a3e0cc1c9f8b15531d8d40343365 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Mon, 7 Jul 2025 13:01:03 +0000 Subject: [PATCH 05/17] Fix minor formatting issues Signed-off-by: vito.yy --- vllm/model_executor/models/bailing_moe.py | 122 +++++++++--------- .../transformers_utils/configs/bailing_moe.py | 15 +-- 2 files changed, 66 insertions(+), 71 deletions(-) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 039c805ecd59..a70c7300d29e 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -31,8 +31,7 @@ from vllm.attention import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -42,7 +41,7 @@ QKVParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -55,10 +54,8 @@ from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, - make_layers, +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) KVCache = tuple[torch.Tensor, torch.Tensor] @@ -84,14 +81,13 @@ def __init__( assert self.total_num_heads >= self.total_kv_heads self.num_heads = self.total_num_heads // tp_size - self.head_dim = config.head_dim or (self.hidden_size // + self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.q_size_per_rank = self.head_dim * self.num_heads self.num_kv_heads = self.total_kv_heads // tp_size self.kv_size_per_rank = self.num_kv_heads * self.head_dim - - self.scale = self.head_dim ** -0.5 + self.scale = self.head_dim**-0.5 self.query_key_value = QKVParallelLinear( self.hidden_size, @@ -103,11 +99,13 @@ def __init__( prefix=f"{prefix}.query_key_value", ) - self.dense = RowParallelLinear(self.total_num_heads * self.head_dim, - self.hidden_size, - bias=config.use_bias, - quant_config=quant_config, - prefix=f"{prefix}.dense",) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) self.attn = Attention(self.num_heads, self.head_dim, @@ -132,10 +130,10 @@ def forward( ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) - q, k, v = qkv.split( - [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], - dim=-1 - ) + q, k, v = qkv.split([ + self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank + ], + dim=-1) q, k = self.rotary_emb(position_ids, q, k) @@ -210,16 +208,14 @@ def __init__( bias=False, quant_config=None) - self.experts = FusedMoE( - num_experts=self.num_experts, - top_k=self.top_k, - hidden_size=self.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=self.norm_expert_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts" - ) + self.experts = FusedMoE(num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") if self.num_shared_experts > 0: intermediate_size = (config.moe_intermediate_size * @@ -229,8 +225,7 @@ def __init__( config=config, quant_config=quant_config, reduce_results=False, - prefix=f"{prefix}.shared_experts" - ) + prefix=f"{prefix}.shared_experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -239,9 +234,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if self.num_shared_experts > 0: final_hidden_states = final_hidden_states + shared_output @@ -251,6 +245,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states) return final_hidden_states.view(num_tokens, hidden_size) + class BailingMoeBlock(nn.Module): def __init__( @@ -268,8 +263,13 @@ def __init__( cache_config, quant_config, prefix=f"{prefix}.attention") - self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) - self.mlp = BailingMoE(intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp") + self.post_attention_layernorm = RMSNorm(hidden_size, + eps=config.rms_norm_eps) + self.mlp = BailingMoE(intermediate_size, + config, + quant_config, + True, + prefix=f"{prefix}.mlp") def forward( self, @@ -314,7 +314,8 @@ def __init__( if get_pp_group().is_first_rank or (config.tie_word_embeddings and get_pp_group().is_last_rank): - self.word_embeddings = VocabParallelEmbedding(self.vocab_size, self.embed_dim) + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, self.embed_dim) else: self.word_embeddings = PPMissingLayer() @@ -328,20 +329,17 @@ def __init__( quant_config=quant_config, prefix=prefix, ), - prefix=f"{prefix}.layers" - ) + prefix=f"{prefix}.layers") self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) - ) + ["hidden_states", "residual"], config.hidden_size)) if get_pp_group().is_last_rank: self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() - + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings(input_ids) @@ -417,28 +415,20 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config self.max_position_embeddings = config.max_position_embeddings - self.model = BailingMoeModel( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model") - ) + self.model = BailingMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: - self.lm_head = ( - self.word_embeddings - if config.tie_word_embeddings - else ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config - ) - ) + self.lm_head = (self.word_embeddings if config.tie_word_embeddings + else ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config)) self.logits_processor = LogitsProcessor(config.vocab_size) else: self.lm_head = PPMissingLayer() self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) + self.model.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -459,7 +449,8 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) return logits def sample( @@ -470,7 +461,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -490,7 +482,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue if self.config.norm_head and "lm_head.weight" in name: import torch.nn.functional as F - loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7) + loaded_weight = F.normalize(loaded_weight, + dim=0, + p=2, + eps=1e-7) for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -538,7 +533,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) - return loaded_params \ No newline at end of file + return loaded_params diff --git a/vllm/transformers_utils/configs/bailing_moe.py b/vllm/transformers_utils/configs/bailing_moe.py index fd6260989d31..60315dc950be 100644 --- a/vllm/transformers_utils/configs/bailing_moe.py +++ b/vllm/transformers_utils/configs/bailing_moe.py @@ -22,8 +22,8 @@ def __init__( use_bias=True, # bailing only rms_norm_eps=1e-05, norm_head=False, # bailing only - tie_word_embeddings=False, # PretrainedConfig key, - # here change default value. + tie_word_embeddings=False, # PretrainedConfig key, + # here change default value. embedding_dropout=0.1, attention_dropout=0.1, output_dropout=0.1, @@ -66,11 +66,8 @@ def __init__( self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window self.max_window_layers = max_window_layers - self.head_dim = ( - head_dim - if head_dim is not None - else self.hidden_size // self.num_attention_heads - ) + self.head_dim = (head_dim if head_dim is not None else + self.hidden_size // self.num_attention_heads) self.rope_scaling = rope_scaling # MoE configs @@ -81,4 +78,6 @@ def __init__( self.moe_intermediate_size = moe_intermediate_size self.first_k_dense_replace = first_k_dense_replace - super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) \ No newline at end of file + super().__init__(pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs) From adf9d271d82bca774a00ec74e13ef747fddeb9b9 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Wed, 9 Jul 2025 06:36:14 +0000 Subject: [PATCH 06/17] Add content to supported_models.md and test files Signed-off-by: vito.yy --- docs/models/supported_models.md | 3 ++- tests/models/registry.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e003a3e31717..97e82d45e939 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -316,6 +316,7 @@ Specified using `--task generate`. | `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | @@ -738,4 +739,4 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. \ No newline at end of file diff --git a/tests/models/registry.py b/tests/models/registry.py index 48302f9d6648..8f2917f4b764 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -141,6 +141,8 @@ def check_available_online( trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", From 05815862d890dd365caedf267140f541d3323d8e Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Wed, 9 Jul 2025 06:57:19 +0000 Subject: [PATCH 07/17] Small fix Signed-off-by: vito.yy --- docs/models/supported_models.md | 2 +- tests/models/registry.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 97e82d45e939..9935430ecb57 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -739,4 +739,4 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. \ No newline at end of file +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. diff --git a/tests/models/registry.py b/tests/models/registry.py index 8f2917f4b764..7bad81c53eb5 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -142,7 +142,7 @@ def check_available_online( "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", - trust_remote_code=True), + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", @@ -502,4 +502,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) From bf58e570f6dbc9981582e0d3e939b6a61ad9052e Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 8 Jul 2025 20:03:35 -0700 Subject: [PATCH 08/17] [feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640) Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 10 ++++------ vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d771a7a54cfc..de588d512739 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, - w1: torch.Tensor, +def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, w2: torch.Tensor) -> bool: - def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): - return M >= 128 and N % 128 == 0 and K % 128 == 0 + def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): + return N % 128 == 0 and K % 128 == 0 - m = hidden_states.size(0) _, K, N = w2.size() - if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): logger.debug( "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") return False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d90..d0ff44a38a4a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1180,7 +1180,7 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( a=hidden_states, From cdff58bd21d9697f1fecb9c019522c6d5c3c4640 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 9 Jul 2025 12:03:41 +0900 Subject: [PATCH 09/17] Fix bullets in incremental_build.md (#20642) --- docs/contributing/incremental_build.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md index 33584fdd5d40..5ac80fa66bf2 100644 --- a/docs/contributing/incremental_build.md +++ b/docs/contributing/incremental_build.md @@ -84,6 +84,7 @@ Below is an example of what the generated `CMakeUserPresets.json` might look lik ``` **What do the various configurations mean?** + - `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically. - `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default. - `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable. From 3e53b33945615befa0551f627d03990f3c3bbb3b Mon Sep 17 00:00:00 2001 From: B-201 Date: Wed, 9 Jul 2025 11:15:44 +0800 Subject: [PATCH 10/17] [Misc] Fix the size of batched_dummy_mm_inputs in profile_run (#20434) Signed-off-by: bk-201 --- tests/models/registry.py | 3 ++- vllm/v1/worker/gpu_model_runner.py | 12 +++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 7bad81c53eb5..10da077e5b5a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -414,7 +414,8 @@ def check_available_online( hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 + max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8658d7d916f0..ef03626cf14d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2219,8 +2219,8 @@ def profile_run(self) -> None: encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens_per_mm_item) + max_num_mm_items_encoder_budget = encoder_budget // \ + max_tokens_per_mm_item # Check how many items of this modality can be supported by # the decoder budget. @@ -2233,8 +2233,10 @@ def profile_run(self) -> None: max_num_mm_items_decoder_budget = self.max_num_reqs * \ max_mm_items_per_req - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) + max_num_mm_items = max( + 1, + min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget)) logger.info( "Encoder cache will be initialized with a budget of %s tokens," @@ -2244,7 +2246,7 @@ def profile_run(self) -> None: # Create dummy batch of multimodal inputs. dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=max_tokens_per_mm_item, mm_counts={ dummy_data_modality: 1 }, From 0100e50ee8f591c9aaf1e535e605a88b1e676967 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Wed, 9 Jul 2025 00:34:28 -0700 Subject: [PATCH 11/17] [XPU] Use spawn with XPU multiprocessing (#20649) Signed-off-by: Dmitry Rogozhkin --- tests/utils.py | 7 ++++--- tests/v1/e2e/test_cascade_attention.py | 4 ++-- vllm/utils/__init__.py | 9 +++++++++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index a37872830dad..f4317e6bdb40 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -818,14 +818,15 @@ def create_new_process_for_each_test( Args: method: The process creation method. Can be either "spawn" or "fork". - If not specified, - it defaults to "spawn" on ROCm platforms and "fork" otherwise. + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. Returns: A decorator to run test functions in separate processes. """ if method is None: - method = "spawn" if current_platform.is_rocm() else "fork" + use_spawn = current_platform.is_rocm() or current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index 161bcd4d3ef9..f2f460513605 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -5,10 +5,10 @@ from vllm import LLM, SamplingParams -from ...utils import fork_new_process_for_each_test +from ...utils import create_new_process_for_each_test -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) def test_cascade_attention(example_system_message, monkeypatch, attn_backend): diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index bfdbd682464a..cf7320a19e4d 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1535,6 +1535,13 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def xpu_is_initialized() -> bool: + """Check if XPU is initialized.""" + if not torch.xpu._is_compiled(): + return False + return torch.xpu.is_initialized() + + def cuda_get_device_properties(device, names: Sequence[str], init_cuda=False) -> tuple[Any, ...]: @@ -2848,6 +2855,8 @@ def _maybe_force_spawn(): reason = None if cuda_is_initialized(): reason = "CUDA is initialized" + elif xpu_is_initialized(): + reason = "XPU is initialized" elif is_in_ray_actor(): # even if we choose to spawn, we need to pass the ray address # to the subprocess so that it knows how to connect to the ray cluster. From a95d0d1576ea1de275403377913c8bd31bb2cacd Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 9 Jul 2025 15:36:58 +0800 Subject: [PATCH 12/17] [Intel GPU] support ray as distributed executor backend for XPU. (#20659) Signed-off-by: Kunshang Ji --- .buildkite/scripts/hardware_ci/run-xpu-test.sh | 2 ++ vllm/executor/ray_distributed_executor.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index a23abdc1ed6c..7589b48b584d 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -27,6 +27,8 @@ docker run \ "${image_name}" \ sh -c ' VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests pytest -v -s v1/core ' diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 6f11dcd19e9c..dec32f8e50fa 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -62,7 +62,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1 and not current_platform.is_xpu(): + if envs.VLLM_USE_V1: # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" From bce930dddb6af93afa37184f100b48baec899a83 Mon Sep 17 00:00:00 2001 From: qscqesze Date: Wed, 9 Jul 2025 15:37:07 +0800 Subject: [PATCH 13/17] [Docs] fix minimax tool_calling docs error (#20667) Signed-off-by: qingjun --- docs/features/tool_calling.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 13a8386a2971..c68b3aef5828 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -268,10 +268,10 @@ Flags: `--tool-call-parser hermes` Supported models: -* `MiniMaxAi/MiniMax-M1-40k` (use with ) -* `MiniMaxAi/MiniMax-M1-80k` (use with ) +* `MiniMaxAi/MiniMax-M1-40k` (use with ) +* `MiniMaxAi/MiniMax-M1-80k` (use with ) -Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax.jinja` +Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax_m1.jinja` ### DeepSeek-V3 Models (`deepseek_v3`) From 63cfe24c8bfa5f02616cb398794346a80f60ff34 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Wed, 9 Jul 2025 15:39:58 +0800 Subject: [PATCH 14/17] [Bugfix] Fix the issue where `reasoning_content` is `None` when Thinkng is enabled and `tool_choice` is set to `'required'`. (#20662) Signed-off-by: chaunceyjiang --- .../openai/test_completion_with_function_calling.py | 6 +++++- vllm/entrypoints/openai/serving_chat.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 84ad7a09165a..799648d3992e 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -145,7 +145,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "enable_thinking": enable_thinking } }) - + if enable_thinking: + assert chat_completion.choices[0].message.\ + reasoning_content is not None + assert chat_completion.choices[0].message.\ + reasoning_content != "" assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 else: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a802fbc3865f..451241d3f9f7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1049,6 +1049,7 @@ async def chat_completion_full_generator( message = ChatMessage( role=role, content="", + reasoning_content=reasoning_content, tool_calls=[ tool_call_class(function=FunctionCall( name=tool_call.name, From d57795f55e478798734bebfa5ef0149a8a867284 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 9 Jul 2025 10:02:41 +0200 Subject: [PATCH 15/17] [V1] [Doc] Update V1 docs for Mamba models (#20499) Signed-off-by: Thomas Parnell Co-authored-by: Cyrus Leung --- docs/models/supported_models.md | 12 ++++++------ docs/usage/v1_guide.md | 14 +++++++++++--- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9935430ecb57..2c0f0fc00f68 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -317,7 +317,7 @@ Specified using `--task generate`. | `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | | `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -333,7 +333,7 @@ Specified using `--task generate`. | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | | `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -346,7 +346,7 @@ Specified using `--task generate`. | `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | | `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | @@ -358,14 +358,14 @@ Specified using `--task generate`. | `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | -| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | +| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | | `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | | `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | | `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | | `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | | `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | @@ -390,7 +390,7 @@ Specified using `--task generate`. | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | | | `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 8b50802e6a8e..459ea2d676c1 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | 🚀 Optimized | | **Encoder-Decoder Models** | 🟠 Delayed | | **Embedding Models** | 🟢 Functional | -| **Mamba Models** | 🚧 WIP () | +| **Mamba Models** | 🟢 (Mamba-2), 🟡 (Mamba-1) | | **Multimodal Models** | 🟢 Functional | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. @@ -104,8 +104,16 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention (e.g., `MambaForCausalLM`, `JambaForCausalLM`) -will be supported via . +Models using selective state-space mechanisms instead of standard transformer attention are partially supported. +Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers +(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet suported. Please note that these models currently require +enforcing eager mode and disabling prefix caching in V1. + +Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that +these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention +backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible +using the `vllm serve` CLI yet). #### Encoder-Decoder Models From 9143ce6ea3711861973a0167316e83ad721c5ae5 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Wed, 9 Jul 2025 06:36:14 +0000 Subject: [PATCH 16/17] Add content to supported_models.md and test files Signed-off-by: vito.yy --- docs/models/supported_models.md | 2 +- tests/models/registry.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 2c0f0fc00f68..576372ac0609 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -739,4 +739,4 @@ We have the following levels of testing for models: 1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test. 2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test. 3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test. -4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. +4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category. \ No newline at end of file diff --git a/tests/models/registry.py b/tests/models/registry.py index 10da077e5b5a..360c07e0f18e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -143,6 +143,8 @@ def check_available_online( trust_remote_code=True), "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", trust_remote_code=True), + "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", From 483fe2e02e3d1f8b049c476092f1806c39140361 Mon Sep 17 00:00:00 2001 From: "vito.yy" Date: Wed, 9 Jul 2025 09:58:49 +0000 Subject: [PATCH 17/17] Resolve conflicts Signed-off-by: vito.yy --- tests/models/registry.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 360c07e0f18e..10da077e5b5a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -143,8 +143,6 @@ def check_available_online( trust_remote_code=True), "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", trust_remote_code=True), - "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", - trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m",