Skip to content

[Feature] AutoModel can load components using model_index.json #11401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
40 changes: 34 additions & 6 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import os
from typing import Optional, Union

from huggingface_hub.utils import validate_hf_hub_args
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args

from ..configuration_utils import ConfigMixin
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates


class AutoModel(ConfigMixin):
Expand Down Expand Up @@ -156,12 +158,38 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
"subfolder": subfolder,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would avoid using exceptions for control flow and simplify this a bit

        load_config_kwargs = {
            "cache_dir": cache_dir,
            "force_download": force_download,
            "proxies": proxies,
            "token": token,
            "local_files_only": local_files_only,
            "revision": revision,
        }

        library = None
        orig_class_name = None
        from diffusers import pipelines

        # Always attempt to fetch model_index.json first
        try:
            cls.config_name = "model_index.json"
            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)

            if subfolder is not None and subfolder in config:
                library, orig_class_name = config[subfolder]

        except EntryNotFoundError as e:
            logger.debug(e)

        # Unable to load from model_index.json so fallback to loading from config
        if library is None and orig_class_name is None:
            cls.config_name = "config.json"
            load_config_kwargs.update({"subfolder": subfolder})

            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
            orig_class_name = config["_class_name"]
            library = "diffusers"

        model_cls, _ = get_class_obj_and_candidates(
            library_name=library,
            class_name=orig_class_name,
            importable_classes=ALL_IMPORTABLE_CLASSES,
            pipelines=pipelines,
            is_pipeline_module=hasattr(pipelines, library), 
        )

orig_class_name = config["_class_name"]
if subfolder is not None:
try:
# To avoid circular import problem.
from diffusers import pipelines

mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"}
mindex_kwargs["filename"] = "model_index.json"
config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs)
config = cls.load_config(config_path, **load_config_kwargs)
library, orig_class_name = config[subfolder]
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
)
except EntryNotFoundError:
# If `model_index.json` is not found, we load the model from the
# `config.json` file and `diffusers` library.
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
library = importlib.import_module("diffusers")
orig_class_name = config["_class_name"]
model_cls = getattr(library, orig_class_name, None)
else:
# If `subfolder` is not provided, we load the model from the
# `config.json` file and `diffusers` library.
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
library = importlib.import_module("diffusers")
orig_class_name = config["_class_name"]
model_cls = getattr(library, orig_class_name, None)

library = importlib.import_module("diffusers")

model_cls = getattr(library, orig_class_name, None)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,14 @@ def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None

if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)

class_obj = getattr(pipeline_module, class_name)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_models_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
from unittest.mock import patch

from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPTextModel

from diffusers.models import AutoModel, UNet2DConditionModel


class TestAutoModel(unittest.TestCase):
@patch("diffusers.models.auto_model.hf_hub_download", side_effect=EntryNotFoundError("File not found"))
def test_from_pretrained_falls_back_on_entry_error(self, mock_hf_hub_download):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
assert isinstance(model, UNet2DConditionModel)

def test_from_pretrained_loads_successfully(
self
):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)