From 46bfc5871a2f8b18ae3dc50f23151d459bdd60ae Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Fri, 6 Jun 2025 07:20:25 +0000 Subject: [PATCH 01/10] commit 1 Signed-off-by: YAO Matrix --- .../pipelines/audioldm2/pipeline_audioldm2.py | 6 ++---- src/diffusers/pipelines/consisid/consisid_utils.py | 2 +- .../pipelines/controlnet/pipeline_controlnet.py | 4 ++-- .../controlnet/pipeline_controlnet_img2img.py | 4 ++-- .../controlnet/pipeline_controlnet_inpaint.py | 4 ++-- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 4 ++-- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 6 +++--- .../pipeline_controlnet_union_inpaint_sd_xl.py | 4 ++-- .../pipeline_controlnet_union_sd_xl_img2img.py | 6 +++--- .../pipelines/controlnet_xs/pipeline_controlnet_xs.py | 6 +++--- .../kandinsky/pipeline_kandinsky_combined.py | 6 +++--- .../kandinsky2_2/pipeline_kandinsky2_2_combined.py | 8 ++++---- .../pipelines/kolors/pipeline_kolors_img2img.py | 4 ++-- src/diffusers/pipelines/musicldm/pipeline_musicldm.py | 11 ++++++----- .../pipelines/pag/pipeline_pag_controlnet_sd.py | 6 +++--- .../pag/pipeline_pag_controlnet_sd_inpaint.py | 4 ++-- .../pipelines/pag/pipeline_pag_controlnet_sd_xl.py | 2 +- .../pag/pipeline_pag_controlnet_sd_xl_img2img.py | 6 +++--- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 6 +++++- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 4 ++-- src/diffusers/pipelines/pipeline_utils.py | 10 +++------- src/diffusers/pipelines/sana/pipeline_sana_sprint.py | 7 +++++-- .../pipelines/sana/pipeline_sana_sprint_img2img.py | 6 ++++-- .../pipeline_stable_cascade_combined.py | 4 ++-- .../pipelines/stable_diffusion/convert_from_ckpt.py | 5 +++-- .../pipeline_stable_diffusion_xl_img2img.py | 4 ++-- .../pipeline_text_to_video_zero.py | 4 ++-- .../wuerstchen/pipeline_wuerstchen_combined.py | 4 ++-- src/diffusers/utils/peft_utils.py | 4 ++-- src/diffusers/utils/torch_utils.py | 6 ++++++ 31 files changed, 85 insertions(+), 74 deletions(-) diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index dd70fb82fff4..c34415595507 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -41,7 +41,7 @@ replace_example_docstring, ) from ...utils.import_utils import is_transformers_version -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel @@ -267,9 +267,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache(device.type) model_sequence = [ self.text_encoder.text_model, diff --git a/src/diffusers/pipelines/consisid/consisid_utils.py b/src/diffusers/pipelines/consisid/consisid_utils.py index 23811a4986e3..521d4d787e54 100644 --- a/src/diffusers/pipelines/consisid/consisid_utils.py +++ b/src/diffusers/pipelines/consisid/consisid_utils.py @@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype): Parameters: - model_path: Path to the directory containing model files. - - device: The device (e.g., 'cuda', 'cpu') where models will be loaded. + - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded. - dtype: Data type (e.g., torch.float32) for model inference. Returns: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index e14c51ab94ce..12aac7ff0d8a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -37,7 +37,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1339,7 +1339,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2102b34d5ffb..f3cb34d96d3a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1311,7 +1311,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 3fc1206ba5f6..bff14b3fbeb5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1500,7 +1500,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 51ddd997142b..b38b60a20ae1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -51,7 +51,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1858,7 +1858,7 @@ def denoising_value_valid(dnv): if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 8aa47fd4277a..442cf11e243a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1465,7 +1465,7 @@ def __call__( # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if torch.cuda.is_available() and (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 52e76a220454..08c6d02f27f9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -53,7 +53,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -921,7 +921,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) @@ -1632,7 +1632,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index fc160573ba4e..22751365a9c2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -51,7 +51,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1766,7 +1766,7 @@ def denoising_value_valid(dnv): if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index bd98b7975082..3bcaf228abf9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -53,7 +53,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -876,7 +876,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) @@ -1574,7 +1574,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index ecc7cd4ad726..160cab754bc7 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -851,7 +851,7 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if is_controlnet_compiled and is_torch_higher_equal_2_1: + if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -900,7 +900,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py index f6d445a4869e..25f90d58a423 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py @@ -193,7 +193,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a @@ -411,7 +411,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -652,7 +652,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py index 2ebd995eb58d..c104e4bcced0 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -179,7 +179,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -407,7 +407,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -417,7 +417,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -656,7 +656,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 4b0f8cdb62d3..62132f2ea9d9 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -25,7 +25,7 @@ from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import KolorsPipelineOutput from .text_encoder import ChatGLMModel @@ -618,7 +618,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py index d28864298d42..c0455c5b4198 100644 --- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py +++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -35,7 +35,7 @@ logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin @@ -396,8 +396,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its `forward` + method is called, and the model remains in accelerator until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): @@ -405,11 +405,12 @@ def enable_model_cpu_offload(self, gpu_id=0): else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - device = torch.device(f"cuda:{gpu_id}") + device_type = get_device() + device = torch.device(f"{device_type}:{gpu_id}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = [ self.text_encoder.text_model, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index 205868762935..e1e7477c3574 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1228,7 +1228,7 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if torch.cuda.is_available() and (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) @@ -1309,7 +1309,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py index f8a58a665475..5c08e33bb2e2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -37,7 +37,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1521,7 +1521,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 259b8939ce99..a5c71ac72c02 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1498,7 +1498,7 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if torch.cuda.is_available() and (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index dc20ea95cdbe..3b4e608153e3 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -52,7 +52,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin @@ -926,7 +926,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) @@ -1648,7 +1648,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 02ede6c3d6c6..a3558f3a6bdc 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -35,6 +35,7 @@ logging, replace_example_docstring, ) +from ...utils.import_utils import is_torch_version from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pixart_alpha.pipeline_pixart_alpha import ( @@ -917,9 +918,12 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfM +emoryError try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 92eb45a72e7f..74009c581a42 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -49,7 +49,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from .pag_utils import PAGMixin @@ -716,7 +716,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8184573b02ef..f50f932be733 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -67,7 +67,7 @@ numpy_to_pil, ) from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card -from ..utils.torch_utils import get_device, is_compiled_module +from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module if is_torch_npu_available(): @@ -1167,9 +1167,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self._offload_device = device self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache(device.type) all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} @@ -1280,9 +1278,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - device_mod = getattr(torch, self.device.type, None) - if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + empty_device_cache(self.device.type) for name, model in self.components.items(): if not isinstance(model, torch.nn.Module): diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 46edbf7c33ef..c17d12427b03 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -38,7 +38,8 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, randn_tensor +from ...utils.import_utils import is_torch_version from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput @@ -864,9 +865,11 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py index f71b980ffc84..090d30382798 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -39,7 +39,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import get_device, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput @@ -952,9 +952,11 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py index a7c273fbe1e9..c387ab7b0501 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -125,7 +125,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -135,7 +135,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 568ae7f7d671..6c0221d2092a 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -53,6 +53,7 @@ ) from ...utils import is_accelerate_available, logging from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT +from ...utils.torch_utils import get_device from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline @@ -1272,7 +1273,7 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") else: if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) else: checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) @@ -1842,7 +1843,7 @@ def download_controlnet_from_original_ckpt( checkpoint[key] = f.get_tensor(key) else: if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() checkpoint = torch.load(checkpoint_path, map_location=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index dfbbdaeac5a3..d4248a14e560 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -50,7 +50,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput @@ -704,7 +704,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - torch.cuda.empty_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 7a8db4a8e522..59b7982087e1 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -23,7 +23,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import empty_accelerator_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionSafetyChecker @@ -758,7 +758,7 @@ def __call__( # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") - torch.cuda.empty_cache() + empty_accelerator_cache() if output_type == "latent": image = latents diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py index e4756efbac92..e531389518f0 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py @@ -112,7 +112,7 @@ def __init__( def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) - def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -122,7 +122,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device) - def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): + def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None): r""" Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗 Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 7d0a6faa7afb..dae2a257d88c 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -22,6 +22,7 @@ from packaging import version from .import_utils import is_peft_available, is_torch_available +from .torch_utils import empty_device_cache if is_torch_available(): @@ -95,8 +96,7 @@ def recurse_remove_peft_layers(model): setattr(model, name, new_module) del module - if torch.cuda.is_available(): - torch.cuda.empty_cache() + empty_device_cache() return model diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index bb5674092d09..ccf2cc5cc7cf 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -172,3 +172,9 @@ def get_device(): return "xpu" else: return "cpu" + +def empty_device_cache(device_type: Optional[str] = None) + if device_type is None: + device_type = get_device() + device_mod = getattr(torch, device_type, torch.cuda) + device_mod.empty_cache() From cc2f3f87f93906293caaef1a1c716d7ccfd72cc3 Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Fri, 6 Jun 2025 07:38:27 +0000 Subject: [PATCH 02/10] patch 2 Signed-off-by: YAO Matrix --- .../pipelines/controlnet/pipeline_controlnet_img2img.py | 4 ++-- .../controlnet/pipeline_controlnet_union_inpaint_sd_xl.py | 4 ++-- .../pipelines/controlnet_xs/pipeline_controlnet_xs.py | 4 ++-- src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py | 4 ++-- src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py | 4 ++-- .../pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py | 4 ++-- src/diffusers/pipelines/sana/pipeline_sana.py | 6 +++++- src/diffusers/pipelines/sana/pipeline_sana_controlnet.py | 6 +++++- .../text_to_video_synthesis/pipeline_text_to_video_zero.py | 4 ++-- 9 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index f3cb34d96d3a..1fdc285c26e9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1311,7 +1311,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - empty_accelerator_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index 22751365a9c2..a0b59dd0f851 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -51,7 +51,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput @@ -1766,7 +1766,7 @@ def denoising_value_valid(dnv): if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - empty_accelerator_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 160cab754bc7..f04cd947c731 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -900,7 +900,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - empty_accelerator_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 62132f2ea9d9..93dcd36a7021 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -25,7 +25,7 @@ from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import empty_accelerator_cache, randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import KolorsPipelineOutput from .text_encoder import ChatGLMModel @@ -618,7 +618,7 @@ def prepare_latents( # Offload text encoder if `enable_model_cpu_offload` was enabled if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.text_encoder_2.to("cpu") - empty_accelerator_cache() + empty_device_cache() image = image.to(device=device, dtype=dtype) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index e1e7477c3574..dc1c466327d4 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -36,7 +36,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, is_torch_version, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1309,7 +1309,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - empty_accelerator_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py index 5c08e33bb2e2..dbe5532f6f47 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -37,7 +37,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import empty_accelerator_cache, is_compiled_module, randn_tensor +from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -1521,7 +1521,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - empty_accelerator_cache() + empty_devcie_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 6c2c1fb0a8f1..78abd55d3fdf 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -38,6 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) +from ...utils.import_utils import is_torch_version from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( @@ -982,9 +983,12 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfM +emoryError try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 593e0895e43d..a957b455924a 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -38,6 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) +from ...utils.import_utils import is_torch_version from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( @@ -1078,9 +1079,12 @@ def __call__( image = latents else: latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfM +emoryError try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - except torch.cuda.OutOfMemoryError as e: + except oom_error as e: warnings.warn( f"{e}. \n" f"Try to use VAE tiling for large images. For example: \n" diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 59b7982087e1..617f8ebbf1ea 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -23,7 +23,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import empty_accelerator_cache, randn_tensor +from ...utils.torch_utils import empty_device_cache, randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionSafetyChecker @@ -758,7 +758,7 @@ def __call__( # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") - empty_accelerator_cache() + empty_device_cache() if output_type == "latent": image = latents From 7184867ec20cf495855b77df223b553c3f345b4b Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 6 Jun 2025 15:48:13 +0800 Subject: [PATCH 03/10] Update pipeline_pag_sana.py --- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index a3558f3a6bdc..624b1e18e03a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -35,8 +35,7 @@ logging, replace_example_docstring, ) -from ...utils.import_utils import is_torch_version -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, From 6acd1e8845b0dbf35a6385242024835cec5e641f Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 6 Jun 2025 15:49:48 +0800 Subject: [PATCH 04/10] Update pipeline_sana.py --- src/diffusers/pipelines/sana/pipeline_sana.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 78abd55d3fdf..b40d48d46022 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -38,8 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.import_utils import is_torch_version -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, From f08c8a586d7b503bdd1a475b7d7ee767745e7c48 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 6 Jun 2025 15:50:49 +0800 Subject: [PATCH 05/10] Update pipeline_sana_controlnet.py --- src/diffusers/pipelines/sana/pipeline_sana_controlnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index a957b455924a..ecca994c2376 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -38,8 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.import_utils import is_torch_version -from ...utils.torch_utils import randn_tensor +from ...utils.torch_utils import is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, From b0c45fdf0be6b7775ff4686258be36f4670297f1 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 6 Jun 2025 15:52:07 +0800 Subject: [PATCH 06/10] Update pipeline_sana_sprint_img2img.py --- src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py index 090d30382798..bbf16ab03c03 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -39,7 +39,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import get_device, randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput From dea64fcb9aa46c2c5dade6a3164c45e3fb5b41f3 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 6 Jun 2025 15:53:18 +0800 Subject: [PATCH 07/10] Update pipeline_sana_sprint.py --- src/diffusers/pipelines/sana/pipeline_sana_sprint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index c17d12427b03..557bd6ed8d24 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -38,8 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import get_device, randn_tensor -from ...utils.import_utils import is_torch_version +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN from .pipeline_output import SanaPipelineOutput From d2a37848a708f7816d1382c02a48ea80c5d9cbbb Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Sun, 8 Jun 2025 23:48:50 +0000 Subject: [PATCH 08/10] fix style Signed-off-by: YAO Matrix --- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 6 +++++- src/diffusers/pipelines/musicldm/pipeline_musicldm.py | 7 ++++--- .../pipelines/pag/pipeline_pag_controlnet_sd.py | 6 +++++- .../pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py | 2 +- .../pipelines/pag/pipeline_pag_controlnet_sd_xl.py | 6 +++++- src/diffusers/pipelines/pag/pipeline_pag_sana.py | 9 ++++++--- src/diffusers/pipelines/sana/pipeline_sana.py | 9 ++++++--- src/diffusers/pipelines/sana/pipeline_sana_controlnet.py | 9 ++++++--- src/diffusers/pipelines/sana/pipeline_sana_sprint.py | 6 +++++- .../pipelines/sana/pipeline_sana_sprint_img2img.py | 6 +++++- src/diffusers/utils/torch_utils.py | 3 ++- 11 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 442cf11e243a..d178f382db47 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1465,7 +1465,11 @@ def __call__( # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if torch.cuda.is_available() and (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py index c0455c5b4198..1adb9af41c69 100644 --- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py +++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -396,9 +396,10 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its `forward` - method is called, and the model remains in accelerator until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its + `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are + lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution + of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index dc1c466327d4..2d472382da41 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -1228,7 +1228,11 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if torch.cuda.is_available() and (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py index dbe5532f6f47..35a62b3d474a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -1521,7 +1521,7 @@ def __call__( if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") self.controlnet.to("cpu") - empty_devcie_cache() + empty_device_cache() if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index a5c71ac72c02..dbabfe423d47 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1498,7 +1498,11 @@ def __call__( for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if torch.cuda.is_available() and (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: + if ( + torch.cuda.is_available() + and (is_unet_compiled and is_controlnet_compiled) + and is_torch_higher_equal_2_1 + ): torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0])) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 624b1e18e03a..78320223491c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -35,7 +35,7 @@ logging, replace_example_docstring, ) -from ...utils.torch_utils import is_torch_version, randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -918,8 +918,11 @@ def __call__( else: latents = latents.to(self.vae.dtype) torch_accelerator_module = getattr(torch, get_device(), torch.cuda) - oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfM -emoryError + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except oom_error as e: diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index b40d48d46022..8a7d574cd5df 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_torch_version, randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -983,8 +983,11 @@ def __call__( else: latents = latents.to(self.vae.dtype) torch_accelerator_module = getattr(torch, get_device(), torch.cuda) - oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfM -emoryError + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except oom_error as e: diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index ecca994c2376..bf96fe902e3a 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -38,7 +38,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ...utils.torch_utils import is_torch_version, randn_tensor +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor from ..pipeline_utils import DiffusionPipeline from ..pixart_alpha.pipeline_pixart_alpha import ( ASPECT_RATIO_512_BIN, @@ -1079,8 +1079,11 @@ def __call__( else: latents = latents.to(self.vae.dtype) torch_accelerator_module = getattr(torch, get_device(), torch.cuda) - oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfM -emoryError + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except oom_error as e: diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 557bd6ed8d24..d13506cc7120 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -865,7 +865,11 @@ def __call__( else: latents = latents.to(self.vae.dtype) torch_accelerator_module = getattr(torch, get_device(), torch.cuda) - oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except oom_error as e: diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py index bbf16ab03c03..99cc101f06d0 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -953,7 +953,11 @@ def __call__( else: latents = latents.to(self.vae.dtype) torch_accelerator_module = getattr(torch, get_device(), torch.cuda) - oom_error = torch.OutOfMemoryError if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) try: image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] except oom_error as e: diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index ccf2cc5cc7cf..35bcd46ede5b 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -173,7 +173,8 @@ def get_device(): else: return "cpu" -def empty_device_cache(device_type: Optional[str] = None) + +def empty_device_cache(device_type: Optional[str] = None): if device_type is None: device_type = get_device() device_mod = getattr(torch, device_type, torch.cuda) From 39b4c15529c5ba6cf85e2f772e6691a84f9f32a6 Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 11 Jun 2025 13:35:54 +0000 Subject: [PATCH 09/10] fix fat-thumb while merge conflict Signed-off-by: YAO Matrix --- src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index 9640bcc1e1a6..874b4531be78 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -36,6 +36,8 @@ scale_lora_layers, unscale_lora_layers, ) +from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor +from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker From 1f063a19d3a4e0dd1ae79369bfbd769bf25d3170 Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Fri, 13 Jun 2025 14:14:44 +0000 Subject: [PATCH 10/10] fix ci issues Signed-off-by: YAO Matrix --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a90fce398c9c..4fb3a541f509 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1311,8 +1311,9 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un self._offload_device = device if self.device.type != "cpu": + orig_device_type = self.device.type self.to("cpu", silence_dtype_warnings=True) - empty_device_cache(self.device.type) + empty_device_cache(orig_device_type) for name, model in self.components.items(): if not isinstance(model, torch.nn.Module):