Skip to content

Commit 2b77018

Browse files
committed
Don't cache MLPs since we can just pass only last item
1 parent 1fd45ca commit 2b77018

File tree

4 files changed

+19
-34
lines changed

4 files changed

+19
-34
lines changed

dalle_pytorch/attention.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,22 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
5757
self.causal = causal
5858

5959
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
60-
self.to_out = Cached(nn.Sequential(
60+
self.to_out = nn.Sequential(
6161
nn.Linear(inner_dim, dim),
6262
nn.Dropout(dropout)
63-
))
63+
)
6464

6565
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
6666
b, n, _, h, device = *x.shape, self.heads, x.device
6767
softmax = torch.softmax if not self.stable else stable_softmax
6868

6969
qkv_key = f'{cache_key}_qkv'
7070
if exists(cache) and qkv_key in cache:
71-
qkv = self.to_qkv(x[..., n - 1:n, :]).chunk(3, dim = -1)
71+
qkv = self.to_qkv(x).chunk(3, dim = -1)
7272
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
7373

7474
if exists(rotary_pos_emb):
75-
q, k, v = apply_pos_emb(rotary_pos_emb[..., n - 1:n, :], (q, k, v))
75+
q, k, v = apply_pos_emb(rotary_pos_emb[..., n - 1:n, :], (q, k, v)) # FIXME: Fix rotary index here
7676

7777
q *= self.scale
7878

@@ -105,18 +105,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
105105

106106
attn = softmax(dots, dim=-1)
107107

108-
out_key = f'{cache_key}_out'
109-
if exists(cache) and out_key in cache:
110-
top = cache[out_key]
111-
bottom = attn @ v
112-
out = torch.cat([top, bottom], dim=-2)
113-
else:
114-
out = attn @ v
115-
if exists(cache):
116-
cache[out_key] = out
117-
108+
out = attn @ v
118109
out = rearrange(out, 'b h n d -> b n (h d)')
119-
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out_proj')
110+
out = self.to_out(out)
111+
120112
return out
121113

122114
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation

dalle_pytorch/cache.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,7 @@ def __init__(self, fn):
1212
self.fn = fn
1313

1414
def forward(self, x, *, cache=None, cache_key=None, **kwargs):
15-
if exists(cache) and cache_key in cache:
16-
prefix = cache[cache_key]
17-
assert prefix.shape[1] + 1 == x.shape[1], f'{prefix.shape[1]} {x.shape[1]} {cache_key} {cache.keys()}' # TODO: Change to <= for prod
18-
suffix = self.fn(x[:, prefix.shape[1]:, :], **kwargs)
19-
out = torch.cat([prefix, suffix], dim=1)
20-
else:
21-
out = self.fn(x, **kwargs)
22-
23-
if exists(cache):
24-
cache[cache_key] = out
25-
return out
15+
return self.fn(x, **kwargs)
2616

2717
class FixCacheKey(nn.Module):
2818
def __init__(self, cache_key, fn):

dalle_pytorch/dalle_pytorch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,12 @@ def __init__(
401401

402402
self.to_logits = nn.Sequential(
403403
nn.LayerNorm(dim),
404-
FixCacheKey('to_logits_linear', Cached(nn.Linear(dim, self.total_tokens))),
404+
nn.Linear(dim, self.total_tokens),
405405
)
406406

407407
if share_input_output_emb:
408-
self.text_emb = SharedEmbedding(self.to_logits[1].fn.fn, 0, num_text_tokens)
409-
self.image_emb = SharedEmbedding(self.to_logits[1].fn.fn, num_text_tokens, total_tokens)
408+
self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
409+
self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
410410
else:
411411
self.text_emb = nn.Embedding(num_text_tokens, dim)
412412
self.image_emb = nn.Embedding(num_image_tokens, dim)
@@ -587,17 +587,22 @@ def forward(
587587
alpha = 0.1
588588
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
589589

590+
if cache is not None and 'decoding' in cache:
591+
tokens = tokens[:, -1:]
590592
out = self.transformer(tokens, cache=cache)
591593

592594
if self.stable:
593595
out = self.norm_by_max(out)
594596

595-
out = self.to_logits[0](out)
596-
logits = self.to_logits[1](out, cache=cache)
597+
logits = self.to_logits(out)
597598

598599
# mask logits to make sure text predicts text (except last token), and image predicts image
599600

600601
logits_mask = self.logits_mask[:, :seq_len]
602+
if cache is not None:
603+
if 'decoding' in cache:
604+
logits_mask = logits_mask[:, -1:]
605+
cache['decoding'] = True
601606
max_neg_value = -torch.finfo(logits.dtype).max
602607
logits.masked_fill_(logits_mask, max_neg_value)
603608

dalle_pytorch/transformer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def __init__(
201201
shared_ff_layers[ff_id] = ff
202202

203203
attn = FixCacheKey(f'attn_{ind}', attn)
204-
ff = FixCacheKey(f'ff_{ind}', Cached(ff))
205204

206205
if shift_tokens:
207206
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
@@ -213,9 +212,8 @@ def __init__(
213212

214213
execute_type = ReversibleSequence if reversible else SequentialSequence
215214
route_attn = ((True, False),) * depth
216-
route_all = ((True, True),) * depth
217215
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn,
218-
'cache': route_all}
216+
'cache': route_attn}
219217

220218
self.layers = execute_type(layers, args_route = attn_route_map)
221219

0 commit comments

Comments
 (0)