Skip to content

Commit 1cd8e20

Browse files
committed
Add FFN caching
1 parent c9f462a commit 1cd8e20

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,13 @@ def generate_images(
503503
indices = indices[:, :num_img_tokens]
504504
out = torch.cat((out, indices), dim = -1)
505505

506+
cache = {}
506507
for cur_len in range(out.shape[1], total_len):
507508
is_image = cur_len >= text_seq_len
508509

509510
text, image = out[:, :text_seq_len], out[:, text_seq_len:]
510511

511-
logits = self(text, image, mask = mask)[:, -1, :]
512+
logits = self(text, image, mask = mask, cache = cache)[:, -1, :]
512513

513514
filtered_logits = top_k(logits, thres = filter_thres)
514515
probs = F.softmax(filtered_logits / temperature, dim = -1)
@@ -536,6 +537,7 @@ def forward(
536537
text,
537538
image = None,
538539
mask = None,
540+
cache = None,
539541
return_loss = False
540542
):
541543
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
@@ -584,7 +586,7 @@ def forward(
584586
alpha = 0.1
585587
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
586588

587-
out = self.transformer(tokens)
589+
out = self.transformer(tokens, cache=cache)
588590

589591
if self.stable:
590592
out = self.norm_by_max(out)

dalle_pytorch/transformer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ def __init__(self, dim, dropout = 0., mult = 4.):
8686
def forward(self, x):
8787
return self.net(x)
8888

89+
class Cached(nn.Module):
90+
def __init__(self, key, fn):
91+
super().__init__()
92+
self.key = key
93+
self.fn = fn
94+
95+
def forward(self, x, cache=None, **kwargs):
96+
if exists(cache) and self.key in cache:
97+
prefix = cache[self.key]
98+
suffix = self.fn(x[:, prefix.shape[1]:, :], **kwargs)
99+
out = torch.cat([prefix, suffix], dim=1)
100+
else:
101+
out = self.fn(x, **kwargs)
102+
103+
if exists(cache):
104+
cache[self.key] = out
105+
return out
106+
89107
# token shift classes
90108

91109
class PreShiftToken(nn.Module):
@@ -199,6 +217,8 @@ def __init__(
199217
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
200218
shared_ff_layers[ff_id] = ff
201219

220+
ff = Cached(f'ff_{ind}', ff)
221+
202222
if shift_tokens:
203223
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
204224

@@ -209,7 +229,9 @@ def __init__(
209229

210230
execute_type = ReversibleSequence if reversible else SequentialSequence
211231
route_attn = ((True, False),) * depth
212-
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn}
232+
route_ffn = ((False, True),) * depth
233+
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn,
234+
'cache': route_ffn}
213235

214236
self.layers = execute_type(layers, args_route = attn_route_map)
215237

0 commit comments

Comments
 (0)