Skip to content

Commit 9992a61

Browse files
committed
attention re-use in lookup vit should use pre-softmax attention matrix
1 parent 4b2c00c commit 9992a61

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.7.3',
9+
version = '1.7.4',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/look_vit.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def forward(
9999
self,
100100
x,
101101
context = None,
102-
return_attn = False,
103-
attn = None
102+
return_qk_sim = False,
103+
qk_sim = None
104104
):
105105
x = self.norm(x)
106106

@@ -119,20 +119,21 @@ def forward(
119119
q, k = tuple(self.split_heads(t) for t in qk)
120120

121121
q = q * self.scale
122-
sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
122+
qk_sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
123123

124-
attn = self.attend(sim)
125-
attn = self.dropout(attn)
126124
else:
127-
assert exists(attn), 'attention matrix must be passed in for reusing previous attention'
125+
assert exists(qk_sim), 'qk sim matrix must be passed in for reusing previous attention'
126+
127+
attn = self.attend(qk_sim)
128+
attn = self.dropout(attn)
128129

129130
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
130131
out = self.to_out(out)
131132

132-
if not return_attn:
133+
if not return_qk_sim:
133134
return out
134135

135-
return out, attn
136+
return out, qk_sim
136137

137138
# LookViT
138139

@@ -228,17 +229,17 @@ def forward(self, img):
228229

229230
# main tokens cross attends (lookup) on the high res tokens
230231

231-
lookup_out, lookup_attn = lookup_cross_attn(tokens, highres_tokens, return_attn = True) # return attention as they reuse the attention matrix
232+
lookup_out, qk_sim = lookup_cross_attn(tokens, highres_tokens, return_qk_sim = True) # return attention as they reuse the attention matrix
232233
tokens = lookup_out + tokens
233234

234235
tokens = attn(tokens) + tokens
235236
tokens = mlp(tokens) + tokens
236237

237238
# attention-reuse
238239

239-
lookup_attn = rearrange(lookup_attn, 'b h i j -> b h j i') # transpose for reverse cross attention
240+
qk_sim = rearrange(qk_sim, 'b h i j -> b h j i') # transpose for reverse cross attention
240241

241-
highres_tokens = highres_attn(highres_tokens, tokens, attn = lookup_attn) + highres_tokens
242+
highres_tokens = highres_attn(highres_tokens, tokens, qk_sim = qk_sim) + highres_tokens
242243
highres_tokens = highres_norm(highres_tokens)
243244

244245
highres_tokens = highres_mlp(highres_tokens) + highres_tokens

0 commit comments

Comments
 (0)