diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 88c3823eaa19..7febe60795a3 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -120,13 +120,13 @@ def __init__(self, self.linear_1 = ColumnParallelLinear(vision_hidden_size, text_hidden_size, bias=multimodal_projector_bias, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.linear_1") self.act = get_act_fn(projector_hidden_act) self.linear_2 = RowParallelLinear(text_hidden_size, text_hidden_size, bias=multimodal_projector_bias, - quant_config=quant_config, + quant_config=None, prefix=f"{prefix}.linear_2") def forward(self, image_features: torch.Tensor, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 475d65a58b2a..ec0e40e29373 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +import re from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property @@ -25,6 +26,7 @@ from vllm.config import VllmConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -52,6 +54,8 @@ merge_multimodal_embeddings) from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs +logger = init_logger(__name__) + try: from xformers import ops as xops USE_XFORMERS_OPS = True @@ -334,6 +338,8 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only image modality is supported") + packed_modules_mapping = {} + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -480,6 +486,66 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) + # Reverse mapping from HF to original Pixtral format + MISTRAL3_REVERSE_MAPPING = { + r"^language_model\.lm_head\.weight": + r"output.weight", + r"^language_model\.model\.norm\.weight": + r"norm.weight", + r"^language_model\.model\.embed_tokens\.weight": + r"tok_embeddings.weight", + r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": + r"layers.\1.attention_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": + r"layers.\1.ffn_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": + r"layers.\1.attention.w\2.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": + r"layers.\1.feed_forward.w1.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": + r"layers.\1.feed_forward.w2.weight", + r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": + r"layers.\1.feed_forward.w3.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": + r"vision_encoder.transformer.layers.\1.attention_norm.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": + r"vision_encoder.transformer.layers.\1.ffn_norm.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": + r"vision_encoder.transformer.layers.\1.attention.w\2.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": + r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": + r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight", + r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": + r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight", + r"^multi_modal_projector\.linear_1": + r"vision_language_adapter.w_in", + r"^multi_modal_projector\.linear_2": + r"vision_language_adapter.w_out", + r"^vision_tower\.ln_pre\.weight": + r"vision_encoder.ln_pre.weight", + r"^vision_tower\.patch_conv\.weight": + r"vision_encoder.patch_conv.weight", + r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": + r"patch_merger.merging_layer.weight", + r"^multi_modal_projector\.norm\.weight": + r"pre_mm_projector_norm.weight", + r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": + r"layers.\1.\2.\3" + } + + def maybe_remap_mistral3(self, name: str, + tensor: torch.Tensor) -> tuple[str, torch.Tensor]: + """Remap HF-style weight names back to original Pixtral format.""" + + for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items(): + new_name, n_replace = re.subn(pattern, replacement, name) + if n_replace > 0: + logger.debug("remapped %s to %s for Pixtral compat", name, + new_name) + return new_name, tensor + return name, tensor # Return unchanged if no match + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): @@ -504,13 +570,28 @@ def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): vision_lang_adapter_dict = dict( self.vision_language_adapter.named_parameters()) + def inverse_permute_for_rope(tensor, n_heads, dim1, dim2): + """Reverse the permutation applied for ROPE in HF format.""" + tensor = tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2) + tensor = tensor.transpose(1, 2) + tensor = tensor.reshape(dim1, dim2) + return tensor + def llm_weights_generator(): # Single pass over weights - for name, w in weights: + remapped_weights = (self.maybe_remap_mistral3(name, w) + for name, w in weights) + for name, w in remapped_weights: if is_vision_encoder_weights((name, w)): # Load vision encoder weights directly trimmed_name = '.'.join(name.split(".")[1:]) param = vision_encoder_dict[trimmed_name] + if "wq.weight" in trimmed_name or "wk.weight" in trimmed_name: + n_heads = self.vision_args.num_attention_heads + dim1 = param.shape[0] # num_heads * head_dim + dim2 = param.shape[1] # hidden_size + w = inverse_permute_for_rope(w, n_heads, dim1, dim2) + logger.debug("reversed permute_for_rope for %s", name) with torch.no_grad(): default_weight_loader(param, w) elif is_patch_merger((name, w)): @@ -554,7 +635,7 @@ class VisionEncoderArgs: image_token_id: int adapter_bias: bool = True spatial_merge_size: int = 1 - add_pre_mm_projector_layer_norm: bool = False + add_pre_mm_projector_layer_norm: bool = True mm_projector_id: str = "" @@ -847,9 +928,10 @@ def __init__( super().__init__() mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) - self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim + logger.debug("mlp_input_dim = %d (from %d * (%d ** 2))", mlp_input_dim, + vision_encoder_dim, spatial_merge_size) self.merging_layer = nn.Linear( mlp_input_dim, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e377..717b2304c275 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -216,6 +216,7 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 + "Mistral3ForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501