-
Notifications
You must be signed in to change notification settings - Fork 6k
[rfc][compile] compile method for DiffusionPipeline #11705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2027,6 +2027,40 @@ def _maybe_raise_error_if_group_offload_active( | |||||
return True | ||||||
return False | ||||||
|
||||||
def compile( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could make it a part of
For most cases, users want to just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you expect any improvements from compile on the pre/post transformer blocks? If not, then yeah, moving it inside makes sense to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this change, do we have to do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, to tackle those cases, we could do something like this: diffusers/src/diffusers/models/modeling_utils.py Line 1306 in 62cbde8
So, we include all the acceptable args and kwargs in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Let me try this out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like there are some advantages of
(2) and (3) are effectively consequences of (1). What do you think? @sayakpaul There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sayak suggested to phase this out. Have And then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we will eventually ship |
||||||
self, | ||||||
compile_regions_for_transformer: bool = True, | ||||||
transformer_module_name: str = "transformer", | ||||||
Comment on lines
+2032
to
+2033
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two will have to have changed if we proceed to have it inside |
||||||
other_modules_names: List[str] = [], | ||||||
**compile_kwargs, | ||||||
): | ||||||
transformer = getattr(self, transformer_module_name, None) | ||||||
if transformer is None: | ||||||
raise ValueError( | ||||||
f"{transformer_module_name} not found in the pipeline. Set `transformer_module_name` to the correct module name." | ||||||
) | ||||||
|
||||||
if compile_regions_for_transformer: | ||||||
compile_region_classes = getattr(transformer, "compile_region_classes", None) | ||||||
if compile_region_classes is None: | ||||||
raise ValueError( | ||||||
f"{transformer_module_name} does not have `compile_region_classes` attribute. Set `compile_regions_for_transformer` to False." | ||||||
) | ||||||
|
||||||
for submod in transformer.modules(): | ||||||
if isinstance(submod, compile_region_classes): | ||||||
submod.compile(**compile_kwargs) | ||||||
else: | ||||||
transformer.compile(**compile_kwargs) | ||||||
|
||||||
for module_name in other_modules_names: | ||||||
module = getattr(self, module_name, None) | ||||||
if module is None: | ||||||
raise ValueError( | ||||||
f"{module_name} not found in the pipeline. Set `other_modules_names` to the correct module names." | ||||||
) | ||||||
module.compile(**compile_kwargs) | ||||||
|
||||||
|
||||||
class StableDiffusionMixin: | ||||||
r""" | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a class-level attribute like this?
diffusers/src/diffusers/models/modeling_utils.py
Line 262 in 62cbde8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thats better.