Skip to content

[Feature] AutoModel can load components using model_index.json #11401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import os
from typing import Optional, Union

from huggingface_hub import constants, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args

from .. import pipelines
from ..configuration_utils import ConfigMixin
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates


class AutoModel(ConfigMixin):
Expand Down Expand Up @@ -156,12 +159,28 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
"subfolder": subfolder,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
try:
mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"}
mindex_kwargs["filename"] = "model_index.json"
config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs)
config = cls.load_config(config_path, **load_config_kwargs)
library, orig_class_name = config[subfolder]
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
component_name=subfolder,
cache_dir=constants.HF_HUB_CACHE,
)
except Exception:
# Fallback to loading the config from the config.json file and `diffusers` library
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
library = importlib.import_module("diffusers")
orig_class_name = config["_class_name"]
model_cls = getattr(library, orig_class_name, None)

library = importlib.import_module("diffusers")

model_cls = getattr(library, orig_class_name, None)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def maybe_raise_or_warn(


def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name, cache_dir
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
Expand Down