@@ -295,6 +295,9 @@ def __init__(
295
295
):
296
296
super ().__init__ ()
297
297
assert (model_dim % heads ) == 0 , 'model dimension must be divisible by number of heads'
298
+ self .model_dim = model_dim
299
+ self .time_features = time_features
300
+
298
301
self .embed = InputEmbedding (time_features , model_dim , kernel_size = embed_kernel_size , dropout = dropout )
299
302
300
303
self .encoder_layers = nn .ModuleList ([])
@@ -365,3 +368,86 @@ def forward(
365
368
forecasted = rearrange (forecasted , 'b n 1 -> b n' )
366
369
367
370
return forecasted
371
+
372
+ # classification wrapper
373
+
374
+ class ClassificationWrapper (nn .Module ):
375
+ def __init__ (
376
+ self ,
377
+ * ,
378
+ etsformer ,
379
+ num_classes = 10 ,
380
+ heads = 16 ,
381
+ dim_head = 32 ,
382
+ level_kernel_size = 3 ,
383
+ dropout = 0.
384
+ ):
385
+ super ().__init__ ()
386
+ assert isinstance (etsformer , ETSFormer )
387
+ self .etsformer = etsformer
388
+ model_dim = etsformer .model_dim
389
+ time_features = etsformer .time_features
390
+
391
+ inner_dim = dim_head * heads
392
+ self .scale = dim_head ** - 0.5
393
+ self .dropout = nn .Dropout (dropout )
394
+
395
+ self .type_growth = nn .Parameter (torch .randn (model_dim ) * 1e-5 )
396
+ self .type_seasonal = nn .Parameter (torch .randn (model_dim ) * 1e-5 )
397
+
398
+ self .queries = nn .Parameter (torch .randn (heads , dim_head ))
399
+
400
+ self .growth_and_seasonal_to_kv = nn .Sequential (
401
+ nn .Linear (model_dim , inner_dim * 2 , bias = False ),
402
+ Rearrange ('... n (kv h d) -> kv ... h n d' , kv = 2 , h = heads )
403
+ )
404
+
405
+ self .level_to_kv = nn .Sequential (
406
+ Rearrange ('b n t -> b t n' ),
407
+ nn .Conv1d (time_features , inner_dim * 2 , level_kernel_size , bias = False , padding = level_kernel_size // 2 ),
408
+ Rearrange ('b (kv h d) n -> kv b h n d' , kv = 2 , h = heads )
409
+ )
410
+
411
+ self .to_out = nn .Linear (inner_dim , model_dim )
412
+
413
+ self .to_logits = nn .Sequential (
414
+ nn .LayerNorm (model_dim ),
415
+ nn .Linear (model_dim , num_classes )
416
+ )
417
+
418
+ def forward (self , timeseries ):
419
+ latent_growths , latent_seasonals , level_output = self .etsformer (timeseries )
420
+
421
+ latent_growths = latent_growths .mean (dim = - 2 )
422
+ latent_seasonals = latent_seasonals .mean (dim = - 2 )
423
+
424
+ # differentiate between growth and seasonal
425
+
426
+ latent_growths = latent_growths + self .type_growth
427
+ latent_seasonals = latent_seasonals + self .type_seasonal
428
+
429
+ # queries, key, values
430
+
431
+ q = self .queries * self .scale
432
+
433
+ k , v = torch .cat ((
434
+ self .growth_and_seasonal_to_kv (torch .cat ((latent_growths , latent_seasonals ), dim = - 2 )),
435
+ self .level_to_kv (level_output )
436
+ ), dim = - 2 ).unbind (dim = 0 )
437
+
438
+ # cross attention pooling
439
+
440
+ sim = einsum ('h d, b h j d -> b h j' , q , k )
441
+ sim = sim - sim .amax (dim = - 1 , keepdim = True ).detach ()
442
+
443
+ attn = sim .softmax (dim = - 1 )
444
+ attn = self .dropout (attn )
445
+
446
+ out = einsum ('b h j, b h j d -> b h d' , attn , v )
447
+ out = rearrange (out , 'b ... -> b (...)' )
448
+
449
+ out = self .to_out (out )
450
+
451
+ # project to logits
452
+
453
+ return self .to_logits (out )
0 commit comments