Skip to content

Commit 60f10fb

Browse files
committed
do multi-headed layernorm, done after projection, for growth, seasonal, level key / values
1 parent a762af8 commit 60f10fb

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

etsformer_pytorch/etsformer_pytorch.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,18 @@ def forward(
377377

378378
# classification wrapper
379379

380+
class MultiheadLayerNorm(nn.Module):
381+
def __init__(self, dim, heads = 1, eps = 1e-5):
382+
super().__init__()
383+
self.eps = eps
384+
self.g = nn.Parameter(torch.ones(heads, 1, dim))
385+
self.b = nn.Parameter(torch.zeros(heads, 1, dim))
386+
387+
def forward(self, x):
388+
std = torch.var(x, dim = -1, unbiased = False, keepdim = True).sqrt()
389+
mean = torch.mean(x, dim = -1, keepdim = True)
390+
return (x - mean) / (std + self.eps) * self.g + self.b
391+
380392
class ClassificationWrapper(nn.Module):
381393
def __init__(
382394
self,
@@ -388,8 +400,7 @@ def __init__(
388400
level_kernel_size = 3,
389401
growth_kernel_size = 3,
390402
seasonal_kernel_size = 3,
391-
dropout = 0.,
392-
norm_time_features_kv = False
403+
dropout = 0.
393404
):
394405
super().__init__()
395406
assert isinstance(etsformer, ETSFormer)
@@ -404,24 +415,24 @@ def __init__(
404415
self.queries = nn.Parameter(torch.randn(heads, dim_head))
405416

406417
self.growth_to_kv = nn.Sequential(
407-
nn.LayerNorm(model_dim),
408418
Rearrange('b n d -> b d n'),
409419
nn.Conv1d(model_dim, inner_dim * 2, growth_kernel_size, bias = False, padding = growth_kernel_size // 2),
410-
Rearrange('... (kv h d) n -> kv ... h n d', kv = 2, h = heads)
420+
Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads),
421+
MultiheadLayerNorm(dim_head, heads = 2 * heads),
411422
)
412423

413424
self.seasonal_to_kv = nn.Sequential(
414-
nn.LayerNorm(model_dim),
415425
Rearrange('b n d -> b d n'),
416426
nn.Conv1d(model_dim, inner_dim * 2, seasonal_kernel_size, bias = False, padding = seasonal_kernel_size // 2),
417-
Rearrange('... (kv h d) n -> kv ... h n d', kv = 2, h = heads)
427+
Rearrange('... (kv h d) n -> ... (kv h) n d', kv = 2, h = heads),
428+
MultiheadLayerNorm(dim_head, heads = 2 * heads),
418429
)
419430

420431
self.level_to_kv = nn.Sequential(
421432
Rearrange('b n t -> b t n'),
422433
nn.Conv1d(time_features, inner_dim * 2, level_kernel_size, bias = False, padding = level_kernel_size // 2),
423-
Rearrange('b (kv h d) n -> kv b h n d', kv = 2, h = heads),
424-
nn.LayerNorm(dim_head) if norm_time_features_kv else nn.Identity()
434+
Rearrange('b (kv h d) n -> b (kv h) n d', kv = 2, h = heads),
435+
MultiheadLayerNorm(dim_head, heads = 2 * heads),
425436
)
426437

427438
self.to_out = nn.Linear(inner_dim, model_dim)
@@ -441,11 +452,13 @@ def forward(self, timeseries):
441452

442453
q = self.queries * self.scale
443454

444-
k, v = torch.cat((
455+
kvs = torch.cat((
445456
self.growth_to_kv(latent_growths),
446457
self.seasonal_to_kv(latent_seasonals),
447458
self.level_to_kv(level_output)
448-
), dim = -2).unbind(dim = 0)
459+
), dim = -2)
460+
461+
k, v = kvs.chunk(2, dim = 1)
449462

450463
# cross attention pooling
451464

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ETSformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.15',
6+
version = '0.0.16',
77
license='MIT',
88
description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)