Skip to content

[Feature] AutoModel can load components using model_index.json #11401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
29 changes: 24 additions & 5 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import os
from typing import Optional, Union

from huggingface_hub import constants, hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args

from .. import pipelines
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we move this import inside the try block, we should be able to get rid of the circular import problem: https://github.com/huggingface/diffusers/actions/runs/14787632114/job/41518909088?pr=11401#step:15:68

from ..configuration_utils import ConfigMixin
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates


class AutoModel(ConfigMixin):
Expand Down Expand Up @@ -156,12 +159,28 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
"subfolder": subfolder,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would avoid using exceptions for control flow and simplify this a bit

        load_config_kwargs = {
            "cache_dir": cache_dir,
            "force_download": force_download,
            "proxies": proxies,
            "token": token,
            "local_files_only": local_files_only,
            "revision": revision,
        }

        library = None
        orig_class_name = None
        from diffusers import pipelines

        # Always attempt to fetch model_index.json first
        try:
            cls.config_name = "model_index.json"
            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)

            if subfolder is not None and subfolder in config:
                library, orig_class_name = config[subfolder]

        except EntryNotFoundError as e:
            logger.debug(e)

        # Unable to load from model_index.json so fallback to loading from config
        if library is None and orig_class_name is None:
            cls.config_name = "config.json"
            load_config_kwargs.update({"subfolder": subfolder})

            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
            orig_class_name = config["_class_name"]
            library = "diffusers"

        model_cls, _ = get_class_obj_and_candidates(
            library_name=library,
            class_name=orig_class_name,
            importable_classes=ALL_IMPORTABLE_CLASSES,
            pipelines=pipelines,
            is_pipeline_module=hasattr(pipelines, library), 
        )

orig_class_name = config["_class_name"]
try:
mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this section under a if subfolder is not None:

if subfolder is not None:
     try: 
          ...

I think we are not supporting local path for pretrained_model_or_path here, are we? it would be a bit more complex if we do

Copy link
Contributor Author

@ishan-modi ishan-modi May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the complexity related to supporting local path here ?

Shouldn't we also be able to load models from structure where there is no folder heirarchy using this code, Following is a working example

try:
    control_net = AutoModel.from_pretrained(
        "ishan24/Sana_600M_1024px_ControlNet_diffusers",
        torch_dtype=torch.float16
    )
    print(f"test passed!")
except Exception as e:
    print(f"test failed: {e}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for local path, you can append subfolder to the path as well ( we don't have to consider that for now)

like unet = AutoModel.from_pretrained("ishan24/SDXL_diffusers/unet")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohk, so we dont want to support flat repo's like this to load models

Because doing the following would mean that we dont support above case

if subfolder is not None:
     try: 
          ...

Copy link
Collaborator

@yiyixuxu yiyixuxu May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh no we need to support flat repos, we need to all repos (we don't need to support some edge cases with local path when useer include subfolder directly in the file path, not as subfolder argument)

so basically here
if subfolder is not None -> we try model_index.json approach first, if that fails, we try the config.json approach
if subfolder is None -> we try the config.json approach directly(I think it is not meaningful to use model_index.json when subfolder is None because you need use subfolder as name to locate the info there, no?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, made the change

mindex_kwargs["filename"] = "model_index.json"
config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs)
config = cls.load_config(config_path, **load_config_kwargs)
library, orig_class_name = config[subfolder]
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
component_name=subfolder,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
component_name=subfolder,
component_name=None,

cache_dir=constants.HF_HUB_CACHE,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cache_dir=constants.HF_HUB_CACHE,
cache_dir=None,

let's jjust pass None here since we ddon't actually have a cached_dir so not very meaningful
and update the other function so it works with cache_dir=None and component_name=NOne

)
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should catch the specific Exception here instead of making it generic. This will help eliminate other side-effects.

# Fallback to loading the config from the config.json file and `diffusers` library
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
library = importlib.import_module("diffusers")
orig_class_name = config["_class_name"]
model_cls = getattr(library, orig_class_name, None)

library = importlib.import_module("diffusers")

model_cls = getattr(library, orig_class_name, None)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def maybe_raise_or_warn(


def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name, cache_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm I think we should update the code itself so that component_name and cache_dir can be None - they are only needed for custom code, which we don't support yet with AutoModel

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried setting component_name and cache_dir to None values and ran the tests from #11401 (comment) and they ran fine for me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? the first line would already throw an error, no?
os.path.join(None,None)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant when I changed to the original code:

library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None

It didn't error out for me 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, they should run fine even without it, because of this line

Copy link
Collaborator

@yiyixuxu yiyixuxu May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it is wrapped inside a try/exception.....
but what I meant we should update this function so that componenet_name and cache_dir aree optional argument (it is meant to be, they aree only needed for custom code)

):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
Expand Down
Loading