Skip to content

Commit 0c4d190

Browse files
committed
add cross attention pooling as means for classification off etsformer intermediates
1 parent 561d4f0 commit 0c4d190

File tree

4 files changed

+118
-2
lines changed

4 files changed

+118
-2
lines changed

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,36 @@ timeseries = torch.randn(1, 1024, 4)
3131
pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)
3232
```
3333

34+
For using ETSFormer for classification, using cross attention pooling on all latents and level output
35+
36+
```python
37+
import torch
38+
from etsformer_pytorch import ETSFormer, ClassificationWrapper
39+
40+
etsformer = ETSFormer(
41+
time_features = 1,
42+
model_dim = 512,
43+
embed_kernel_size = 3,
44+
layers = 2,
45+
heads = 8,
46+
K = 4,
47+
dropout = 0.2
48+
)
49+
50+
adapter = ClassificationWrapper(
51+
etsformer = etsformer,
52+
dim_head = 32,
53+
heads = 16,
54+
dropout = 0.2,
55+
level_kernel_size = 5,
56+
num_classes = 10
57+
)
58+
59+
timeseries = torch.randn(1, 1024)
60+
61+
logits = adapter(timeseries) # (1, 10)
62+
```
63+
3464
## Citation
3565

3666
```bibtex

etsformer_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from etsformer_pytorch.etsformer_pytorch import ETSFormer
1+
from etsformer_pytorch.etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer_pytorch/etsformer_pytorch.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,9 @@ def __init__(
295295
):
296296
super().__init__()
297297
assert (model_dim % heads) == 0, 'model dimension must be divisible by number of heads'
298+
self.model_dim = model_dim
299+
self.time_features = time_features
300+
298301
self.embed = InputEmbedding(time_features, model_dim, kernel_size = embed_kernel_size, dropout = dropout)
299302

300303
self.encoder_layers = nn.ModuleList([])
@@ -365,3 +368,86 @@ def forward(
365368
forecasted = rearrange(forecasted, 'b n 1 -> b n')
366369

367370
return forecasted
371+
372+
# classification wrapper
373+
374+
class ClassificationWrapper(nn.Module):
375+
def __init__(
376+
self,
377+
*,
378+
etsformer,
379+
num_classes = 10,
380+
heads = 16,
381+
dim_head = 32,
382+
level_kernel_size = 3,
383+
dropout = 0.
384+
):
385+
super().__init__()
386+
assert isinstance(etsformer, ETSFormer)
387+
self.etsformer = etsformer
388+
model_dim = etsformer.model_dim
389+
time_features = etsformer.time_features
390+
391+
inner_dim = dim_head * heads
392+
self.scale = dim_head ** -0.5
393+
self.dropout = nn.Dropout(dropout)
394+
395+
self.type_growth = nn.Parameter(torch.randn(model_dim) * 1e-5)
396+
self.type_seasonal = nn.Parameter(torch.randn(model_dim) * 1e-5)
397+
398+
self.queries = nn.Parameter(torch.randn(heads, dim_head))
399+
400+
self.growth_and_seasonal_to_kv = nn.Sequential(
401+
nn.Linear(model_dim, inner_dim * 2, bias = False),
402+
Rearrange('... n (kv h d) -> kv ... h n d', kv = 2, h = heads)
403+
)
404+
405+
self.level_to_kv = nn.Sequential(
406+
Rearrange('b n t -> b t n'),
407+
nn.Conv1d(time_features, inner_dim * 2, level_kernel_size, bias = False, padding = level_kernel_size // 2),
408+
Rearrange('b (kv h d) n -> kv b h n d', kv = 2, h = heads)
409+
)
410+
411+
self.to_out = nn.Linear(inner_dim, model_dim)
412+
413+
self.to_logits = nn.Sequential(
414+
nn.LayerNorm(model_dim),
415+
nn.Linear(model_dim, num_classes)
416+
)
417+
418+
def forward(self, timeseries):
419+
latent_growths, latent_seasonals, level_output = self.etsformer(timeseries)
420+
421+
latent_growths = latent_growths.mean(dim = -2)
422+
latent_seasonals = latent_seasonals.mean(dim = -2)
423+
424+
# differentiate between growth and seasonal
425+
426+
latent_growths = latent_growths + self.type_growth
427+
latent_seasonals = latent_seasonals + self.type_seasonal
428+
429+
# queries, key, values
430+
431+
q = self.queries * self.scale
432+
433+
k, v = torch.cat((
434+
self.growth_and_seasonal_to_kv(torch.cat((latent_growths, latent_seasonals), dim = -2)),
435+
self.level_to_kv(level_output)
436+
), dim = -2).unbind(dim = 0)
437+
438+
# cross attention pooling
439+
440+
sim = einsum('h d, b h j d -> b h j', q, k)
441+
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
442+
443+
attn = sim.softmax(dim = -1)
444+
attn = self.dropout(attn)
445+
446+
out = einsum('b h j, b h j d -> b h d', attn, v)
447+
out = rearrange(out, 'b ... -> b (...)')
448+
449+
out = self.to_out(out)
450+
451+
# project to logits
452+
453+
return self.to_logits(out)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'ETSformer-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.9',
6+
version = '0.0.10',
77
license='MIT',
88
description = 'ETSTransformer - Exponential Smoothing Transformer for Time-Series Forecasting - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)