Skip to content

Commit 68f9ed9

Browse files
committed
fix sub-gn in mhesa
1 parent 781e763 commit 68f9ed9

File tree

3 files changed

+8
-15
lines changed

3 files changed

+8
-15
lines changed

README.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,3 @@ logits = adapter(timeseries) # (1, 10)
7373
primaryClass = {cs.LG}
7474
}
7575
```
76-
77-
```bibtex
78-
@article{Sun2023RetentiveNA,
79-
title = {Retentive Network: A Successor to Transformer for Large Language Models},
80-
author = {Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},
81-
journal = {ArXiv},
82-
year = {2023},
83-
volume. = {abs/2307.08621},
84-
url = {https://api.semanticscholar.org/CorpusID:259937453}
85-
}
86-
```

etsformer_pytorch/etsformer_pytorch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ def __init__(
103103
self.dropout = nn.Dropout(dropout)
104104
self.alpha = nn.Parameter(torch.randn(heads))
105105

106-
self.norm_heads = nn.GroupNorm(heads, dim) if norm_heads else nn.Identity()
106+
self.norm_heads = nn.Sequential(
107+
Rearrange('b n (h d) -> b (h d) n', h = heads),
108+
nn.GroupNorm(heads, dim),
109+
Rearrange('b (h d) n -> b n (h d)', h = heads)
110+
) if norm_heads else nn.Identity()
107111

108112
self.project_in = nn.Linear(dim, dim)
109113
self.project_out = nn.Linear(dim, dim)
@@ -178,10 +182,10 @@ def forward(self, x, naive = False):
178182

179183
output = rearrange(output, 'b h n d -> b n (h d)')
180184

181-
# maybe groupnorm
182-
# borrowing a trick from retnet paper
185+
# maybe sub-ln from https://arxiv.org/abs/2210.06423 - retnet used groupnorm
183186

184187
output = self.norm_heads(output)
188+
185189
return self.project_out(output)
186190

187191
## frequency attention

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ETSformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch',
99
long_description_content_type = 'text/markdown',

0 commit comments

Comments
 (0)