Skip to content

Commit 059fe1b

Browse files
committed
Save the current offset in cache
1 parent df89951 commit 059fe1b

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

dalle_pytorch/attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,17 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
6565
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
6666
b, n, _, h, device = *x.shape, self.heads, x.device
6767
softmax = torch.softmax if not self.stable else stable_softmax
68-
using_cache = exists(cache) and cache_key in cache
68+
offset = cache.get('offset', 0) if exists(cache) else 0
6969

7070
qkv = self.to_qkv(x).chunk(3, dim = -1)
7171
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
7272

7373
if exists(rotary_pos_emb):
74-
if using_cache:
75-
rotary_pos_emb = rotary_pos_emb[..., n - 1:, :] # FIXME: Fix rotary index here
76-
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
74+
q, k, v = apply_pos_emb(rotary_pos_emb[..., offset:, :], (q, k, v))
7775

7876
q = q * self.scale
7977

80-
if using_cache:
78+
if offset > 0:
8179
k_top, v_top = cache[cache_key]
8280
k = torch.cat([k_top, k], dim=-2)
8381
v = torch.cat([v_top, v], dim=-2)
@@ -92,7 +90,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
9290
dots.masked_fill_(~mask, mask_value)
9391
del mask
9492

95-
if self.causal and not using_cache: # causality is naturally enforced if we run the cached inference
93+
if self.causal and offset == 0: # causality is naturally enforced for the cached inference
9694
i, j = dots.shape[-2:]
9795
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
9896
dots.masked_fill_(mask, mask_value)

dalle_pytorch/dalle_pytorch.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def forward(
586586
alpha = 0.1
587587
tokens = tokens * alpha + tokens.detach() * (1 - alpha)
588588

589-
if cache is not None and 'decoding' in cache:
589+
if exists(cache) and cache.get('offset'):
590590
tokens = tokens[:, -1:]
591591
out = self.transformer(tokens, cache=cache)
592592

@@ -598,13 +598,14 @@ def forward(
598598
# mask logits to make sure text predicts text (except last token), and image predicts image
599599

600600
logits_mask = self.logits_mask[:, :seq_len]
601-
if cache is not None:
602-
if 'decoding' in cache:
603-
logits_mask = logits_mask[:, -1:]
604-
cache['decoding'] = True
601+
if exists(cache) and cache.get('offset'):
602+
logits_mask = logits_mask[:, -1:]
605603
max_neg_value = -torch.finfo(logits.dtype).max
606604
logits.masked_fill_(logits_mask, max_neg_value)
607605

606+
if exists(cache):
607+
cache['offset'] = cache.get('offset', 0) + logits.shape[1]
608+
608609
if not return_loss:
609610
return logits
610611

0 commit comments

Comments
 (0)