Skip to content

Commit 6067995

Browse files
committed
update
1 parent 08152c5 commit 6067995

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def unfuse_projections(self):
261261
def forward(
262262
self,
263263
hidden_states: torch.Tensor,
264-
encoder_hidden_states: torch.Tensor,
264+
encoder_hidden_states: Optional[torch.Tensor] = None,
265265
attention_mask: Optional[torch.Tensor] = None,
266266
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
267267
**kwargs,

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ def forward(
110110
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
111111
control_hidden_states
112112
)
113-
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
113+
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
114114
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
115115

116116
# 2. Cross-attention
117117
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
118-
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
118+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
119119
control_hidden_states = control_hidden_states + attn_output
120120

121121
# 3. Feed-forward

0 commit comments

Comments
 (0)