Skip to content

Commit 6ba4cb6

Browse files
committed
Cache pre-logits MLP
1 parent 112ea05 commit 6ba4cb6

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from einops import rearrange
99

1010
from dalle_pytorch import distributed_utils
11+
from dalle_pytorch.cache import Cached, FixCacheKey
1112
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
1213
from dalle_pytorch.transformer import Transformer, DivideMax
1314

@@ -400,12 +401,12 @@ def __init__(
400401

401402
self.to_logits = nn.Sequential(
402403
nn.LayerNorm(dim),
403-
nn.Linear(dim, self.total_tokens),
404+
FixCacheKey('to_logits_linear', Cached(nn.Linear(dim, self.total_tokens))),
404405
)
405406

406407
if share_input_output_emb:
407-
self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens)
408-
self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens)
408+
self.text_emb = SharedEmbedding(self.to_logits[1].fn.fn, 0, num_text_tokens)
409+
self.image_emb = SharedEmbedding(self.to_logits[1].fn.fn, num_text_tokens, total_tokens)
409410
else:
410411
self.text_emb = nn.Embedding(num_text_tokens, dim)
411412
self.image_emb = nn.Embedding(num_image_tokens, dim)
@@ -591,7 +592,8 @@ def forward(
591592
if self.stable:
592593
out = self.norm_by_max(out)
593594

594-
logits = self.to_logits(out)
595+
out = self.to_logits[0](out)
596+
logits = self.to_logits[1](out, cache=cache)
595597

596598
# mask logits to make sure text predicts text (except last token), and image predicts image
597599

0 commit comments

Comments
 (0)