Skip to content

Commit ff03fc2

Browse files
committed
Fix shape errors in Dream transformer
1 parent 89b868a commit ff03fc2

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/diffusers/models/transformers/transformer_dream.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __call__(
101101

102102
# Repeat KV heads if attn.kv_heads < attn.heads
103103
key = repeat_kv(key, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H]
104-
value = repeat_kv(query, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H]
104+
value = repeat_kv(value, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H]
105105

106106
# TODO: call dispatch_attention_fn here to dispatch the implementation to a backend? e.g. FlashAttn
107107
# hidden_states = dispatch_attention_fn(
@@ -175,13 +175,12 @@ def __call__(
175175

176176
# Repeat KV heads if attn.kv_heads < attn.heads
177177
key = repeat_kv(key, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H]
178-
value = repeat_kv(query, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H]
178+
value = repeat_kv(value, attn.kv_groups) # [B, N_KV, L, H] --> [B, N, L, H]
179179

180180
# TODO: call dispatch_attention_fn here to dispatch the implementation to a backend? e.g. FlashAttn
181181
# hidden_states = dispatch_attention_fn(
182182
# query, key, value, attn_mask=attention_mask, backend=self._attention_backend
183183
# )
184-
# TODO: check SDPA call here
185184
hidden_states = F.scaled_dot_product_attention(
186185
query,
187186
key,
@@ -348,12 +347,12 @@ def __init__(
348347
self.down_proj = nn.Linear(self.intermediate_size, self.dim_out, bias=bias)
349348

350349
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
351-
hidden_states = self.up_proj(hidden_states)
350+
expanded_hidden_states = self.up_proj(hidden_states)
352351

353352
gated_hidden_states = self.gate_proj(hidden_states)
354353
gated_hidden_states = self.act_fn(gated_hidden_states)
355354

356-
hidden_states = gated_hidden_states * hidden_states
355+
hidden_states = gated_hidden_states * expanded_hidden_states
357356
hidden_states = self.down_proj(hidden_states)
358357
return hidden_states
359358

@@ -389,7 +388,7 @@ def __init__(
389388
def forward(
390389
self,
391390
hidden_states: torch.Tensor,
392-
temb: torch.Tensor, # temb is not used in Dream (time-invariant model)
391+
temb: Optional[torch.Tensor] = None, # temb is not used in Dream (time-invariant model)
393392
attention_mask: Optional[torch.Tensor] = None,
394393
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
395394
) -> torch.Tensor:

0 commit comments

Comments
 (0)