@@ -99,8 +99,8 @@ def forward(
99
99
self ,
100
100
x ,
101
101
context = None ,
102
- return_attn = False ,
103
- attn = None
102
+ return_qk_sim = False ,
103
+ qk_sim = None
104
104
):
105
105
x = self .norm (x )
106
106
@@ -119,20 +119,21 @@ def forward(
119
119
q , k = tuple (self .split_heads (t ) for t in qk )
120
120
121
121
q = q * self .scale
122
- sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
122
+ qk_sim = einsum (q , k , 'b h i d, b h j d -> b h i j' )
123
123
124
- attn = self .attend (sim )
125
- attn = self .dropout (attn )
126
124
else :
127
- assert exists (attn ), 'attention matrix must be passed in for reusing previous attention'
125
+ assert exists (qk_sim ), 'qk sim matrix must be passed in for reusing previous attention'
126
+
127
+ attn = self .attend (qk_sim )
128
+ attn = self .dropout (attn )
128
129
129
130
out = einsum (attn , v , 'b h i j, b h j d -> b h i d' )
130
131
out = self .to_out (out )
131
132
132
- if not return_attn :
133
+ if not return_qk_sim :
133
134
return out
134
135
135
- return out , attn
136
+ return out , qk_sim
136
137
137
138
# LookViT
138
139
@@ -228,17 +229,17 @@ def forward(self, img):
228
229
229
230
# main tokens cross attends (lookup) on the high res tokens
230
231
231
- lookup_out , lookup_attn = lookup_cross_attn (tokens , highres_tokens , return_attn = True ) # return attention as they reuse the attention matrix
232
+ lookup_out , qk_sim = lookup_cross_attn (tokens , highres_tokens , return_qk_sim = True ) # return attention as they reuse the attention matrix
232
233
tokens = lookup_out + tokens
233
234
234
235
tokens = attn (tokens ) + tokens
235
236
tokens = mlp (tokens ) + tokens
236
237
237
238
# attention-reuse
238
239
239
- lookup_attn = rearrange (lookup_attn , 'b h i j -> b h j i' ) # transpose for reverse cross attention
240
+ qk_sim = rearrange (qk_sim , 'b h i j -> b h j i' ) # transpose for reverse cross attention
240
241
241
- highres_tokens = highres_attn (highres_tokens , tokens , attn = lookup_attn ) + highres_tokens
242
+ highres_tokens = highres_attn (highres_tokens , tokens , qk_sim = qk_sim ) + highres_tokens
242
243
highres_tokens = highres_norm (highres_tokens )
243
244
244
245
highres_tokens = highres_mlp (highres_tokens ) + highres_tokens
0 commit comments