@@ -377,6 +377,18 @@ def forward(
377
377
378
378
# classification wrapper
379
379
380
+ class MultiheadLayerNorm (nn .Module ):
381
+ def __init__ (self , dim , heads = 1 , eps = 1e-5 ):
382
+ super ().__init__ ()
383
+ self .eps = eps
384
+ self .g = nn .Parameter (torch .ones (heads , 1 , dim ))
385
+ self .b = nn .Parameter (torch .zeros (heads , 1 , dim ))
386
+
387
+ def forward (self , x ):
388
+ std = torch .var (x , dim = - 1 , unbiased = False , keepdim = True ).sqrt ()
389
+ mean = torch .mean (x , dim = - 1 , keepdim = True )
390
+ return (x - mean ) / (std + self .eps ) * self .g + self .b
391
+
380
392
class ClassificationWrapper (nn .Module ):
381
393
def __init__ (
382
394
self ,
@@ -388,8 +400,7 @@ def __init__(
388
400
level_kernel_size = 3 ,
389
401
growth_kernel_size = 3 ,
390
402
seasonal_kernel_size = 3 ,
391
- dropout = 0. ,
392
- norm_time_features_kv = False
403
+ dropout = 0.
393
404
):
394
405
super ().__init__ ()
395
406
assert isinstance (etsformer , ETSFormer )
@@ -404,24 +415,24 @@ def __init__(
404
415
self .queries = nn .Parameter (torch .randn (heads , dim_head ))
405
416
406
417
self .growth_to_kv = nn .Sequential (
407
- nn .LayerNorm (model_dim ),
408
418
Rearrange ('b n d -> b d n' ),
409
419
nn .Conv1d (model_dim , inner_dim * 2 , growth_kernel_size , bias = False , padding = growth_kernel_size // 2 ),
410
- Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads )
420
+ Rearrange ('... (kv h d) n -> ... (kv h) n d' , kv = 2 , h = heads ),
421
+ MultiheadLayerNorm (dim_head , heads = 2 * heads ),
411
422
)
412
423
413
424
self .seasonal_to_kv = nn .Sequential (
414
- nn .LayerNorm (model_dim ),
415
425
Rearrange ('b n d -> b d n' ),
416
426
nn .Conv1d (model_dim , inner_dim * 2 , seasonal_kernel_size , bias = False , padding = seasonal_kernel_size // 2 ),
417
- Rearrange ('... (kv h d) n -> kv ... h n d' , kv = 2 , h = heads )
427
+ Rearrange ('... (kv h d) n -> ... (kv h) n d' , kv = 2 , h = heads ),
428
+ MultiheadLayerNorm (dim_head , heads = 2 * heads ),
418
429
)
419
430
420
431
self .level_to_kv = nn .Sequential (
421
432
Rearrange ('b n t -> b t n' ),
422
433
nn .Conv1d (time_features , inner_dim * 2 , level_kernel_size , bias = False , padding = level_kernel_size // 2 ),
423
- Rearrange ('b (kv h d) n -> kv b h n d' , kv = 2 , h = heads ),
424
- nn . LayerNorm (dim_head ) if norm_time_features_kv else nn . Identity ()
434
+ Rearrange ('b (kv h d) n -> b (kv h) n d' , kv = 2 , h = heads ),
435
+ MultiheadLayerNorm (dim_head , heads = 2 * heads ),
425
436
)
426
437
427
438
self .to_out = nn .Linear (inner_dim , model_dim )
@@ -441,11 +452,13 @@ def forward(self, timeseries):
441
452
442
453
q = self .queries * self .scale
443
454
444
- k , v = torch .cat ((
455
+ kvs = torch .cat ((
445
456
self .growth_to_kv (latent_growths ),
446
457
self .seasonal_to_kv (latent_seasonals ),
447
458
self .level_to_kv (level_output )
448
- ), dim = - 2 ).unbind (dim = 0 )
459
+ ), dim = - 2 )
460
+
461
+ k , v = kvs .chunk (2 , dim = 1 )
449
462
450
463
# cross attention pooling
451
464
0 commit comments