Skip to content

[rfc][compile] compile method for DiffusionPipeline #11705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def __init__(

self.gradient_checkpointing = False

self.compile_region_classes = (FluxTransformerBlock, FluxSingleTransformerBlock)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a class-level attribute like this?

_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thats better.


@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ def __init__(

self.gradient_checkpointing = False

self.compile_region_classes = (WanTransformerBlock,)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
34 changes: 34 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,40 @@ def _maybe_raise_error_if_group_offload_active(
return True
return False

def compile(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make it a part of ModelMixin rather than DiffusionPipeline I believe:

class ModelMixin(torch.nn.Module, PushToHubMixin):

For most cases, users want to just do pipe.transformer.compile(). So, perhaps easier with this being added to ModelMixin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect any improvements from compile on the pre/post transformer blocks? If not, then yeah, moving it inside makes sense to me.

Copy link
Contributor Author

@anijain2305 anijain2305 Jun 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, do we have to do pipe.transformer.compile()? Hmm, this could be confusing because torch.nn.Module also has a method named compile. If we move in the ModelMixin, we might want a different method name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, to tackle those cases, we could do something like this:

@wraps(torch.nn.Module.cuda)

So, we include all the acceptable args and kwargs in the ModelMixin compile()` method but additionally include the regional/hierarchical compilation related kwargs. Would that work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Let me try this out.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like there are some advantages of Pipeline.compile:

  1. It captures the user intent better, e.g., "I just want to accelerate this pipeline" (the word "compile" is also great because it signals "expect some initial delay").
  2. It's more future proof (diffuser devs get to tweak things to accommodate changes from diffusers/pytorch versions, or even future hardware and pipeline architecture/perf-characteristics).
  3. It's more user friendly (small things like .transformer does make a noticeable effect on user experience).

(2) and (3) are effectively consequences of (1). What do you think? @sayakpaul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sayak suggested to phase this out.

Have pipeline.transformer.compile as the first PR

And then pipeline.compile in future which can internally call pipeline.transformer.compile. Its actually better this way because then the compile region classes is also hidden from the pipe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will eventually ship pipe.compile() :)

self,
compile_regions_for_transformer: bool = True,
transformer_module_name: str = "transformer",
Comment on lines +2032 to +2033
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two will have to have changed if we proceed to have it inside ModelMixin.

other_modules_names: List[str] = [],
**compile_kwargs,
):
transformer = getattr(self, transformer_module_name, None)
if transformer is None:
raise ValueError(
f"{transformer_module_name} not found in the pipeline. Set `transformer_module_name` to the correct module name."
)

if compile_regions_for_transformer:
compile_region_classes = getattr(transformer, "compile_region_classes", None)
if compile_region_classes is None:
raise ValueError(
f"{transformer_module_name} does not have `compile_region_classes` attribute. Set `compile_regions_for_transformer` to False."
)

for submod in transformer.modules():
if isinstance(submod, compile_region_classes):
submod.compile(**compile_kwargs)
else:
transformer.compile(**compile_kwargs)

for module_name in other_modules_names:
module = getattr(self, module_name, None)
if module is None:
raise ValueError(
f"{module_name} not found in the pipeline. Set `other_modules_names` to the correct module names."
)
module.compile(**compile_kwargs)


class StableDiffusionMixin:
r"""
Expand Down