-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Feature] AutoModel can load components using model_index.json #11401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
e506314
85024b0
6a0d0be
d86b0f2
314b6cc
528e002
6e92f40
76ea98d
0e53ad0
f697631
5614a15
f6b6b42
4e5cac1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,9 +16,12 @@ | |||||
import os | ||||||
from typing import Optional, Union | ||||||
|
||||||
from huggingface_hub import constants, hf_hub_download | ||||||
from huggingface_hub.utils import validate_hf_hub_args | ||||||
|
||||||
from .. import pipelines | ||||||
from ..configuration_utils import ConfigMixin | ||||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates | ||||||
|
||||||
|
||||||
class AutoModel(ConfigMixin): | ||||||
|
@@ -156,12 +159,28 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi | |||||
"subfolder": subfolder, | ||||||
} | ||||||
|
||||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would avoid using exceptions for control flow and simplify this a bit load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"token": token,
"local_files_only": local_files_only,
"revision": revision,
}
library = None
orig_class_name = None
from diffusers import pipelines
# Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
except EntryNotFoundError as e:
logger.debug(e)
# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
load_config_kwargs.update({"subfolder": subfolder})
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
library = "diffusers"
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library),
) |
||||||
orig_class_name = config["_class_name"] | ||||||
try: | ||||||
mindex_kwargs = {k: v for k, v in load_config_kwargs.items() if k != "subfolder"} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move this section under a if subfolder is not None:
try:
... I think we are not supporting local path for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would be the complexity related to supporting local path here ? Shouldn't we also be able to load models from structure where there is no folder heirarchy using this code, Following is a working example try:
control_net = AutoModel.from_pretrained(
"ishan24/Sana_600M_1024px_ControlNet_diffusers",
torch_dtype=torch.float16
)
print(f"test passed!")
except Exception as e:
print(f"test failed: {e}") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for local path, you can append subfolder to the path as well ( we don't have to consider that for now) like unet = AutoModel.from_pretrained("ishan24/SDXL_diffusers/unet") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohk, so we dont want to support flat repo's like this to load models Because doing the following would mean that we dont support above case if subfolder is not None:
try:
... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh no we need to support flat repos, we need to all repos (we don't need to support some edge cases with local path when useer include subfolder directly in the file path, not as so basically here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it, made the change |
||||||
mindex_kwargs["filename"] = "model_index.json" | ||||||
config_path = hf_hub_download(pretrained_model_or_path, **mindex_kwargs) | ||||||
config = cls.load_config(config_path, **load_config_kwargs) | ||||||
library, orig_class_name = config[subfolder] | ||||||
model_cls, _ = get_class_obj_and_candidates( | ||||||
library_name=library, | ||||||
class_name=orig_class_name, | ||||||
importable_classes=ALL_IMPORTABLE_CLASSES, | ||||||
pipelines=pipelines, | ||||||
is_pipeline_module=hasattr(pipelines, library), | ||||||
component_name=subfolder, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
cache_dir=constants.HF_HUB_CACHE, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
let's jjust pass None here since we ddon't actually have a cached_dir so not very meaningful |
||||||
) | ||||||
except Exception: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should catch the specific Exception here instead of making it generic. This will help eliminate other side-effects. |
||||||
# Fallback to loading the config from the config.json file and `diffusers` library | ||||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||||||
library = importlib.import_module("diffusers") | ||||||
orig_class_name = config["_class_name"] | ||||||
model_cls = getattr(library, orig_class_name, None) | ||||||
|
||||||
library = importlib.import_module("diffusers") | ||||||
|
||||||
model_cls = getattr(library, orig_class_name, None) | ||||||
if model_cls is None: | ||||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -332,7 +332,7 @@ def maybe_raise_or_warn( | |||
|
||||
|
||||
def get_class_obj_and_candidates( | ||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None | ||||
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name, cache_dir | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. umm I think we should update the code itself so that component_name and cache_dir can be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just tried setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do you mean? the first line would already throw an error, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant when I changed to the original code:
It didn't error out for me 👀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, they should run fine even without it, because of this line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. well, it is wrapped inside a try/exception..... |
||||
): | ||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects""" | ||||
component_folder = os.path.join(cache_dir, component_name) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we move this import inside the
try
block, we should be able to get rid of the circular import problem: https://github.com/huggingface/diffusers/actions/runs/14787632114/job/41518909088?pr=11401#step:15:68