@@ -101,7 +101,7 @@ def __call__(
101101
102102 # Repeat KV heads if attn.kv_heads < attn.heads
103103 key = repeat_kv (key , attn .kv_groups ) # [B, N_KV, L, H] --> [B, N, L, H]
104- value = repeat_kv (query , attn .kv_groups ) # [B, N_KV, L, H] --> [B, N, L, H]
104+ value = repeat_kv (value , attn .kv_groups ) # [B, N_KV, L, H] --> [B, N, L, H]
105105
106106 # TODO: call dispatch_attention_fn here to dispatch the implementation to a backend? e.g. FlashAttn
107107 # hidden_states = dispatch_attention_fn(
@@ -175,13 +175,12 @@ def __call__(
175175
176176 # Repeat KV heads if attn.kv_heads < attn.heads
177177 key = repeat_kv (key , attn .kv_groups ) # [B, N_KV, L, H] --> [B, N, L, H]
178- value = repeat_kv (query , attn .kv_groups ) # [B, N_KV, L, H] --> [B, N, L, H]
178+ value = repeat_kv (value , attn .kv_groups ) # [B, N_KV, L, H] --> [B, N, L, H]
179179
180180 # TODO: call dispatch_attention_fn here to dispatch the implementation to a backend? e.g. FlashAttn
181181 # hidden_states = dispatch_attention_fn(
182182 # query, key, value, attn_mask=attention_mask, backend=self._attention_backend
183183 # )
184- # TODO: check SDPA call here
185184 hidden_states = F .scaled_dot_product_attention (
186185 query ,
187186 key ,
@@ -348,12 +347,12 @@ def __init__(
348347 self .down_proj = nn .Linear (self .intermediate_size , self .dim_out , bias = bias )
349348
350349 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
351- hidden_states = self .up_proj (hidden_states )
350+ expanded_hidden_states = self .up_proj (hidden_states )
352351
353352 gated_hidden_states = self .gate_proj (hidden_states )
354353 gated_hidden_states = self .act_fn (gated_hidden_states )
355354
356- hidden_states = gated_hidden_states * hidden_states
355+ hidden_states = gated_hidden_states * expanded_hidden_states
357356 hidden_states = self .down_proj (hidden_states )
358357 return hidden_states
359358
@@ -389,7 +388,7 @@ def __init__(
389388 def forward (
390389 self ,
391390 hidden_states : torch .Tensor ,
392- temb : torch .Tensor , # temb is not used in Dream (time-invariant model)
391+ temb : Optional [ torch .Tensor ] = None , # temb is not used in Dream (time-invariant model)
393392 attention_mask : Optional [torch .Tensor ] = None ,
394393 rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
395394 ) -> torch .Tensor :
0 commit comments