Skip to content

Commit a441390

Browse files
committed
Update dependencyvit.py
1 parent 8d588bb commit a441390

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

timm/models/dependencyvit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self.scale = self.head_dim ** -0.5
4949
self.track_dependency_mask = False
5050
self.dependency_mask = None
51-
self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0
51+
self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0
5252

5353
self.head_selector = nn.Linear(dim, num_heads)
5454

@@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
7878
p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1)
7979
p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N)
8080

81-
m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N)
81+
m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N)
8282

8383
q = q * self.scale
8484
attn = q @ k.transpose(-2, -1)
@@ -187,7 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
187187
#x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm
188188

189189
x = self.norm(x)
190-
#x = x * m.transpose(1, 3).squeeze(-1)
190+
x = x * m.transpose(1, 3).squeeze(-1)
191191
return x
192192

193193

0 commit comments

Comments
 (0)