Skip to content

Commit ab42820

Browse files
authored
Refactor CogVideoX transformer forward (#10789)
update
1 parent 8d081de commit ab42820

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -503,14 +503,7 @@ def forward(
503503
attention_kwargs=attention_kwargs,
504504
)
505505

506-
if not self.config.use_rotary_positional_embeddings:
507-
# CogVideoX-2B
508-
hidden_states = self.norm_final(hidden_states)
509-
else:
510-
# CogVideoX-5B
511-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
512-
hidden_states = self.norm_final(hidden_states)
513-
hidden_states = hidden_states[:, text_seq_length:]
506+
hidden_states = self.norm_final(hidden_states)
514507

515508
# 4. Final block
516509
hidden_states = self.norm_out(hidden_states, temb=emb)

0 commit comments

Comments
 (0)