|
7 | 7 | import time
|
8 | 8 | from functools import cache, partial
|
9 | 9 | from pathlib import Path
|
10 |
| -from typing import Any, Callable, Literal, Optional, TypeVar, Union |
| 10 | +from typing import Any, Callable, Optional, TypeVar, Union |
11 | 11 |
|
12 | 12 | import huggingface_hub
|
13 | 13 | from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
|
42 | 42 | SkyworkR1VChatConfig, SolarConfig,
|
43 | 43 | Telechat2Config, UltravoxConfig)
|
44 | 44 | # yapf: enable
|
| 45 | +from vllm.transformers_utils.configs.mistral import adapt_config_dict |
45 | 46 | from vllm.transformers_utils.utils import check_gguf_file
|
46 | 47 | from vllm.utils import resolve_obj_by_qualname
|
47 | 48 |
|
@@ -394,7 +395,16 @@ def get_config(
|
394 | 395 | config = _maybe_remap_hf_config_attrs(config)
|
395 | 396 |
|
396 | 397 | elif config_format == ConfigFormat.MISTRAL:
|
397 |
| - config = load_params_config(model, revision, **kwargs) |
| 398 | + # This function loads a params.json config which |
| 399 | + # should be used when loading models in mistral format |
| 400 | + config_dict = _download_mistral_config_file(model, revision) |
| 401 | + if (max_position_embeddings := |
| 402 | + config_dict.get("max_position_embeddings")) is None: |
| 403 | + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( |
| 404 | + model, revision, **kwargs) |
| 405 | + config_dict["max_position_embeddings"] = max_position_embeddings |
| 406 | + |
| 407 | + config = adapt_config_dict(config_dict) |
398 | 408 | else:
|
399 | 409 | supported_formats = [
|
400 | 410 | fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
@@ -693,117 +703,6 @@ def _reduce_config(config: VllmConfig):
|
693 | 703 | exc_info=e)
|
694 | 704 |
|
695 | 705 |
|
696 |
| -def load_params_config(model: Union[str, Path], revision: Optional[str], |
697 |
| - **kwargs) -> PretrainedConfig: |
698 |
| - # This function loads a params.json config which |
699 |
| - # should be used when loading models in mistral format |
700 |
| - |
701 |
| - config_file_name = "params.json" |
702 |
| - |
703 |
| - config_dict = get_hf_file_to_dict(config_file_name, model, revision) |
704 |
| - if config_dict is None: |
705 |
| - raise ValueError( |
706 |
| - f"Failed to load mistral '{config_file_name}' config for model " |
707 |
| - f"{model}. Please check if the model is a mistral-format model " |
708 |
| - f"and if the config file exists.") |
709 |
| - assert isinstance(config_dict, dict) |
710 |
| - |
711 |
| - config_mapping = { |
712 |
| - "dim": "hidden_size", |
713 |
| - "norm_eps": "rms_norm_eps", |
714 |
| - "n_kv_heads": "num_key_value_heads", |
715 |
| - "n_layers": "num_hidden_layers", |
716 |
| - "n_heads": "num_attention_heads", |
717 |
| - "hidden_dim": "intermediate_size", |
718 |
| - } |
719 |
| - |
720 |
| - def recurse_elems(elem: Any): |
721 |
| - if isinstance(elem, dict): |
722 |
| - config_dict = {} |
723 |
| - for key, value in elem.items(): |
724 |
| - key = config_mapping.get(key, key) |
725 |
| - config_dict[key] = recurse_elems(value) |
726 |
| - |
727 |
| - return config_dict |
728 |
| - else: |
729 |
| - return elem |
730 |
| - |
731 |
| - config_dict["model_type"] = config_dict.get("model_type", "transformer") |
732 |
| - config_dict["hidden_act"] = config_dict.get("activation", "silu") |
733 |
| - config_dict["tie_word_embeddings"] = config_dict.get( |
734 |
| - "tie_embeddings", False) |
735 |
| - |
736 |
| - if config_dict.get("max_position_embeddings") is None: |
737 |
| - max_position_embeddings = 128_000 |
738 |
| - try: |
739 |
| - trust_remote_code_val = kwargs.get("trust_remote_code", False) |
740 |
| - hf_config = get_config(model=model, |
741 |
| - trust_remote_code=trust_remote_code_val, |
742 |
| - revision=revision, |
743 |
| - config_format=ConfigFormat.HF) |
744 |
| - if hf_value := hf_config.get_text_config().max_position_embeddings: |
745 |
| - max_position_embeddings = hf_value |
746 |
| - except Exception as e: |
747 |
| - logger.warning( |
748 |
| - "The params.json file is missing 'max_position_embeddings'" |
749 |
| - " and could not get a value from the HF config." |
750 |
| - " Defaulting to 128000", |
751 |
| - exc_info=e) |
752 |
| - config_dict["max_position_embeddings"] = max_position_embeddings |
753 |
| - |
754 |
| - if config_dict.get("quantization") is not None: |
755 |
| - quantization = config_dict.get("quantization", {}) |
756 |
| - if quantization.get("qformat_weight") == "fp8_e4m3": |
757 |
| - # This maps to the FP8 static per-tensor quantization scheme |
758 |
| - quantization_config = { |
759 |
| - "quant_method": "fp8", |
760 |
| - "activation_scheme": "static" |
761 |
| - } |
762 |
| - elif quantization.get("quant_method") == "compressed-tensors": |
763 |
| - # Pass through the quantization config to compressed-tensors |
764 |
| - quantization_config = quantization |
765 |
| - else: |
766 |
| - raise ValueError( |
767 |
| - f"Found unknown quantization='{quantization}' in config") |
768 |
| - |
769 |
| - config_dict["quantization_config"] = quantization_config |
770 |
| - |
771 |
| - config_type: Literal["text", |
772 |
| - "multimodal"] = "multimodal" if config_dict.get( |
773 |
| - "vision_encoder") is not None else "text" |
774 |
| - |
775 |
| - if config_dict.get("moe") is not None: |
776 |
| - config_dict["architectures"] = ["MixtralForCausalLM"] |
777 |
| - else: |
778 |
| - config_dict["architectures"] = ["MistralForCausalLM"] |
779 |
| - |
780 |
| - if config_type == "multimodal": |
781 |
| - multimodal_config = config_dict.pop("vision_encoder") |
782 |
| - quantization_config = config_dict.get("quantization_config", {}) |
783 |
| - |
784 |
| - config_dict = { |
785 |
| - "text_config": config_dict, |
786 |
| - "vision_config": multimodal_config |
787 |
| - } |
788 |
| - config_dict["architectures"] = ["PixtralForConditionalGeneration"] |
789 |
| - config_dict["model_type"] = "pixtral" |
790 |
| - if quantization_config: |
791 |
| - config_dict["quantization_config"] = quantization_config |
792 |
| - |
793 |
| - config_dict.update(kwargs) |
794 |
| - |
795 |
| - config_dict = recurse_elems(config_dict) |
796 |
| - |
797 |
| - # transform to HF config format |
798 |
| - if config_type == "multimodal": |
799 |
| - config_dict["text_config"] = PretrainedConfig( |
800 |
| - **config_dict["text_config"]) |
801 |
| - config_dict["vision_config"] = PretrainedConfig( |
802 |
| - **config_dict["vision_config"]) |
803 |
| - |
804 |
| - return PretrainedConfig(**config_dict) |
805 |
| - |
806 |
| - |
807 | 706 | def get_hf_image_processor_config(
|
808 | 707 | model: Union[str, Path],
|
809 | 708 | hf_token: Optional[Union[bool, str]] = None,
|
@@ -920,3 +819,35 @@ def try_get_tokenizer_config(
|
920 | 819 | )
|
921 | 820 | except Exception:
|
922 | 821 | return None
|
| 822 | + |
| 823 | + |
| 824 | +def _download_mistral_config_file(model, revision) -> dict: |
| 825 | + config_file_name = "params.json" |
| 826 | + config_dict = get_hf_file_to_dict(config_file_name, model, revision) |
| 827 | + if config_dict is None: |
| 828 | + raise ValueError( |
| 829 | + f"Failed to load mistral '{config_file_name}' config for model " |
| 830 | + f"{model}. Please check if the model is a mistral-format model " |
| 831 | + f"and if the config file exists.") |
| 832 | + assert isinstance(config_dict, dict) |
| 833 | + return config_dict |
| 834 | + |
| 835 | + |
| 836 | +def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: |
| 837 | + max_position_embeddings = 128_000 |
| 838 | + try: |
| 839 | + trust_remote_code_val = kwargs.get("trust_remote_code", False) |
| 840 | + hf_config = get_config(model=model, |
| 841 | + trust_remote_code=trust_remote_code_val, |
| 842 | + revision=revision, |
| 843 | + config_format=ConfigFormat.HF) |
| 844 | + if hf_value := hf_config.get_text_config().max_position_embeddings: |
| 845 | + max_position_embeddings = hf_value |
| 846 | + except Exception as e: |
| 847 | + logger.warning( |
| 848 | + "The params.json file is missing 'max_position_embeddings'" |
| 849 | + " and could not get a value from the HF config." |
| 850 | + " Defaulting to 128000", |
| 851 | + exc_info=e) |
| 852 | + |
| 853 | + return max_position_embeddings |
0 commit comments