From 01a9b23906692c567871076c8d61901cd396338d Mon Sep 17 00:00:00 2001 From: yyt Date: Sun, 2 Nov 2025 14:09:03 +0000 Subject: [PATCH] rope cache --- .../models/transformers/transformer_flux.py | 20 ++++++++++--------- .../models/transformers/transformer_wan.py | 8 +++++--- src/diffusers/pipelines/flux/pipeline_flux.py | 1 + src/diffusers/pipelines/wan/pipeline_wan.py | 5 +++++ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437f2..77bc85032626 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -633,6 +633,7 @@ def __init__( self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False + self.image_rotary_emb = None def forward( self, @@ -717,11 +718,12 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + if self.image_rotary_emb is None: + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + self.image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + self.image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") @@ -735,7 +737,7 @@ def forward( hidden_states, encoder_hidden_states, temb, - image_rotary_emb, + self.image_rotary_emb, joint_attention_kwargs, ) @@ -744,7 +746,7 @@ def forward( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=self.image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) @@ -767,7 +769,7 @@ def forward( hidden_states, encoder_hidden_states, temb, - image_rotary_emb, + self.image_rotary_emb, joint_attention_kwargs, ) @@ -776,7 +778,7 @@ def forward( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=self.image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..2e566f1daaaa 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -613,6 +613,7 @@ def __init__( self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False + self.rotary_emb = None def forward( self, @@ -644,7 +645,8 @@ def forward( post_patch_height = height // p_h post_patch_width = width // p_w - rotary_emb = self.rope(hidden_states) + if self.rotary_emb is None: + self.rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) @@ -673,11 +675,11 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + block, hidden_states, encoder_hidden_states, timestep_proj, self.rotary_emb ) else: for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, self.rotary_emb) # 5. Output norm, projection & unpatchify if temb.ndim == 3: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 5041e352f73d..732e98c048f0 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -1008,6 +1008,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + self.image_rotary_emb = None if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 78fe71ea9138..e7226d336ac8 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -650,6 +650,11 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + + self.transformer.rotary_emb = None + if self.transformer_2 is not None: + self.transformer_2.rotary_emb = None + if not return_dict: return (video,)