diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1e9e28471d89..a0f96390062a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keep_in_fp32_modules = None _skip_layerwise_casting_patterns = None _supports_group_offloading = True + _regions_for_compile = [] def __init__(self): super().__init__() @@ -1402,6 +1403,44 @@ def float(self, *args): else: return super().float(*args) + @wraps(torch.nn.Module.compile) + def compile(self, use_regional_compile: bool = True, *args, **kwargs): + """ """ + if use_regional_compile: + regions_for_compile = getattr(self, "_regions_for_compile", None) + + if not regions_for_compile: + logger.warning( + "_regions_for_compile attribute is empty. Using _no_split_modules to find compile regions." + ) + + regions_for_compile = getattr(self, "_no_split_modules", None) + + if not regions_for_compile: + logger.warning( + "Both _regions_for_compile and _no_split_modules attribute are empty. " + "Set _regions_for_compile for the model to benefit from regional compilation. " + "Falling back to full model compilation, which could have high first iteration " + "latency." + ) + super().compile(*args, **kwargs) + + has_compiled_region = False + for submod in self.modules(): + if submod.__class__.__name__ in regions_for_compile: + has_compiled_region = True + submod.compile(*args, **kwargs) + + if not has_compiled_region: + raise ValueError( + f"Regional compilation failed because {regions_for_compile} classes are not found in the model. " + "Either set them correctly, or set `use_regional_compile` to False while calling copmile, e.g. " + "pipe.transformer.compile(use_regional_compile=False) to fallback to full model compilation, " + "which could have high iteration latency." + ) + else: + super().compile(*args, **kwargs) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 541576b13b78..a0d2155cc025 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -227,6 +227,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _regions_for_compile = _no_split_modules @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index baa0ede4184e..fb693804c6fe 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _regions_for_compile = _no_split_modules @register_to_config def __init__(