Skip to content

Commit 1c0b10c

Browse files
committed
Update dependencyvit.py
1 parent 68c8aaa commit 1c0b10c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

timm/models/dependencyvit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te
7777

7878
p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1)
7979
p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N)
80-
80+
print(m)
8181
m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N)
8282

8383
q = q * self.scale
@@ -179,7 +179,8 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
179179
x = self.patch_drop(x)
180180
x = self.norm_pre(x)
181181
B, N, _ = x.shape
182-
m = torch.ones(B, 1, 1, N).to(x)
182+
#m = torch.ones(B, 1, 1, N).to(x)
183+
m = torch.Tensor([1]).to(x)
183184
if self.grad_checkpointing and not torch.jit.is_scripting():
184185
x, m = checkpoint_seq(self.blocks, (x, m))
185186
else:

0 commit comments

Comments
 (0)