Skip to content

Commit 68d5ec7

Browse files
hmellorminpeter
authored andcommitted
Improve Transformers backend model loading QoL (vllm-project#17039)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent a133146 commit 68d5ec7

File tree

1 file changed

+8
-5
lines changed
  • vllm/model_executor/model_loader

1 file changed

+8
-5
lines changed

vllm/model_executor/model_loader/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ def resolve_transformers_arch(model_config: ModelConfig,
5555
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
5656
# },
5757
auto_modules = {
58-
name: get_class_from_dynamic_module(module, model_config.model)
58+
name:
59+
get_class_from_dynamic_module(module,
60+
model_config.model,
61+
revision=model_config.revision)
5962
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
6063
}
6164
custom_model_module = auto_modules.get("AutoModel")
@@ -97,10 +100,10 @@ def get_model_architecture(
97100
architectures = ["QuantMixtralForCausalLM"]
98101

99102
vllm_supported_archs = ModelRegistry.get_supported_archs()
100-
is_vllm_supported = any(arch in vllm_supported_archs
101-
for arch in architectures)
102-
if (not is_vllm_supported
103-
or model_config.model_impl == ModelImpl.TRANSFORMERS):
103+
vllm_not_supported = not any(arch in vllm_supported_archs
104+
for arch in architectures)
105+
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
106+
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
104107
architectures = resolve_transformers_arch(model_config, architectures)
105108

106109
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)

0 commit comments

Comments
 (0)