Skip to content

Commit 2603776

Browse files
committed
Cache full Attention
1 parent 4d431ac commit 2603776

File tree

2 files changed

+44
-19
lines changed

2 files changed

+44
-19
lines changed

dalle_pytorch/attention.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,35 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
4949
self.stable = stable
5050
self.causal = causal
5151

52-
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
53-
self.to_out = nn.Sequential(
52+
self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
53+
self.to_out = Cached(nn.Sequential(
5454
nn.Linear(inner_dim, dim),
5555
nn.Dropout(dropout)
56-
)
56+
))
5757

58-
def forward(self, x, mask = None, rotary_pos_emb = None):
58+
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
5959
b, n, _, h, device = *x.shape, self.heads, x.device
6060
softmax = torch.softmax if not self.stable else stable_softmax
6161

62-
qkv = self.to_qkv(x).chunk(3, dim = -1)
62+
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
6363
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
6464

6565
if exists(rotary_pos_emb):
6666
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
6767

6868
q = q * self.scale
6969

70-
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k)
71-
mask_value = max_neg_value(dots)
70+
mask_value = max_neg_value(q)
71+
dots_key = f'{cache_key}_dots'
72+
if exists(cache) and dots_key in cache:
73+
topleft = cache[dots_key]
74+
top = F.pad(topleft, (0, 1), value=mask_value)
75+
bottom = q[..., n - 1:, :] @ k.swapaxes(-1, -2)
76+
dots = torch.cat([top, bottom], dim=-2)
77+
else:
78+
dots = q @ k.swapaxes(-1, -2)
79+
if exists(cache):
80+
cache[dots_key] = dots
7281

7382
if exists(mask):
7483
mask = rearrange(mask, 'b j -> b () () j')
@@ -82,9 +91,21 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
8291

8392
attn = softmax(dots, dim=-1)
8493

85-
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
94+
out_key = f'{cache_key}_out'
95+
if exists(cache) and out_key in cache:
96+
top = cache[out_key]
97+
assert top.shape[-2] == n - 1
98+
99+
bottom = attn[..., n - 1:n, :] @ v
100+
101+
out = torch.cat([top, bottom], dim=-2)
102+
else:
103+
out = attn @ v
104+
if exists(cache):
105+
cache[out_key] = out
106+
86107
out = rearrange(out, 'b h n d -> b n (h d)')
87-
out = self.to_out(out)
108+
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out_proj')
88109
return out
89110

90111
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
@@ -265,15 +286,17 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
265286

266287
# text attention
267288

268-
dots_text = einsum('b i d, b j d -> b i j', q_text, k_text)
289+
print('shapes 1:', q_text.shape, k_text.swapaxes(-1, -2).shape)
290+
dots_text = q_text @ k_text.swapaxes(-1, -2)
269291
mask_value = max_neg_value(dots_text)
270292

271293
i, j = dots_text.shape[-2:]
272294
text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
273295
dots_text.masked_fill_(text_causal_mask, mask_value)
274296

275297
attn_text = softmax(dots_text, dim = -1)
276-
out_text = einsum('b i j, b j d -> b i d', attn_text, v_text)
298+
print('shapes 2:', attn_text.shape, v_text.shape)
299+
out_text = attn_text @ v_text
277300

278301
# image attention
279302

@@ -286,8 +309,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
286309

287310
# similarity
288311

289-
dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img)
290-
dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text)
312+
print('shapes 3:', q_img.shape, k_img.swapaxes(-1, -2).shape)
313+
dots_image_to_image = q_img @ k_img.swapaxes(-1, -2)
314+
print('shapes 4:', q_img.shape, k_text[:, None].swapaxes(-1, -2).shape)
315+
dots_image_to_text = q_img @ k_text[:, None].swapaxes(-1, -2)
291316

292317
dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1)
293318

@@ -310,8 +335,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
310335

311336
attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:]
312337

313-
out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img)
314-
out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text)
338+
print('shapes 5:', attn_image_to_image.shape, v_img.shape)
339+
out_image_to_image = attn_image_to_image @ v_img
340+
print('shapes 6:', attn_image_to_text.shape, v_text[:, None].shape)
341+
out_image_to_text = attn_image_to_text @ v_text[:, None]
315342

316343
out_image = out_image_to_image + out_image_to_text
317344

dalle_pytorch/cache.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@ def __init__(self, fn):
1212
self.fn = fn
1313

1414
def forward(self, x, *, cache=None, cache_key=None, **kwargs):
15-
assert exists(cache) and exists(cache_key)
16-
17-
return self.fn(x, **kwargs) # dbg
15+
assert exists(cache) and exists(cache_key) # dbg
1816

1917
if exists(cache) and cache_key in cache:
2018
prefix = cache[cache_key]
21-
assert prefix.shape[1] == x.shape[1] or prefix.shape[1] + 1 == x.shape[1], f'{prefix.shape[1]} {x.shape[1]} {cache_key} {cache.keys()}' # TODO: Change to <= for prod
19+
assert prefix.shape[1] + 1 == x.shape[1], f'{prefix.shape[1]} {x.shape[1]} {cache_key} {cache.keys()}' # TODO: Change to <= for prod
2220
suffix = self.fn(x[:, prefix.shape[1]:, :], **kwargs)
2321
out = torch.cat([prefix, suffix], dim=1)
2422
else:

0 commit comments

Comments
 (0)