@@ -405,19 +405,22 @@ def __init__(
405
405
self .growth_to_kv = nn .Sequential (
406
406
Rearrange ('b n d -> b d n' ),
407
407
nn .Conv1d (model_dim , inner_dim * 2 , growth_kernel_size , bias = False , padding = growth_kernel_size // 2 ),
408
- Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads )
408
+ Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads ),
409
+ nn .LayerNorm (dim_head ),
409
410
)
410
411
411
412
self .seasonal_to_kv = nn .Sequential (
412
413
Rearrange ('b n d -> b d n' ),
413
414
nn .Conv1d (model_dim , inner_dim * 2 , seasonal_kernel_size , bias = False , padding = seasonal_kernel_size // 2 ),
414
- Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads )
415
+ Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads ),
416
+ nn .LayerNorm (dim_head ),
415
417
)
416
418
417
419
self .level_to_kv = nn .Sequential (
418
420
Rearrange ('b n t -> b t n' ),
419
421
nn .Conv1d (time_features , inner_dim * 2 , level_kernel_size , bias = False , padding = level_kernel_size // 2 ),
420
- Rearrange ('b (kv h d) n -> kv b h n d' , kv = 2 , h = heads )
422
+ Rearrange ('b (kv h d) n -> kv b h n d' , kv = 2 , h = heads ),
423
+ nn .LayerNorm (dim_head ),
421
424
)
422
425
423
426
self .to_out = nn .Linear (inner_dim , model_dim )
0 commit comments