Skip to content

Commit 9febeec

Browse files
committed
Save and use cache['num_cached']
1 parent df89951 commit 9febeec

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

dalle_pytorch/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
7272

7373
if exists(rotary_pos_emb):
7474
if using_cache:
75-
rotary_pos_emb = rotary_pos_emb[..., n - 1:, :] # FIXME: Fix rotary index here
75+
rotary_pos_emb = rotary_pos_emb[..., cache['num_cached']:, :]
7676
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
7777

7878
q = q * self.scale
@@ -92,7 +92,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9292
dots.masked_fill_(~mask, mask_value)
9393
del mask
9494

95-
if self.causal and not using_cache: # causality is naturally enforced if we run the cached inference
95+
if self.causal and not using_cache: # causality is naturally enforced for the cached inference
9696
i, j = dots.shape[-2:]
9797
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
9898
dots.masked_fill_(mask, mask_value)

dalle_pytorch/dalle_pytorch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def forward(
586586
alpha = 0.1
587587
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
588588

589-
if cache is not None and 'decoding' in cache:
589+
if exists(cache) and cache.get('num_cached'):
590590
tokens = tokens[:, -1:]
591591
out = self.transformer(tokens, cache=cache)
592592

@@ -598,13 +598,14 @@ def forward(
598598
# mask logits to make sure text predicts text (except last token), and image predicts image
599599

600600
logits_mask = self.logits_mask[:, :seq_len]
601-
if cache is not None:
602-
if 'decoding' in cache:
603-
logits_mask = logits_mask[:, -1:]
604-
cache['decoding'] = True
601+
if exists(cache) and cache.get('num_cached'):
602+
logits_mask = logits_mask[:, -1:]
605603
max_neg_value = -torch.finfo(logits.dtype).max
606604
logits.masked_fill_(logits_mask, max_neg_value)
607605

606+
if exists(cache):
607+
cache['num_cached'] = cache.get('num_cached', 0) + logits.shape[1]
608+
608609
if not return_loss:
609610
return logits
610611

0 commit comments

Comments
 (0)