@@ -48,7 +48,7 @@ def __init__(
48
48
self .scale = self .head_dim ** - 0.5
49
49
self .track_dependency_mask = False
50
50
self .dependency_mask = None
51
- self .head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0
51
+ self .head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0
52
52
53
53
self .head_selector = nn .Linear (dim , num_heads )
54
54
@@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
78
78
p = (self .head_selector (x ) / self .head_selector_temperature ).softmax (dim = - 1 )
79
79
p = p .transpose (- 2 , - 1 ).reshape (B , self .num_heads , 1 , N )
80
80
81
- m = self .message_controller (x ).sigmoid ().reshape (B , 1 , 1 , N )
81
+ m = m * self .message_controller (x ).sigmoid ().reshape (B , 1 , 1 , N )
82
82
83
83
q = q * self .scale
84
84
attn = q @ k .transpose (- 2 , - 1 )
@@ -187,7 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
187
187
#x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm
188
188
189
189
x = self .norm (x )
190
- # x = x * m.transpose(1, 3).squeeze(-1)
190
+ x = x * m .transpose (1 , 3 ).squeeze (- 1 )
191
191
return x
192
192
193
193
0 commit comments