Skip to content

Commit c333ea7

Browse files
committed
Further optimize attention caching
1 parent 6ba4cb6 commit c333ea7

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

dalle_pytorch/attention.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88

99
from dalle_pytorch.cache import Cached
1010

11-
from rotary_embedding_torch import apply_rotary_emb
12-
1311
# helpers
1412

1513
def exists(val):
@@ -31,9 +29,18 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
3129
t = t - torch.amax(t, dim = dim, keepdim = True)
3230
return (t * alpha).softmax(dim = dim)
3331

32+
def rotate_half(x):
33+
d = x.shape[-1] // 2
34+
return torch.cat([-x[..., d:], x[..., :d]], dim=-1)
35+
36+
def apply_rotary_emb(freqs, t):
37+
rot_dim = freqs.shape[-1]
38+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
39+
t, t_right = t[..., :rot_dim], t[..., rot_dim:]
40+
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
41+
return torch.cat((t, t_right), dim = -1)
42+
3443
def apply_pos_emb(pos_emb, qkv):
35-
n = qkv[0].shape[-2]
36-
pos_emb = pos_emb[..., :n, :]
3744
return tuple(map(lambda t: apply_rotary_emb(pos_emb, t), qkv))
3845

3946
# classes
@@ -49,7 +56,7 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
4956
self.stable = stable
5057
self.causal = causal
5158

52-
self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
59+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
5360
self.to_out = Cached(nn.Sequential(
5461
nn.Linear(inner_dim, dim),
5562
nn.Dropout(dropout)
@@ -59,42 +66,49 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
5966
b, n, _, h, device = *x.shape, self.heads, x.device
6067
softmax = torch.softmax if not self.stable else stable_softmax
6168

62-
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
63-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
69+
qkv_key = f'{cache_key}_qkv'
70+
if exists(cache) and qkv_key in cache:
71+
qkv = self.to_qkv(x[..., n - 1:n, :]).chunk(3, dim = -1)
72+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
6473

65-
if exists(rotary_pos_emb):
66-
q, k, v = apply_pos_emb(rotary_pos_emb, (q, k, v))
74+
if exists(rotary_pos_emb):
75+
q, k, v = apply_pos_emb(rotary_pos_emb[..., n - 1:n, :], (q, k, v))
6776

68-
q = q * self.scale
77+
q *= self.scale
6978

70-
mask_value = max_neg_value(q)
71-
dots_key = f'{cache_key}_dots'
72-
if exists(cache) and dots_key in cache:
73-
topleft = cache[dots_key]
74-
top = F.pad(topleft, (0, 1), value=mask_value)
75-
bottom = q[..., n - 1:n, :] @ k.swapaxes(-1, -2)
76-
dots = torch.cat([top, bottom], dim=-2)
79+
k_top, v_top = cache[qkv_key]
80+
k = torch.cat([k_top, k], dim=-2)
81+
v = torch.cat([v_top, v], dim=-2)
7782
else:
78-
dots = q @ k.swapaxes(-1, -2)
83+
qkv = self.to_qkv(x).chunk(3, dim = -1)
84+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
85+
86+
if exists(rotary_pos_emb):
87+
q, k, v = apply_pos_emb(rotary_pos_emb[..., :n, :], (q, k, v))
88+
89+
q *= self.scale
7990
if exists(cache):
80-
cache[dots_key] = dots
91+
cache[qkv_key] = (k, v)
8192

82-
if exists(mask):
83-
mask = rearrange(mask, 'b j -> b () () j')
84-
dots.masked_fill_(~mask, mask_value)
85-
del mask
93+
# mask_value = max_neg_value(q)
94+
dots = q @ k.swapaxes(-1, -2)
8695

87-
if self.causal:
88-
i, j = dots.shape[-2:]
89-
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
90-
dots.masked_fill_(mask, mask_value)
96+
# if exists(mask): # TODO:
97+
# mask = rearrange(mask, 'b j -> b () () j')
98+
# dots.masked_fill_(~mask, mask_value)
99+
# del mask
100+
101+
# if self.causal: # TODO:
102+
# i, j = dots.shape[-2:]
103+
# mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
104+
# dots.masked_fill_(mask, mask_value)
91105

92106
attn = softmax(dots, dim=-1)
93107

94108
out_key = f'{cache_key}_out'
95109
if exists(cache) and out_key in cache:
96110
top = cache[out_key]
97-
bottom = attn[..., n - 1:n, :] @ v
111+
bottom = attn @ v
98112
out = torch.cat([top, bottom], dim=-2)
99113
else:
100114
out = attn @ v

dalle_pytorch/cache.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ def __init__(self, fn):
1212
self.fn = fn
1313

1414
def forward(self, x, *, cache=None, cache_key=None, **kwargs):
15-
assert exists(cache) and exists(cache_key) # dbg
16-
1715
if exists(cache) and cache_key in cache:
1816
prefix = cache[cache_key]
1917
assert prefix.shape[1] + 1 == x.shape[1], f'{prefix.shape[1]} {x.shape[1]} {cache_key} {cache.keys()}' # TODO: Change to <= for prod

0 commit comments

Comments
 (0)