-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Feature] AutoModel can load components using model_index.json #11401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
e506314
85024b0
6a0d0be
d86b0f2
314b6cc
528e002
6e92f40
76ea98d
0e53ad0
f697631
5614a15
f6b6b42
4e5cac1
24f16f6
684384c
0fe68cd
2950372
67e3404
694b81c
3bf51cd
13420fb
af007ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,15 +12,18 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
# See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
import importlib | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Optional, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
from huggingface_hub.utils import validate_hf_hub_args | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..configuration_utils import ConfigMixin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ..utils import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger = logging.get_logger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
class AutoModel(ConfigMixin): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
config_name = "config.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -153,15 +156,60 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi | |||||||||||||||||||||||||||||||||||||||||||||||||||||
"token": token, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"local_files_only": local_files_only, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"revision": revision, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"subfolder": subfolder, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
orig_class_name = config["_class_name"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
library = importlib.import_module("diffusers") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
library = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
orig_class_name = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from diffusers import pipelines | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Always attempt to fetch model_index.json first | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cls.config_name = "model_index.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if subfolder is not None and subfolder in config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
library, orig_class_name = config[subfolder] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
load_config_kwargs.update({"subfolder": subfolder}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
except EnvironmentError as e: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.debug(e) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Unable to load from model_index.json so fallback to loading from config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if library is None and orig_class_name is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cls.config_name = "config.json" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if "_class_name" in config: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# If we find a class name in the config, we can try to load the model as a diffusers model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
orig_class_name = config["_class_name"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
library = "diffusers" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
load_config_kwargs.update({"subfolder": subfolder}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# If we don't find a class name in the config, we can try to load the model as a transformers model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if "architectures" in config and len(config["architectures"]) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if len(config["architectures"]) > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
orig_class_name = config["architectures"][0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
library = "transformers" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Couldn't find model associated with the config file at {pretrained_model_or_path}." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks @yiyixuxu, using for example, AutoModelForCausalLM - helps us import architecture like Let me know if we want to ignore all the other mappings and just use AutoModel for import ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @DN6 can you also let me know do you think here? my opinion is that we should not implement our own logic to load transformer models (that's not part of a diffusers repo), so ok to either dispatch to their AutoModel or throw a warning for not supporting but open to other thoughts :)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree @yiyixuxu. For now let's keep the logic simple. We can refactor later if we need to include things like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made the change ! |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_cls, _ = get_class_obj_and_candidates( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
library_name=library, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
class_name=orig_class_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
importable_classes=ALL_IMPORTABLE_CLASSES, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
pipelines=pipelines, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
is_pipeline_module=hasattr(pipelines, library), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
model_cls = getattr(library, orig_class_name, None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if model_cls is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
|
||
from transformers import AlbertForMaskedLM, CLIPTextModel | ||
|
||
from diffusers.models import AutoModel, UNet2DConditionModel | ||
|
||
|
||
class TestAutoModel(unittest.TestCase): | ||
@patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}]) | ||
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config): | ||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet") | ||
assert isinstance(model, UNet2DConditionModel) | ||
|
||
@patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"architectures": [ "CLIPTextModel"]}]) | ||
def test_load_from_config_transformers_with_subfolder(self, mock_load_config): | ||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") | ||
assert isinstance(model, CLIPTextModel) | ||
|
||
def test_load_from_config_without_subfolder(self): | ||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-albert") | ||
assert isinstance(model, AlbertForMaskedLM) | ||
|
||
def test_load_from_model_index(self): | ||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder") | ||
assert isinstance(model, CLIPTextModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would avoid using exceptions for control flow and simplify this a bit