Skip to content

Commit 14601f5

Browse files
[Config] Refactor mistral configs (#20570)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 042d131 commit 14601f5

File tree

3 files changed

+167
-113
lines changed

3 files changed

+167
-113
lines changed

vllm/model_executor/models/llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
491491
"qscale_act": "input_scale",
492492
"qscale_weight": "weight_scale",
493493
"kv_fake_quantizer.qscale_act": "kv_scale",
494+
"q_fake_quantizer.qscale_act": "attn.q_scale",
495+
"k_fake_quantizer.qscale_act": "k_scale",
496+
"v_fake_quantizer.qscale_act": "v_scale",
494497
"wq": "q_proj",
495498
"wk": "k_proj",
496499
"wv": "v_proj",

vllm/transformers_utils/config.py

Lines changed: 44 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
from functools import cache, partial
99
from pathlib import Path
10-
from typing import Any, Callable, Literal, Optional, TypeVar, Union
10+
from typing import Any, Callable, Optional, TypeVar, Union
1111

1212
import huggingface_hub
1313
from huggingface_hub import get_safetensors_metadata, hf_hub_download
@@ -42,6 +42,7 @@
4242
SkyworkR1VChatConfig, SolarConfig,
4343
Telechat2Config, UltravoxConfig)
4444
# yapf: enable
45+
from vllm.transformers_utils.configs.mistral import adapt_config_dict
4546
from vllm.transformers_utils.utils import check_gguf_file
4647
from vllm.utils import resolve_obj_by_qualname
4748

@@ -394,7 +395,16 @@ def get_config(
394395
config = _maybe_remap_hf_config_attrs(config)
395396

396397
elif config_format == ConfigFormat.MISTRAL:
397-
config = load_params_config(model, revision, **kwargs)
398+
# This function loads a params.json config which
399+
# should be used when loading models in mistral format
400+
config_dict = _download_mistral_config_file(model, revision)
401+
if (max_position_embeddings :=
402+
config_dict.get("max_position_embeddings")) is None:
403+
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
404+
model, revision, **kwargs)
405+
config_dict["max_position_embeddings"] = max_position_embeddings
406+
407+
config = adapt_config_dict(config_dict)
398408
else:
399409
supported_formats = [
400410
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
@@ -693,117 +703,6 @@ def _reduce_config(config: VllmConfig):
693703
exc_info=e)
694704

695705

696-
def load_params_config(model: Union[str, Path], revision: Optional[str],
697-
**kwargs) -> PretrainedConfig:
698-
# This function loads a params.json config which
699-
# should be used when loading models in mistral format
700-
701-
config_file_name = "params.json"
702-
703-
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
704-
if config_dict is None:
705-
raise ValueError(
706-
f"Failed to load mistral '{config_file_name}' config for model "
707-
f"{model}. Please check if the model is a mistral-format model "
708-
f"and if the config file exists.")
709-
assert isinstance(config_dict, dict)
710-
711-
config_mapping = {
712-
"dim": "hidden_size",
713-
"norm_eps": "rms_norm_eps",
714-
"n_kv_heads": "num_key_value_heads",
715-
"n_layers": "num_hidden_layers",
716-
"n_heads": "num_attention_heads",
717-
"hidden_dim": "intermediate_size",
718-
}
719-
720-
def recurse_elems(elem: Any):
721-
if isinstance(elem, dict):
722-
config_dict = {}
723-
for key, value in elem.items():
724-
key = config_mapping.get(key, key)
725-
config_dict[key] = recurse_elems(value)
726-
727-
return config_dict
728-
else:
729-
return elem
730-
731-
config_dict["model_type"] = config_dict.get("model_type", "transformer")
732-
config_dict["hidden_act"] = config_dict.get("activation", "silu")
733-
config_dict["tie_word_embeddings"] = config_dict.get(
734-
"tie_embeddings", False)
735-
736-
if config_dict.get("max_position_embeddings") is None:
737-
max_position_embeddings = 128_000
738-
try:
739-
trust_remote_code_val = kwargs.get("trust_remote_code", False)
740-
hf_config = get_config(model=model,
741-
trust_remote_code=trust_remote_code_val,
742-
revision=revision,
743-
config_format=ConfigFormat.HF)
744-
if hf_value := hf_config.get_text_config().max_position_embeddings:
745-
max_position_embeddings = hf_value
746-
except Exception as e:
747-
logger.warning(
748-
"The params.json file is missing 'max_position_embeddings'"
749-
" and could not get a value from the HF config."
750-
" Defaulting to 128000",
751-
exc_info=e)
752-
config_dict["max_position_embeddings"] = max_position_embeddings
753-
754-
if config_dict.get("quantization") is not None:
755-
quantization = config_dict.get("quantization", {})
756-
if quantization.get("qformat_weight") == "fp8_e4m3":
757-
# This maps to the FP8 static per-tensor quantization scheme
758-
quantization_config = {
759-
"quant_method": "fp8",
760-
"activation_scheme": "static"
761-
}
762-
elif quantization.get("quant_method") == "compressed-tensors":
763-
# Pass through the quantization config to compressed-tensors
764-
quantization_config = quantization
765-
else:
766-
raise ValueError(
767-
f"Found unknown quantization='{quantization}' in config")
768-
769-
config_dict["quantization_config"] = quantization_config
770-
771-
config_type: Literal["text",
772-
"multimodal"] = "multimodal" if config_dict.get(
773-
"vision_encoder") is not None else "text"
774-
775-
if config_dict.get("moe") is not None:
776-
config_dict["architectures"] = ["MixtralForCausalLM"]
777-
else:
778-
config_dict["architectures"] = ["MistralForCausalLM"]
779-
780-
if config_type == "multimodal":
781-
multimodal_config = config_dict.pop("vision_encoder")
782-
quantization_config = config_dict.get("quantization_config", {})
783-
784-
config_dict = {
785-
"text_config": config_dict,
786-
"vision_config": multimodal_config
787-
}
788-
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
789-
config_dict["model_type"] = "pixtral"
790-
if quantization_config:
791-
config_dict["quantization_config"] = quantization_config
792-
793-
config_dict.update(kwargs)
794-
795-
config_dict = recurse_elems(config_dict)
796-
797-
# transform to HF config format
798-
if config_type == "multimodal":
799-
config_dict["text_config"] = PretrainedConfig(
800-
**config_dict["text_config"])
801-
config_dict["vision_config"] = PretrainedConfig(
802-
**config_dict["vision_config"])
803-
804-
return PretrainedConfig(**config_dict)
805-
806-
807706
def get_hf_image_processor_config(
808707
model: Union[str, Path],
809708
hf_token: Optional[Union[bool, str]] = None,
@@ -920,3 +819,35 @@ def try_get_tokenizer_config(
920819
)
921820
except Exception:
922821
return None
822+
823+
824+
def _download_mistral_config_file(model, revision) -> dict:
825+
config_file_name = "params.json"
826+
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
827+
if config_dict is None:
828+
raise ValueError(
829+
f"Failed to load mistral '{config_file_name}' config for model "
830+
f"{model}. Please check if the model is a mistral-format model "
831+
f"and if the config file exists.")
832+
assert isinstance(config_dict, dict)
833+
return config_dict
834+
835+
836+
def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
837+
max_position_embeddings = 128_000
838+
try:
839+
trust_remote_code_val = kwargs.get("trust_remote_code", False)
840+
hf_config = get_config(model=model,
841+
trust_remote_code=trust_remote_code_val,
842+
revision=revision,
843+
config_format=ConfigFormat.HF)
844+
if hf_value := hf_config.get_text_config().max_position_embeddings:
845+
max_position_embeddings = hf_value
846+
except Exception as e:
847+
logger.warning(
848+
"The params.json file is missing 'max_position_embeddings'"
849+
" and could not get a value from the HF config."
850+
" Defaulting to 128000",
851+
exc_info=e)
852+
853+
return max_position_embeddings
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any
4+
5+
from transformers import PretrainedConfig
6+
7+
from vllm.logger import init_logger
8+
9+
logger = init_logger(__name__)
10+
11+
12+
def adapt_config_dict(config_dict: dict[str, Any],
13+
**kwargs) -> PretrainedConfig:
14+
config_dict.update(kwargs)
15+
config_dict = _remap_general_mistral_args(config_dict)
16+
17+
if bool(config_dict.get("quantization")):
18+
config_dict = _remap_mistral_quantization_args(config_dict)
19+
20+
if bool(config_dict.get("moe")):
21+
config_dict["architectures"] = ["MixtralForCausalLM"]
22+
else:
23+
config_dict["architectures"] = ["MistralForCausalLM"]
24+
25+
if bool(config_dict.get("yarn")):
26+
config_dict = _remap_mistral_yarn_args(config_dict)
27+
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
28+
or config_dict.get("vision_encoder")):
29+
config_dict = _remap_mistral_vision_args(config_dict)
30+
31+
config = PretrainedConfig.from_dict(config_dict)
32+
33+
logger.debug("Initialized config", config)
34+
35+
return config
36+
37+
38+
def _remap_mistral_vision_args(config: dict) -> dict:
39+
if config.get("multimodal"):
40+
vision_config = config.pop("multimodal")
41+
else:
42+
vision_config = config.pop("vision_encoder")
43+
44+
quant_config = config.get("quantization_config")
45+
config = {
46+
"model_type": "pixtral",
47+
"architectures": ["PixtralForConditionalGeneration"],
48+
"text_config": PretrainedConfig.from_dict(config),
49+
"vision_config": PretrainedConfig.from_dict(vision_config),
50+
}
51+
if quant_config:
52+
config["quantization_config"] = quant_config
53+
return config
54+
55+
56+
def _remap_mistral_yarn_args(config: dict) -> dict:
57+
# Direct remaps: yarn.X -> rope_scaling.Y
58+
# Source keys are from mistral.model.args.YarnArgs
59+
_map = {
60+
"beta": "beta_fast",
61+
"alpha": "beta_slow",
62+
}
63+
yarn_config = config.get("yarn") or {}
64+
renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
65+
config["rope_scaling"] = {
66+
"rope_type": "yarn",
67+
"mscale_all_dim": 1, # We hardcoded this to 1
68+
**renamed_yarn_config
69+
}
70+
return config
71+
72+
73+
def _remap_general_mistral_args(config: dict) -> dict:
74+
# Mistral key -> HF key
75+
config_mapping = {
76+
"dim": "hidden_size",
77+
"norm_eps": "rms_norm_eps",
78+
"n_kv_heads": "num_key_value_heads",
79+
"n_layers": "num_hidden_layers",
80+
"n_heads": "num_attention_heads",
81+
"hidden_dim": "intermediate_size",
82+
}
83+
# HF key -> (Mistral key, default value)
84+
top_level_mapping_with_default = {
85+
"model_type": ("model_type", "transformer"),
86+
"hidden_act": ("activation", "silu"),
87+
"tie_word_embeddings": ("tied_embeddings", False),
88+
"max_seq_len": ("max_seq_len", 128_000),
89+
"max_position_embeddings": ("max_position_embeddings", 128_000),
90+
}
91+
92+
for key, new_key in config_mapping.items():
93+
if key in config:
94+
config[new_key] = config.pop(key)
95+
96+
for new_key, (key,
97+
default_value) in top_level_mapping_with_default.items():
98+
config[new_key] = config.pop(key, default_value)
99+
100+
return config
101+
102+
103+
def _remap_mistral_quantization_args(config: dict) -> dict:
104+
quantization = config.get("quantization", {})
105+
if quantization.get("qformat_weight") == "fp8_e4m3":
106+
# This maps to the FP8 static per-tensor quantization scheme
107+
quantization_config = {
108+
"quant_method": "fp8",
109+
"activation_scheme": "static"
110+
}
111+
elif quantization.get("quant_method") == "compressed-tensors":
112+
# Pass through the quantization config to compressed-tensors
113+
quantization_config = quantization
114+
else:
115+
raise ValueError(
116+
f"Found unknown quantization='{quantization}' in config")
117+
118+
config["quantization_config"] = quantization_config
119+
120+
return config

0 commit comments

Comments
 (0)