|
8 | 8 | from einops import rearrange
|
9 | 9 |
|
10 | 10 | from dalle_pytorch import distributed_utils
|
| 11 | +from dalle_pytorch.cache import Cached, FixCacheKey |
11 | 12 | from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE
|
12 | 13 | from dalle_pytorch.transformer import Transformer, DivideMax
|
13 | 14 |
|
@@ -400,12 +401,12 @@ def __init__(
|
400 | 401 |
|
401 | 402 | self.to_logits = nn.Sequential(
|
402 | 403 | nn.LayerNorm(dim),
|
403 |
| - nn.Linear(dim, self.total_tokens), |
| 404 | + FixCacheKey('to_logits_linear', Cached(nn.Linear(dim, self.total_tokens))), |
404 | 405 | )
|
405 | 406 |
|
406 | 407 | if share_input_output_emb:
|
407 |
| - self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens) |
408 |
| - self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens) |
| 408 | + self.text_emb = SharedEmbedding(self.to_logits[1].fn.fn, 0, num_text_tokens) |
| 409 | + self.image_emb = SharedEmbedding(self.to_logits[1].fn.fn, num_text_tokens, total_tokens) |
409 | 410 | else:
|
410 | 411 | self.text_emb = nn.Embedding(num_text_tokens, dim)
|
411 | 412 | self.image_emb = nn.Embedding(num_image_tokens, dim)
|
@@ -591,7 +592,8 @@ def forward(
|
591 | 592 | if self.stable:
|
592 | 593 | out = self.norm_by_max(out)
|
593 | 594 |
|
594 |
| - logits = self.to_logits(out) |
| 595 | + out = self.to_logits[0](out) |
| 596 | + logits = self.to_logits[1](out, cache=cache) |
595 | 597 |
|
596 | 598 | # mask logits to make sure text predicts text (except last token), and image predicts image
|
597 | 599 |
|
|
0 commit comments