-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Model] Re-add the implicit conversion feature for as_seq_cls_model #20930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
0090c4b
84ffda3
0542acb
0a2be97
980f877
b537c85
c917036
32f0e50
898ccc6
5213862
e4c026a
6afa49e
fb6485f
d28bddb
e0303da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -22,7 +22,8 @@ | |||||||||
QuantizationConfig, QuantizeMethodBase) | ||||||||||
from vllm.model_executor.models import ModelRegistry | ||||||||||
from vllm.model_executor.models.adapters import (as_embedding_model, | ||||||||||
as_reward_model) | ||||||||||
as_reward_model, | ||||||||||
as_seq_cls_model) | ||||||||||
from vllm.model_executor.models.interfaces import SupportsQuant | ||||||||||
from vllm.utils import is_pin_memory_available | ||||||||||
|
||||||||||
|
@@ -238,22 +239,41 @@ def get_model_architecture( | |||||||||
vllm_supported_archs = ModelRegistry.get_supported_archs() | ||||||||||
vllm_not_supported = not any(arch in vllm_supported_archs | ||||||||||
for arch in architectures) | ||||||||||
|
||||||||||
if vllm_not_supported: | ||||||||||
# try automatic conversion in adapters.py | ||||||||||
for arch in architectures: | ||||||||||
if not arch.endswith("ForSequenceClassification"): | ||||||||||
DarkLight1337 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
continue | ||||||||||
|
||||||||||
assert model_config.task in ["auto", "classify"] | ||||||||||
model_config.task = "classify" | ||||||||||
new_arch = arch.replace("ForSequenceClassification", "ForCausalLM") | ||||||||||
vllm_supported = not any(arch in vllm_supported_archs | ||||||||||
for arch in architectures) | ||||||||||
if vllm_supported: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is something strange here, vllm_supported has the exact same definition as vllm_not_supported.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is indeed a problem with this logic, thank you for pointing it out. |
||||||||||
architectures = [new_arch] | ||||||||||
vllm_not_supported = False | ||||||||||
break | ||||||||||
|
||||||||||
if (model_config.model_impl == ModelImpl.TRANSFORMERS or | ||||||||||
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): | ||||||||||
architectures = resolve_transformers_arch(model_config, architectures) | ||||||||||
logger.debug_once("Resolve transformers arch %s", str(architectures)) | ||||||||||
elif (model_config.quantization is not None | ||||||||||
and model_config.quantization not in mixtral_supported | ||||||||||
and "MixtralForCausalLM" in architectures): | ||||||||||
architectures = ["QuantMixtralForCausalLM"] | ||||||||||
|
||||||||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures) | ||||||||||
if model_config.task == "embed": | ||||||||||
logger.debug_once("Automatic conversion using `as_embedding_model`.") | ||||||||||
model_cls = as_embedding_model(model_cls) | ||||||||||
elif model_config.task == "classify": | ||||||||||
# Cannot automatically run as_seq_cls_model, | ||||||||||
# otherwise it will cause a circular reference on is_cross_encoder_model | ||||||||||
pass | ||||||||||
logger.debug_once("Automatic conversion using `as_seq_cls_model`.") | ||||||||||
model_cls = as_seq_cls_model(model_cls) | ||||||||||
elif model_config.task == "reward": | ||||||||||
logger.debug_once("Automatic conversion using `as_reward_model`.") | ||||||||||
model_cls = as_reward_model(model_cls) | ||||||||||
|
||||||||||
return model_cls, arch | ||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.