Skip to content

Commit a762af8

Browse files
committed
layernorm key / values for seasonal, growth, level, after key value projection in classification wrapper
1 parent 2561053 commit a762af8

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

etsformer_pytorch/etsformer_pytorch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ def __init__(
388388
level_kernel_size = 3,
389389
growth_kernel_size = 3,
390390
seasonal_kernel_size = 3,
391-
dropout = 0.
391+
dropout = 0.,
392+
norm_time_features_kv = False
392393
):
393394
super().__init__()
394395
assert isinstance(etsformer, ETSFormer)
@@ -403,12 +404,14 @@ def __init__(
403404
self.queries = nn.Parameter(torch.randn(heads, dim_head))
404405

405406
self.growth_to_kv = nn.Sequential(
407+
nn.LayerNorm(model_dim),
406408
Rearrange('b n d -> b d n'),
407409
nn.Conv1d(model_dim, inner_dim * 2, growth_kernel_size, bias = False, padding = growth_kernel_size // 2),
408410
Rearrange('... (kv h d) n -> kv ... h n d', kv = 2, h = heads)
409411
)
410412

411413
self.seasonal_to_kv = nn.Sequential(
414+
nn.LayerNorm(model_dim),
412415
Rearrange('b n d -> b d n'),
413416
nn.Conv1d(model_dim, inner_dim * 2, seasonal_kernel_size, bias = False, padding = seasonal_kernel_size // 2),
414417
Rearrange('... (kv h d) n -> kv ... h n d', kv = 2, h = heads)
@@ -417,7 +420,8 @@ def __init__(
417420
self.level_to_kv = nn.Sequential(
418421
Rearrange('b n t -> b t n'),
419422
nn.Conv1d(time_features, inner_dim * 2, level_kernel_size, bias = False, padding = level_kernel_size // 2),
420-
Rearrange('b (kv h d) n -> kv b h n d', kv = 2, h = heads)
423+
Rearrange('b (kv h d) n -> kv b h n d', kv = 2, h = heads),
424+
nn.LayerNorm(dim_head) if norm_time_features_kv else nn.Identity()
421425
)
422426

423427
self.to_out = nn.Linear(inner_dim, model_dim)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ETSformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.12',
6+
version = '0.0.15',
77
license='MIT',
88
description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)