Skip to content

[Feature] AutoModel can load components using model_index.json #11401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions src/diffusers/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import os
from typing import Optional, Union

from huggingface_hub.utils import validate_hf_hub_args

from ..configuration_utils import ConfigMixin
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
from ..utils import logging


logger = logging.get_logger(__name__)

class AutoModel(ConfigMixin):
config_name = "config.json"

Expand Down Expand Up @@ -153,15 +156,60 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
"token": token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would avoid using exceptions for control flow and simplify this a bit

        load_config_kwargs = {
            "cache_dir": cache_dir,
            "force_download": force_download,
            "proxies": proxies,
            "token": token,
            "local_files_only": local_files_only,
            "revision": revision,
        }

        library = None
        orig_class_name = None
        from diffusers import pipelines

        # Always attempt to fetch model_index.json first
        try:
            cls.config_name = "model_index.json"
            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)

            if subfolder is not None and subfolder in config:
                library, orig_class_name = config[subfolder]

        except EntryNotFoundError as e:
            logger.debug(e)

        # Unable to load from model_index.json so fallback to loading from config
        if library is None and orig_class_name is None:
            cls.config_name = "config.json"
            load_config_kwargs.update({"subfolder": subfolder})

            config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
            orig_class_name = config["_class_name"]
            library = "diffusers"

        model_cls, _ = get_class_obj_and_candidates(
            library_name=library,
            class_name=orig_class_name,
            importable_classes=ALL_IMPORTABLE_CLASSES,
            pipelines=pipelines,
            is_pipeline_module=hasattr(pipelines, library), 
        )

orig_class_name = config["_class_name"]

library = importlib.import_module("diffusers")
library = None
orig_class_name = None
from diffusers import pipelines

# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)

if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
load_config_kwargs.update({"subfolder": subfolder})

except EnvironmentError as e:
logger.debug(e)

# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)

if "_class_name" in config:
# If we find a class name in the config, we can try to load the model as a diffusers model
orig_class_name = config["_class_name"]
library = "diffusers"
load_config_kwargs.update({"subfolder": subfolder})
else:
# If we don't find a class name in the config, we can try to load the model as a transformers model
logger.warning(
f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model."
)
if "architectures" in config and len(config["architectures"]) > 0:
if len(config["architectures"]) > 1:
logger.warning(
f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}"
)
orig_class_name = config["architectures"][0]
library = "transformers"
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
else:
raise ValueError(
f"Couldn't find model associated with the config file at {pretrained_model_or_path}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
# If we don't find a class name in the config, we can try to load the model as a transformers model
logger.warning(
f"Doesn't look like a diffusers model. Loading {pretrained_model_or_path} as a transformer model."
)
if "architectures" in config and len(config["architectures"]) > 0:
if len(config["architectures"]) > 1:
logger.warning(
f"Found multiple architectures in {pretrained_model_or_path}. Using the first one: {config['architectures'][0]}"
)
orig_class_name = config["architectures"][0]
library = "transformers"
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
else:
raise ValueError(
f"Couldn't find model associated with the config file at {pretrained_model_or_path}."
)
elif "model_type" in config:
logger.warning(
f"Loading {config[model_type]} as a transformer model from {pretrained_model_or_path}."
)
from transformers import AutoModel
# we can use the AutoModel from tranformers here I think?
....
else:
raise ValueError(...)

Copy link
Contributor Author

@ishan-modi ishan-modi May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @yiyixuxu, using AutoModel will allow us to only load from this mapping There are many other mappings in the file that we might want to load from hence I am using the architectures.

for example,

AutoModelForCausalLM - helps us import architecture like MarianForCausalLM while AutoModel doesn't

Let me know if we want to ignore all the other mappings and just use AutoModel for import ?

Copy link
Collaborator

@yiyixuxu yiyixuxu May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @DN6 can you also let me know do you think here?

my opinion is that we should not implement our own logic to load transformer models (that's not part of a diffusers repo), so ok to either dispatch to their AutoModel or throw a warning for not supporting

but open to other thoughts :)

Copy link
Collaborator

@DN6 DN6 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree @yiyixuxu. For now let's keep the logic simple. We can refactor later if we need to include things like AutoModelForCausalLM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change !


model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
)

model_cls = getattr(library, orig_class_name, None)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,14 @@ def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name)
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None

if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)

class_obj = getattr(pipeline_module, class_name)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_models_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest
from unittest.mock import patch

from transformers import AlbertForMaskedLM, CLIPTextModel

from diffusers.models import AutoModel, UNet2DConditionModel


class TestAutoModel(unittest.TestCase):
@patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}])
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
assert isinstance(model, UNet2DConditionModel)

@patch("diffusers.models.AutoModel.load_config", side_effect=[EnvironmentError("File not found"), {"architectures": [ "CLIPTextModel"]}])
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)

def test_load_from_config_without_subfolder(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-albert")
assert isinstance(model, AlbertForMaskedLM)

def test_load_from_model_index(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)