Skip to content

Commit 8e8dea8

Browse files
committed
Revert excess changes in attentions
1 parent 2b77018 commit 8e8dea8

File tree

4 files changed

+36
-58
lines changed

4 files changed

+36
-58
lines changed

dalle_pytorch/attention.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import torch.nn.functional as F
77
from einops import rearrange, repeat
88

9-
from dalle_pytorch.cache import Cached
10-
119
# helpers
1210

1311
def exists(val):
@@ -41,6 +39,8 @@ def apply_rotary_emb(freqs, t):
4139
return torch.cat((t, t_right), dim = -1)
4240

4341
def apply_pos_emb(pos_emb, qkv):
42+
n = qkv[0].shape[-2]
43+
pos_emb = pos_emb[..., :n, :]
4444
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
4545

4646
# classes
@@ -65,30 +65,24 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
6565
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
6666
b, n, _, h, device = *x.shape, self.heads, x.device
6767
softmax = torch.softmax if not self.stable else stable_softmax
68+
using_cache = exists(cache) and cache_key in cache
6869

69-
qkv_key = f'{cache_key}_qkv'
70-
if exists(cache) and qkv_key in cache:
71-
qkv = self.to_qkv(x).chunk(3, dim = -1)
72-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
70+
qkv = self.to_qkv(x).chunk(3, dim = -1)
71+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
7372

74-
if exists(rotary_pos_emb):
75-
q, k, v = apply_pos_emb(rotary_pos_emb[..., n - 1:n, :], (q, k, v)) # FIXME: Fix rotary index here
73+
if exists(rotary_pos_emb):
74+
if using_cache:
75+
rotary_pos_emb = rotary_pos_emb[..., n - 1:, :] # FIXME: Fix rotary index here
76+
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
7677

77-
q *= self.scale
78+
q = q * self.scale
7879

79-
k_top, v_top = cache[qkv_key]
80+
if using_cache:
81+
k_top, v_top = cache[cache_key]
8082
k = torch.cat([k_top, k], dim=-2)
8183
v = torch.cat([v_top, v], dim=-2)
82-
else:
83-
qkv = self.to_qkv(x).chunk(3, dim = -1)
84-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
85-
86-
if exists(rotary_pos_emb):
87-
q, k, v = apply_pos_emb(rotary_pos_emb[..., :n, :], (q, k, v))
88-
89-
q *= self.scale
9084
if exists(cache):
91-
cache[qkv_key] = (k, v)
85+
cache[cache_key] = k, v
9286

9387
dots = q @ k.swapaxes(-1, -2)
9488
mask_value = max_neg_value(dots)
@@ -98,17 +92,16 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9892
dots.masked_fill_(~mask, mask_value)
9993
del mask
10094

101-
# if self.causal: # TODO:
102-
# i, j = dots.shape[-2:]
103-
# mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
104-
# dots.masked_fill_(mask, mask_value)
95+
if self.causal and not using_cache: # causality is naturally enforced if we run the cached inference
96+
i, j = dots.shape[-2:]
97+
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
98+
dots.masked_fill_(mask, mask_value)
10599

106100
attn = softmax(dots, dim=-1)
107101

108102
out = attn @ v
109103
out = rearrange(out, 'b h n d -> b n (h d)')
110104
out = self.to_out(out)
111-
112105
return out
113106

114107
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
@@ -128,14 +121,14 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
128121

129122
self.stable = stable
130123

131-
self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
124+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
132125

133-
self.to_out = Cached(nn.Sequential(
126+
self.to_out = nn.Sequential(
134127
nn.Linear(inner_dim, dim),
135128
nn.Dropout(dropout)
136-
))
129+
)
137130

138-
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
131+
def forward(self, x, mask = None, rotary_pos_emb = None):
139132
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
140133
softmax = torch.softmax if not self.stable else stable_softmax
141134

@@ -152,7 +145,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
152145

153146
# derive query / keys / values
154147

155-
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
148+
qkv = self.to_qkv(x).chunk(3, dim = -1)
156149
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
157150

158151
if exists(rotary_pos_emb):
@@ -229,7 +222,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
229222
out = torch.cat((out_text, out_image), dim = 1)
230223

231224
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
232-
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
225+
out = self.to_out(out)
233226
return out[:, :n]
234227

235228
# sparse axial causal attention
@@ -248,14 +241,14 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
248241

249242
self.stable = stable
250243

251-
self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
244+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
252245

253-
self.to_out = Cached(nn.Sequential(
246+
self.to_out = nn.Sequential(
254247
nn.Linear(inner_dim, dim),
255248
nn.Dropout(dropout)
256-
))
249+
)
257250

258-
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
251+
def forward(self, x, mask = None, rotary_pos_emb = None):
259252
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
260253
softmax = torch.softmax if not self.stable else stable_softmax
261254

@@ -272,7 +265,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
272265

273266
# derive queries / keys / values
274267

275-
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
268+
qkv = self.to_qkv(x).chunk(3, dim = -1)
276269
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
277270

278271
if exists(rotary_pos_emb):
@@ -284,15 +277,15 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
284277

285278
# text attention
286279

287-
dots_text = q_text @ k_text.swapaxes(-1, -2)
280+
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
288281
mask_value = max_neg_value(dots_text)
289282

290283
i, j = dots_text.shape[-2:]
291284
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
292285
dots_text.masked_fill_(text_causal_mask, mask_value)
293286

294287
attn_text = softmax(dots_text, dim = -1)
295-
out_text = attn_text @ v_text
288+
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
296289

297290
# image attention
298291

@@ -305,8 +298,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
305298

306299
# similarity
307300

308-
dots_image_to_image = q_img @ k_img.swapaxes(-1, -2)
309-
dots_image_to_text = q_img @ k_text[:, None].swapaxes(-1, -2)
301+
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
302+
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
310303

311304
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
312305

@@ -329,8 +322,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
329322

330323
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
331324

332-
out_image_to_image = attn_image_to_image @ v_img
333-
out_image_to_text = attn_image_to_text @ v_text[:, None]
325+
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
326+
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
334327

335328
out_image = out_image_to_image + out_image_to_text
336329

@@ -343,7 +336,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
343336
out = torch.cat((out_text, out_image), dim = 1)
344337

345338
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
346-
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
339+
out = self.to_out(out)
347340
return out[:, :n]
348341

349342
# microsoft sparse attention CUDA kernel

dalle_pytorch/cache.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,5 @@
1-
import torch
21
import torch.nn as nn
32

4-
# helpers
5-
6-
def exists(val):
7-
return val is not None
8-
9-
class Cached(nn.Module):
10-
def __init__(self, fn):
11-
super().__init__()
12-
self.fn = fn
13-
14-
def forward(self, x, *, cache=None, cache_key=None, **kwargs):
15-
return self.fn(x, **kwargs)
16-
173
class FixCacheKey(nn.Module):
184
def __init__(self, cache_key, fn):
195
super().__init__()

dalle_pytorch/dalle_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from einops import rearrange
99

1010
from dalle_pytorch import distributed_utils
11-
from dalle_pytorch.cache import Cached, FixCacheKey
1211
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
1312
from dalle_pytorch.transformer import Transformer, DivideMax
1413

dalle_pytorch/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
1111
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
12-
from dalle_pytorch.cache import Cached, FixCacheKey
12+
from dalle_pytorch.cache import FixCacheKey
1313

1414
from rotary_embedding_torch import RotaryEmbedding, broadcat
1515
from g_mlp_pytorch import gMLPBlock

0 commit comments

Comments
 (0)