Skip to content

Commit 44e782c

Browse files
[DRAFT]: Prediction head architecture clean-up (ecmwf#481)
* - Avoid time encoding is 0 - eps in layer norms to 10^-3 - bf16 * Make the attention dtype and norm eps configurable * Fix gitignore and add config files * Shuffle config files into sensible folders * Implement first attempt at new prediction heads * Fix some bugs * Fix trainer compile + fsdp * Fix trainer and better defaults * Choose AdaLN * Correlate predictions per cell Previously this pr treated as independent * Make things more parameter efficient * Revert "Make things more parameter efficient" It made things way worse This reverts commit 0f31bf1. * Improve the prediction heads at small sizes * Improve the stability of training Two main changes: better beta 1 and beta 2 values in adam w and remove gelu * Adding some more regularisation In particular to prevent training divergences and overfitting * Forgot the dropout in MLPs * Tune the learning rate * Add the original prediction heads CAREFUL: Untested!!! * Fix bugs and ruff * Restore old version last part * Start fixing the defaults * Deleting hpc specific configs * Deleting hpc specific configs * Defaults and documentation * Apply ruff * Clean up code * Add one more comment --------- Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
1 parent 3860c1c commit 44e782c

File tree

10 files changed

+621
-118
lines changed

10 files changed

+621
-118
lines changed

config/default_config.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ ae_global_att_dense_rate: 0.2
3232
ae_global_block_factor: 64
3333
ae_global_mlp_hidden_factor: 2
3434

35+
decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
3536
pred_adapter_kv: False
3637
pred_self_attention: True
3738
pred_dyadic_dims: False
@@ -90,11 +91,11 @@ samples_per_validation: 512
9091
shuffle: True
9192

9293
lr_scaling_policy: "sqrt"
93-
lr_start: 0.000001
94-
lr_max: 0.0001
95-
lr_final_decay: 0.000001
94+
lr_start: 1e-6
95+
lr_max: 5e-5
96+
lr_final_decay: 1e-6
9697
lr_final: 0.0
97-
lr_steps_warmup: 256
98+
lr_steps_warmup: 512
9899
lr_steps_cooldown: 512
99100
lr_policy_warmup: "cosine"
100101
lr_policy_decay: "linear"

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ ignore = [
105105
"SIM401",
106106
# To ignore, not relevant for us
107107
"SIM108", # in case additional norm layer supports are added in future
108-
"N817" # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
108+
"N817", # we use heavy acronyms, e.g., allowing 'import LongModuleName as LMN' (LMN is accepted)
109+
"E731", # overly restrictive and less readable code
110+
"N812", # prevents us following the convention for importing torch.nn.functional as F
109111
]
110112

111113
[tool.ruff.lint.flake8-tidy-imports.banned-api]

src/weathergen/datasets/tokenizer_masking.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,17 @@ def id(arg):
301301
target_tokens = self.masker.mask_target(target_tokens_cells, coords, geoinfos, source)
302302

303303
target_tokens_lens = [len(t) for t in target_tokens]
304+
total_target = sum(target_tokens_lens)
305+
306+
# sampling the number of targets according to sampling_rate_target
307+
samples = (torch.empty(total_target).uniform_() < sampling_rate_target).split(
308+
target_tokens_lens
309+
)
310+
target_tokens = [
311+
(tokens[samples]) for tokens, samples in zip(target_tokens, samples, strict=False)
312+
]
313+
target_tokens_lens = [len(t) for t in target_tokens]
314+
304315
if torch.tensor(target_tokens_lens).sum() == 0:
305316
return (torch.tensor([]), torch.tensor([]), torch.tensor([]), torch.tensor([]))
306317

src/weathergen/model/attention.py

Lines changed: 131 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
num_heads,
2626
dim_head_proj=None,
2727
dropout_rate=0.0,
28+
with_residual=True,
2829
with_qk_lnorm=True,
2930
with_flash=True,
3031
norm_type="LayerNorm",
@@ -38,6 +39,7 @@ def __init__(
3839
self.num_heads = num_heads
3940
self.dropout_rate = dropout_rate
4041
self.with_flash = with_flash
42+
self.with_residual = with_residual
4143
self.softcap = softcap
4244

4345
assert dim_embed % num_heads == 0
@@ -50,8 +52,6 @@ def __init__(
5052

5153
if dim_aux is not None:
5254
self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps)
53-
else:
54-
self.lnorm = norm(dim_embed, eps=norm_eps)
5555
self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False)
5656
self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False)
5757
self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False)
@@ -71,7 +71,7 @@ def __init__(
7171
#########################################
7272
def forward(self, x, x_lens, ada_ln_aux=None):
7373
x_in = x
74-
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
74+
x = x if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
7575

7676
## project onto heads and q,k,v and
7777
# ensure these are 4D tensors as required for flash attention
@@ -94,8 +94,12 @@ def forward(self, x, x_lens, ada_ln_aux=None):
9494
dropout_p=self.dropout_rate,
9595
)
9696

97-
# return x_in + self.dropout( self.proj_out( outs.flatten( -2, -1)) )
98-
return x_in + self.proj_out(outs.flatten(-2, -1))
97+
x = self.proj_out(outs.flatten(-2, -1))
98+
99+
if self.with_residual:
100+
x = x_in + x
101+
102+
return x
99103

100104

101105
####################################################################################################
@@ -107,6 +111,7 @@ def __init__(
107111
num_heads,
108112
dim_head_proj=None,
109113
dropout_rate=0.0,
114+
with_residual=True,
110115
with_qk_lnorm=True,
111116
with_flash=True,
112117
norm_type="LayerNorm",
@@ -167,7 +172,11 @@ def forward(self, x, x_lens=None):
167172

168173
outs = self.compiled_flex_attention(qs, ks, vs).transpose(1, 2).squeeze()
169174

170-
return x_in + self.dropout(self.proj_out(outs.flatten(-2, -1)))
175+
x = self.proj_out(outs.flatten(-2, -1))
176+
if self.with_residual:
177+
x = x_in + x
178+
179+
return x
171180

172181

173182
####################################################################################################
@@ -284,9 +293,6 @@ def __init__(
284293

285294
if dim_aux is not None:
286295
self.lnorm_in_q = AdaLayerNorm(dim_embed_q, dim_aux, norm_eps=norm_eps)
287-
else:
288-
self.lnorm_in_q = norm(dim_embed_q, eps=norm_eps)
289-
self.lnorm_in_kv = norm(dim_embed_kv, eps=norm_eps)
290296

291297
self.proj_heads_q = torch.nn.Linear(dim_embed_q, num_heads * self.dim_head_proj, bias=False)
292298
self.proj_heads_k = torch.nn.Linear(
@@ -309,11 +315,10 @@ def __init__(
309315
assert with_flash, "Only flash attention supported at the moment"
310316

311317
#########################################
312-
def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
318+
def forward(self, x_q, x_kv, x_lens=None, x_kv_lens=None, ada_ln_aux=None):
313319
if self.with_residual:
314320
x_q_in = x_q
315-
x_q = self.lnorm_in_q(x_q) if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux)
316-
x_kv = self.lnorm_in_kv(x_kv)
321+
x_q = x_q if ada_ln_aux is None else self.lnorm_in_q(x_q, ada_ln_aux)
317322

318323
## project onto heads and q,k,v and
319324
# ensure these are 4D tensors as required for flash attention
@@ -324,15 +329,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
324329
vs = self.proj_heads_v(x_kv).reshape(s)
325330

326331
if x_kv_lens is not None:
327-
cum_x_q_lens = torch.cumsum(x_q_lens, 0, dtype=torch.int32)
332+
cum_x_q_lens = torch.cumsum(x_lens, 0, dtype=torch.int32)
328333
cum_x_kv_lens = torch.cumsum(x_kv_lens, 0, dtype=torch.int32)
329334
outs = flash_attn_varlen_func(
330335
qs,
331336
ks,
332337
vs,
333338
cum_x_q_lens,
334339
cum_x_kv_lens,
335-
x_q_lens.max(),
340+
x_lens.max(),
336341
x_kv_lens.max(),
337342
softcap=self.softcap,
338343
dropout_p=self.dropout_rate,
@@ -454,14 +459,13 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
454459
vs,
455460
cum_x_q_lens,
456461
cum_x_kv_lens,
457-
x_q_lens.max(),
458-
x_kv_lens.max(),
462+
x_q_lens.max().item(),
463+
x_kv_lens.max().item(),
459464
softcap=self.softcap,
460465
dropout_p=self.dropout_rate,
461466
)
462467
]
463468

464-
# outs = self.dropout( self.proj_out( torch.stack(outs).transpose(1,0).flatten( -2, -1)) )
465469
outs = self.proj_out(torch.stack(outs).transpose(1, 0).flatten(-2, -1))
466470
if self.with_residual:
467471
outs = x_q_in + outs.reshape(x_q_in.shape)
@@ -479,7 +483,9 @@ def __init__(
479483
dim_head_proj=None,
480484
dropout_rate=0.0,
481485
with_qk_lnorm=True,
486+
with_residual=True,
482487
with_flash=True,
488+
softcap=0.0,
483489
norm_type="LayerNorm",
484490
dim_aux=None,
485491
norm_eps=1e-5,
@@ -490,7 +496,10 @@ def __init__(
490496
self.num_heads = num_heads
491497
self.with_flash = with_flash
492498
self.dropout_rate = dropout_rate
499+
self.with_residual = with_residual
500+
self.softcap = softcap
493501

502+
assert with_flash, "You have to use flash attention"
494503
assert dim_embed % num_heads == 0
495504
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
496505

@@ -502,57 +511,136 @@ def __init__(
502511
if dim_aux is not None:
503512
self.lnorm = AdaLayerNorm(dim_embed, dim_aux, norm_eps=norm_eps)
504513
else:
505-
self.lnorm = norm(dim_embed, eps=norm_eps)
514+
self.lnorm = norm(dim_embed)
506515
self.proj_heads_q = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False)
507516
self.proj_heads_k = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False)
508517
self.proj_heads_v = torch.nn.Linear(dim_embed, num_heads * self.dim_head_proj, bias=False)
509518
self.proj_out = torch.nn.Linear(dim_embed, dim_embed, bias=False)
510-
self.dropout = (
511-
torch.nn.Dropout(p=dropout_rate) if dropout_rate > 0.0 else torch.nn.Identity()
512-
)
513519

514520
lnorm = norm if with_qk_lnorm else torch.nn.Identity
515-
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
516-
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
521+
self.lnorm_q = lnorm(self.dim_head_proj)
522+
self.lnorm_k = lnorm(self.dim_head_proj)
517523

518524
self.dtype = attention_dtype
519-
if with_flash:
520-
self.att = torch.nn.functional.scaled_dot_product_attention
521-
else:
522-
self.att = self.attention
523-
self.softmax = torch.nn.Softmax(dim=-1)
524525

525526
#########################################
526527
def forward(self, x, ada_ln_aux=None):
527528
x_in = x
528-
# x = self.lnorm( x)
529-
x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
529+
# x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
530530

531531
## project onto heads and q,k,v and
532532
# ensure these are 4D tensors as required for flash attention
533-
s = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1]
534-
qs = self.lnorm_q(self.proj_heads_q(x).reshape(s)).to(self.dtype)
535-
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
536-
vs = self.proj_heads_v(x).reshape(s).to(self.dtype)
533+
q_shape = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1]
534+
kv_shape = [
535+
*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]),
536+
self.num_heads,
537+
-1,
538+
]
539+
qs = self.lnorm_q(self.proj_heads_q(x).reshape(q_shape)).to(self.dtype)
540+
ks = self.lnorm_k(self.proj_heads_k(x).reshape(kv_shape)).to(self.dtype)
541+
vs = self.proj_heads_v(x).reshape(kv_shape).to(self.dtype)
537542

538543
# ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt)
539-
outs = flash_attn_func(qs, ks, vs, dropout_p=self.dropout_rate)
544+
outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=self.dropout_rate)
545+
546+
if self.with_residual:
547+
x = x_in + self.proj_out(outs.flatten(-2, -1))
548+
else:
549+
x = self.proj_out(outs.flatten(-2, -1))
550+
551+
return x
540552

541-
# return x_in + self.dropout( self.proj_out( outs.flatten( -2, -1)) )
542-
return x_in + self.proj_out(outs.flatten(-2, -1))
543553

554+
####################################################################################################
555+
class MultiCrossAttentionHead(torch.nn.Module):
544556
#########################################
545-
def attention(self, q, k, v):
546-
scaling = 1.0 / torch.sqrt(torch.tensor(q.shape[-1]))
547-
return torch.matmul(self.softmax(scaling * self.score(q, k)), v)
557+
def __init__(
558+
self,
559+
dim_embed_q,
560+
dim_embed_kv,
561+
num_heads,
562+
dim_head_proj=None,
563+
dropout_rate=0.0,
564+
with_qk_lnorm=True,
565+
with_residual=True,
566+
with_flash=True,
567+
softcap=0.0,
568+
norm_type="LayerNorm",
569+
dim_aux=None,
570+
norm_eps=1e-5,
571+
attention_dtype=torch.bfloat16,
572+
):
573+
super(MultiCrossAttentionHead, self).__init__()
574+
575+
self.num_heads = num_heads
576+
self.with_flash = with_flash
577+
self.dropout_rate = dropout_rate
578+
self.with_residual = with_residual
579+
self.softcap = softcap
580+
581+
assert with_flash, "You have to use flash attention"
582+
assert dim_embed_kv % num_heads == 0
583+
self.dim_head_proj_kv = (
584+
dim_embed_kv // num_heads if dim_head_proj is None else dim_head_proj
585+
)
586+
self.dim_head_proj_q = dim_embed_q // num_heads if dim_head_proj is None else dim_head_proj
587+
588+
if norm_type == "LayerNorm":
589+
norm = partial(torch.nn.LayerNorm, elementwise_affine=False, eps=norm_eps)
590+
else:
591+
norm = RMSNorm
592+
593+
if dim_aux is not None:
594+
self.lnorm = AdaLayerNorm(dim_embed_kv, dim_aux, norm_eps=norm_eps)
595+
else:
596+
self.lnorm = norm(dim_embed_kv)
597+
self.proj_heads_q = torch.nn.Linear(
598+
dim_embed_q, num_heads * self.dim_head_proj_q, bias=False
599+
)
600+
self.proj_heads_k = torch.nn.Linear(
601+
dim_embed_kv, num_heads * self.dim_head_proj_kv, bias=False
602+
)
603+
self.proj_heads_v = torch.nn.Linear(
604+
dim_embed_kv, num_heads * self.dim_head_proj_kv, bias=False
605+
)
606+
self.proj_out = torch.nn.Linear(dim_embed_kv, dim_embed_kv, bias=False)
607+
608+
lnorm = norm if with_qk_lnorm else torch.nn.Identity
609+
self.lnorm_q = lnorm(self.dim_head_proj_q)
610+
self.lnorm_k = lnorm(self.dim_head_proj_kv)
611+
612+
self.dtype = attention_dtype
548613

549614
#########################################
550-
def score(self, q, k):
551-
return torch.matmul(q, torch.transpose(k, -2, -1))
615+
def forward(self, q, x, ada_ln_aux=None):
616+
x_in = x
617+
# x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
618+
619+
## project onto heads and q,k,v and
620+
# ensure these are 4D tensors as required for flash attention
621+
q_shape = [*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]), self.num_heads, -1]
622+
kv_shape = [
623+
*([x.shape[0], 1] if len(x.shape) == 2 else x.shape[:-1]),
624+
self.num_heads,
625+
-1,
626+
]
627+
qs = self.lnorm_q(self.proj_heads_q(x).reshape(q_shape)).to(self.dtype)
628+
ks = self.lnorm_k(self.proj_heads_k(x).reshape(kv_shape)).to(self.dtype)
629+
vs = self.proj_heads_v(x).reshape(kv_shape).to(self.dtype)
630+
631+
# ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt)
632+
outs = flash_attn_func(qs, ks, vs, softcap=self.softcap, dropout_p=self.dropout_rate)
633+
634+
if self.with_residual:
635+
x = x_in + self.proj_out(outs.flatten(-2, -1))
636+
else:
637+
x = self.proj_out(outs.flatten(-2, -1))
638+
639+
return x
552640

553641

554642
####################################################################################################
555-
class MultiCrossAttentionHead(torch.nn.Module):
643+
class MultiCrossAttentionHeadSPDA(torch.nn.Module):
556644
#########################################
557645
def __init__(
558646
self,

0 commit comments

Comments
 (0)