Skip to content

Commit 353401f

Browse files
sophie-xhonneuxSophie Xhonneux
andauthored
Add forgotten LayerNorm (ecmwf#687)
* Add forgotten LayerNorm * Apply ruff --------- Co-authored-by: Sophie Xhonneux <sxhonneux@clariden-ln001.cscs.ch>
1 parent 51aa5d5 commit 353401f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/weathergen/model/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def __init__(
310310
lnorm = norm if with_qk_lnorm else torch.nn.Identity
311311
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
312312
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
313+
self.lnorm_kv = lnorm(dim_embed_kv, eps=norm_eps)
313314

314315
self.dtype = attention_dtype
315316
assert with_flash, "Only flash attention supported at the moment"
@@ -319,6 +320,7 @@ def forward(self, x_q, x_kv, x_lens=None, x_kv_lens=None, ada_ln_aux=None):
319320
if self.with_residual:
320321
x_q_in = x_q
321322
x_q = x_q if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux)
323+
x_kv = self.lnorm_kv(x_kv)
322324

323325
## project onto heads and q,k,v and
324326
# ensure these are 4D tensors as required for flash attention

0 commit comments

Comments
 (0)