Skip to content

Commit 8bcc35c

Browse files
committed
Apply yapf formatting via pre-commit
1 parent c54a949 commit 8bcc35c

File tree

1 file changed

+54
-30
lines changed

1 file changed

+54
-30
lines changed

vllm/model_executor/models/pixtral.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -339,13 +339,14 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
339339
raise ValueError("Only image modality is supported")
340340

341341
packed_modules_mapping = {}
342+
342343
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
343344
super().__init__()
344345
config = vllm_config.model_config.hf_config
345346
multimodal_config = vllm_config.model_config.multimodal_config
346347
self.config = config
347348
self.multimodal_config = multimodal_config
348-
349+
349350
dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
350351
vision_args = {
351352
key: value
@@ -487,42 +488,66 @@ def compute_logits(
487488

488489
# Reverse mapping from HF to original Pixtral format
489490
MISTRAL3_REVERSE_MAPPING = {
490-
r"^language_model\.lm_head\.weight": r"output.weight",
491-
r"^language_model\.model\.norm\.weight": r"norm.weight",
492-
r"^language_model\.model\.embed_tokens\.weight": r"tok_embeddings.weight",
493-
r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight": r"layers.\1.attention_norm.weight",
494-
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight": r"layers.\1.ffn_norm.weight",
495-
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight": r"layers.\1.attention.w\2.weight",
496-
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight": r"layers.\1.feed_forward.w1.weight",
497-
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight": r"layers.\1.feed_forward.w2.weight",
498-
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight": r"layers.\1.feed_forward.w3.weight",
499-
r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight": r"vision_encoder.transformer.layers.\1.attention_norm.weight",
500-
r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight": r"vision_encoder.transformer.layers.\1.ffn_norm.weight",
501-
r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight": r"vision_encoder.transformer.layers.\1.attention.w\2.weight",
502-
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight",
503-
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight",
504-
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight": r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight",
505-
r"^multi_modal_projector\.linear_1": r"vision_language_adapter.w_in",
506-
r"^multi_modal_projector\.linear_2": r"vision_language_adapter.w_out",
507-
r"^vision_tower\.ln_pre\.weight": r"vision_encoder.ln_pre.weight",
508-
r"^vision_tower\.patch_conv\.weight": r"vision_encoder.patch_conv.weight",
509-
r"^multi_modal_projector\.patch_merger\.merging_layer\.weight": r"patch_merger.merging_layer.weight",
510-
r"^multi_modal_projector\.norm\.weight": r"pre_mm_projector_norm.weight",
511-
r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$": r"layers.\1.\2.\3"
491+
r"^language_model\.lm_head\.weight":
492+
r"output.weight",
493+
r"^language_model\.model\.norm\.weight":
494+
r"norm.weight",
495+
r"^language_model\.model\.embed_tokens\.weight":
496+
r"tok_embeddings.weight",
497+
r"^language_model\.model\.layers\.(\d+)\.input_layernorm\.weight":
498+
r"layers.\1.attention_norm.weight",
499+
r"^language_model\.model\.layers\.(\d+)\.post_attention_layernorm\.weight":
500+
r"layers.\1.ffn_norm.weight",
501+
r"^language_model\.model\.layers\.(\d+)\.self_attn\.(q|k|v|o)_proj\.weight":
502+
r"layers.\1.attention.w\2.weight",
503+
r"^language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.weight":
504+
r"layers.\1.feed_forward.w1.weight",
505+
r"^language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.weight":
506+
r"layers.\1.feed_forward.w2.weight",
507+
r"^language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.weight":
508+
r"layers.\1.feed_forward.w3.weight",
509+
r"^vision_tower\.transformer\.layers\.(\d+)\.attention_norm\.weight":
510+
r"vision_encoder.transformer.layers.\1.attention_norm.weight",
511+
r"^vision_tower\.transformer\.layers\.(\d+)\.ffn_norm\.weight":
512+
r"vision_encoder.transformer.layers.\1.ffn_norm.weight",
513+
r"^vision_tower\.transformer\.layers\.(\d+)\.attention\.(q|k|v|o)_proj\.weight":
514+
r"vision_encoder.transformer.layers.\1.attention.w\2.weight",
515+
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.gate_proj\.weight":
516+
r"vision_encoder.transformer.layers.\1.feed_forward.w1.weight",
517+
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.down_proj\.weight":
518+
r"vision_encoder.transformer.layers.\1.feed_forward.w2.weight",
519+
r"^vision_tower\.transformer\.layers\.(\d+)\.feed_forward\.up_proj\.weight":
520+
r"vision_encoder.transformer.layers.\1.feed_forward.w3.weight",
521+
r"^multi_modal_projector\.linear_1":
522+
r"vision_language_adapter.w_in",
523+
r"^multi_modal_projector\.linear_2":
524+
r"vision_language_adapter.w_out",
525+
r"^vision_tower\.ln_pre\.weight":
526+
r"vision_encoder.ln_pre.weight",
527+
r"^vision_tower\.patch_conv\.weight":
528+
r"vision_encoder.patch_conv.weight",
529+
r"^multi_modal_projector\.patch_merger\.merging_layer\.weight":
530+
r"patch_merger.merging_layer.weight",
531+
r"^multi_modal_projector\.norm\.weight":
532+
r"pre_mm_projector_norm.weight",
533+
r"^language_model\.model\.layers\.(\d+)\.(.+)\.(g_idx|zp|scales|zeros|qweight|qzeros)$":
534+
r"layers.\1.\2.\3"
512535
}
513536

514-
def maybe_remap_mistral3(self, name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
537+
def maybe_remap_mistral3(self, name: str,
538+
tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
515539
"""Remap HF-style weight names back to original Pixtral format."""
516540

517541
for pattern, replacement in self.MISTRAL3_REVERSE_MAPPING.items():
518542
new_name, n_replace = re.subn(pattern, replacement, name)
519543
if n_replace > 0:
520-
logger.debug(f"remapped %s to %s for Pixtral compat", name, new_name)
544+
logger.debug("remapped %s to %s for Pixtral compat", name,
545+
new_name)
521546
return new_name, tensor
522547
return name, tensor # Return unchanged if no match
523548

524549
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
525-
550+
526551
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
527552
return weight[0].startswith("vision_encoder")
528553

@@ -554,7 +579,8 @@ def inverse_permute_for_rope(tensor, n_heads, dim1, dim2):
554579

555580
def llm_weights_generator():
556581
# Single pass over weights
557-
remapped_weights = (self.maybe_remap_mistral3(name, w) for name, w in weights)
582+
remapped_weights = (self.maybe_remap_mistral3(name, w)
583+
for name, w in weights)
558584
for name, w in remapped_weights:
559585
if is_vision_encoder_weights((name, w)):
560586
# Load vision encoder weights directly
@@ -565,9 +591,7 @@ def llm_weights_generator():
565591
dim1 = param.shape[0] # num_heads * head_dim
566592
dim2 = param.shape[1] # hidden_size
567593
w = inverse_permute_for_rope(w, n_heads, dim1, dim2)
568-
logger.debug(
569-
"reversed permute_for_rope for %s", name
570-
)
594+
logger.debug("reversed permute_for_rope for %s", name)
571595
with torch.no_grad():
572596
default_weight_loader(param, w)
573597
elif is_patch_merger((name, w)):

0 commit comments

Comments
 (0)