-
Notifications
You must be signed in to change notification settings - Fork 6.1k
AutoModel #11115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AutoModel #11115
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I think this will be so much better as an UX improvement. Cc: @vladmandic you might like this :)
Let's add docs and tests after @DN6 reviews.
I anticipate this functionality to be quite helpful in a future version of diffusers. Does it make sense to add this PR to the future release section of the diffusers roadmap? |
Hi @ParagEkbote, I've added this to the roadmap, we will aim to get it included with the next release. |
an idea - to really make automodel as simple as possible, add default |
@vladmandic Do you have an example pseudo-code of your expected usage? |
sorry, i was too fast with my comment. if using auto-pipeline, we dont know what's the model ahead of time. one thing with automodel i'd love to see is the model type. |
src/diffusers/models/auto_model.py
Outdated
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) | ||
orig_class_name = config["_class_name"] | ||
|
||
model_cls = _get_task_class(AUTO_MODEL_MAPPING, orig_class_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than maintain this big mapping can't we use importlib? All the models are available in the main diffusers init right? Similar to how we do it here?
diffusers/src/diffusers/pipelines/pipeline_loading_utils.py
Lines 354 to 358 in c51b6bd
# else we just import it from the library. | |
library = importlib.import_module(library_name) | |
class_obj = getattr(library, class_name) | |
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, agree, probably don't need a mapping here, could use the same logic as in from_pretrained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we can get rid of mapping, everything will be supported automatically too I think
@@ -895,6 +910,21 @@ def from_pretrained(cls, *args, **kwargs): | |||
requires_backends(cls, ["torch"]) | |||
|
|||
|
|||
class TransformerTemporalModel(metaclass=DummyObject): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was missing, needed to be added for the test to pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work 👍🏽
What does this PR do?
Fixes #10059
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.