Skip to content

Commit 112ea05

Browse files
committed
Remove debug outputs
1 parent 2603776 commit 112ea05

File tree

1 file changed

+1
-18
lines changed

1 file changed

+1
-18
lines changed

dalle_pytorch/attention.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
7272
if exists(cache) and dots_key in cache:
7373
topleft = cache[dots_key]
7474
top = F.pad(topleft, (0, 1), value=mask_value)
75-
bottom = q[..., n - 1:, :] @ k.swapaxes(-1, -2)
75+
bottom = q[..., n - 1:n, :] @ k.swapaxes(-1, -2)
7676
dots = torch.cat([top, bottom], dim=-2)
7777
else:
7878
dots = q @ k.swapaxes(-1, -2)
@@ -94,10 +94,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9494
out_key = f'{cache_key}_out'
9595
if exists(cache) and out_key in cache:
9696
top = cache[out_key]
97-
assert top.shape[-2] == n - 1
98-
9997
bottom = attn[..., n - 1:n, :] @ v
100-
10198
out = torch.cat([top, bottom], dim=-2)
10299
else:
103100
out = attn @ v
@@ -231,8 +228,6 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
231228

232229
# sparse axial causal attention
233230

234-
from time import time
235-
236231
class SparseAxialCausalAttention(nn.Module):
237232
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
238233
super().__init__()
@@ -271,10 +266,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
271266

272267
# derive queries / keys / values
273268

274-
t = time()
275269
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
276-
print(f'Time 1: {time() - t:.5f} sec')
277-
t = time()
278270
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
279271

280272
if exists(rotary_pos_emb):
@@ -286,7 +278,6 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
286278

287279
# text attention
288280

289-
print('shapes 1:', q_text.shape, k_text.swapaxes(-1, -2).shape)
290281
dots_text = q_text @ k_text.swapaxes(-1, -2)
291282
mask_value = max_neg_value(dots_text)
292283

@@ -295,7 +286,6 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
295286
dots_text.masked_fill_(text_causal_mask, mask_value)
296287

297288
attn_text = softmax(dots_text, dim = -1)
298-
print('shapes 2:', attn_text.shape, v_text.shape)
299289
out_text = attn_text @ v_text
300290

301291
# image attention
@@ -309,9 +299,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
309299

310300
# similarity
311301

312-
print('shapes 3:', q_img.shape, k_img.swapaxes(-1, -2).shape)
313302
dots_image_to_image = q_img @ k_img.swapaxes(-1, -2)
314-
print('shapes 4:', q_img.shape, k_text[:, None].swapaxes(-1, -2).shape)
315303
dots_image_to_text = q_img @ k_text[:, None].swapaxes(-1, -2)
316304

317305
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
@@ -335,9 +323,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
335323

336324
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
337325

338-
print('shapes 5:', attn_image_to_image.shape, v_img.shape)
339326
out_image_to_image = attn_image_to_image @ v_img
340-
print('shapes 6:', attn_image_to_text.shape, v_text[:, None].shape)
341327
out_image_to_text = attn_image_to_text @ v_text[:, None]
342328

343329
out_image = out_image_to_image + out_image_to_text
@@ -351,10 +337,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
351337
out = torch.cat((out_text, out_image), dim = 1)
352338

353339
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
354-
print(f'Time 2: {time() - t:.5f} sec')
355-
t = time()
356340
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
357-
print(f'Time 3: {time() - t:.5f} sec\n')
358341
return out[:, :n]
359342

360343
# microsoft sparse attention CUDA kernel

0 commit comments

Comments
 (0)