@@ -388,7 +388,8 @@ def __init__(
388
388
level_kernel_size = 3 ,
389
389
growth_kernel_size = 3 ,
390
390
seasonal_kernel_size = 3 ,
391
- dropout = 0.
391
+ dropout = 0. ,
392
+ norm_time_features_kv = False
392
393
):
393
394
super ().__init__ ()
394
395
assert isinstance (etsformer , ETSFormer )
@@ -403,12 +404,14 @@ def __init__(
403
404
self .queries = nn .Parameter (torch .randn (heads , dim_head ))
404
405
405
406
self .growth_to_kv = nn .Sequential (
407
+ nn .LayerNorm (model_dim ),
406
408
Rearrange ('b n d -> b d n' ),
407
409
nn .Conv1d (model_dim , inner_dim * 2 , growth_kernel_size , bias = False , padding = growth_kernel_size // 2 ),
408
410
Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads )
409
411
)
410
412
411
413
self .seasonal_to_kv = nn .Sequential (
414
+ nn .LayerNorm (model_dim ),
412
415
Rearrange ('b n d -> b d n' ),
413
416
nn .Conv1d (model_dim , inner_dim * 2 , seasonal_kernel_size , bias = False , padding = seasonal_kernel_size // 2 ),
414
417
Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads )
@@ -417,7 +420,8 @@ def __init__(
417
420
self .level_to_kv = nn .Sequential (
418
421
Rearrange ('b n t -> b t n' ),
419
422
nn .Conv1d (time_features , inner_dim * 2 , level_kernel_size , bias = False , padding = level_kernel_size // 2 ),
420
- Rearrange ('b (kv h d) n -> kv b h n d' , kv = 2 , h = heads )
423
+ Rearrange ('b (kv h d) n -> kv b h n d' , kv = 2 , h = heads ),
424
+ nn .LayerNorm (dim_head ) if norm_time_features_kv else nn .Identity ()
421
425
)
422
426
423
427
self .to_out = nn .Linear (inner_dim , model_dim )
0 commit comments