Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ def __init__(
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)

self.gradient_checkpointing = False
self.image_rotary_emb = None

def forward(
self,
Expand Down Expand Up @@ -717,11 +718,12 @@ def forward(
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
if self.image_rotary_emb is None:
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
self.image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
self.image_rotary_emb = self.pos_embed(ids)

if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
Expand All @@ -735,7 +737,7 @@ def forward(
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
self.image_rotary_emb,
joint_attention_kwargs,
)

Expand All @@ -744,7 +746,7 @@ def forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
image_rotary_emb=self.image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand All @@ -767,7 +769,7 @@ def forward(
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
self.image_rotary_emb,
joint_attention_kwargs,
)

Expand All @@ -776,7 +778,7 @@ def forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
image_rotary_emb=self.image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def __init__(
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)

self.gradient_checkpointing = False
self.rotary_emb = None

def forward(
self,
Expand Down Expand Up @@ -644,7 +645,8 @@ def forward(
post_patch_height = height // p_h
post_patch_width = width // p_w

rotary_emb = self.rope(hidden_states)
if self.rotary_emb is None:
self.rotary_emb = self.rope(hidden_states)

hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
Expand Down Expand Up @@ -673,11 +675,11 @@ def forward(
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
block, hidden_states, encoder_hidden_states, timestep_proj, self.rotary_emb
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, self.rotary_emb)

# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,7 @@ def __call__(

# Offload all models
self.maybe_free_model_hooks()
self.image_rotary_emb = None

if not return_dict:
return (image,)
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()


self.transformer.rotary_emb = None
if self.transformer_2 is not None:
self.transformer_2.rotary_emb = None

if not return_dict:
return (video,)

Expand Down