Skip to content

Commit 4d431ac

Browse files
committed
Cache to_qkv and to_out in sparse attn, add debug prints
1 parent 1cd8e20 commit 4d431ac

File tree

3 files changed

+65
-33
lines changed

3 files changed

+65
-33
lines changed

dalle_pytorch/attention.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.nn.functional as F
77
from einops import rearrange, repeat
88

9+
from dalle_pytorch.cache import Cached
10+
911
from rotary_embedding_torch import apply_rotary_emb
1012

1113
# helpers
@@ -102,14 +104,14 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
102104

103105
self.stable = stable
104106

105-
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
107+
self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
106108

107-
self.to_out = nn.Sequential(
109+
self.to_out = Cached(nn.Sequential(
108110
nn.Linear(inner_dim, dim),
109111
nn.Dropout(dropout)
110-
)
112+
))
111113

112-
def forward(self, x, mask = None, rotary_pos_emb = None):
114+
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
113115
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
114116
softmax = torch.softmax if not self.stable else stable_softmax
115117

@@ -126,7 +128,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
126128

127129
# derive query / keys / values
128130

129-
qkv = self.to_qkv(x).chunk(3, dim = -1)
131+
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
130132
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
131133

132134
if exists(rotary_pos_emb):
@@ -203,11 +205,13 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
203205
out = torch.cat((out_text, out_image), dim = 1)
204206

205207
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
206-
out = self.to_out(out)
208+
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
207209
return out[:, :n]
208210

209211
# sparse axial causal attention
210212

213+
from time import time
214+
211215
class SparseAxialCausalAttention(nn.Module):
212216
def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, dropout = 0., stable = False, **kwargs):
213217
super().__init__()
@@ -222,14 +226,14 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
222226

223227
self.stable = stable
224228

225-
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
229+
self.to_qkv = Cached(nn.Linear(dim, inner_dim * 3, bias = False))
226230

227-
self.to_out = nn.Sequential(
231+
self.to_out = Cached(nn.Sequential(
228232
nn.Linear(inner_dim, dim),
229233
nn.Dropout(dropout)
230-
)
234+
))
231235

232-
def forward(self, x, mask = None, rotary_pos_emb = None):
236+
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
233237
b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device
234238
softmax = torch.softmax if not self.stable else stable_softmax
235239

@@ -246,7 +250,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
246250

247251
# derive queries / keys / values
248252

249-
qkv = self.to_qkv(x).chunk(3, dim = -1)
253+
t = time()
254+
qkv = self.to_qkv(x, cache = cache, cache_key = f'{cache_key}_qkv').chunk(3, dim = -1)
255+
print(f'Time 1: {time() - t:.5f} sec')
256+
t = time()
250257
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
251258

252259
if exists(rotary_pos_emb):
@@ -317,7 +324,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
317324
out = torch.cat((out_text, out_image), dim = 1)
318325

319326
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
320-
out = self.to_out(out)
327+
print(f'Time 2: {time() - t:.5f} sec')
328+
t = time()
329+
out = self.to_out(out, cache = cache, cache_key = f'{cache_key}_out')
330+
print(f'Time 3: {time() - t:.5f} sec\n')
321331
return out[:, :n]
322332

323333
# microsoft sparse attention CUDA kernel

dalle_pytorch/cache.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
# helpers
5+
6+
def exists(val):
7+
return val is not None
8+
9+
class Cached(nn.Module):
10+
def __init__(self, fn):
11+
super().__init__()
12+
self.fn = fn
13+
14+
def forward(self, x, *, cache=None, cache_key=None, **kwargs):
15+
assert exists(cache) and exists(cache_key)
16+
17+
return self.fn(x, **kwargs) # dbg
18+
19+
if exists(cache) and cache_key in cache:
20+
prefix = cache[cache_key]
21+
assert prefix.shape[1] == x.shape[1] or prefix.shape[1] + 1 == x.shape[1], f'{prefix.shape[1]} {x.shape[1]} {cache_key} {cache.keys()}' # TODO: Change to <= for prod
22+
suffix = self.fn(x[:, prefix.shape[1]:, :], **kwargs)
23+
out = torch.cat([prefix, suffix], dim=1)
24+
else:
25+
out = self.fn(x, **kwargs)
26+
27+
if exists(cache):
28+
cache[cache_key] = out
29+
return out
30+
31+
class FixCacheKey(nn.Module):
32+
def __init__(self, cache_key, fn):
33+
super().__init__()
34+
self.cache_key = cache_key
35+
self.fn = fn
36+
37+
def forward(self, x, *, cache=None, **kwargs):
38+
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)

dalle_pytorch/transformer.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
1111
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
12+
from dalle_pytorch.cache import Cached, FixCacheKey
1213

1314
from rotary_embedding_torch import RotaryEmbedding, broadcat
1415
from g_mlp_pytorch import gMLPBlock
@@ -86,24 +87,6 @@ def __init__(self, dim, dropout = 0., mult = 4.):
8687
def forward(self, x):
8788
return self.net(x)
8889

89-
class Cached(nn.Module):
90-
def __init__(self, key, fn):
91-
super().__init__()
92-
self.key = key
93-
self.fn = fn
94-
95-
def forward(self, x, cache=None, **kwargs):
96-
if exists(cache) and self.key in cache:
97-
prefix = cache[self.key]
98-
suffix = self.fn(x[:, prefix.shape[1]:, :], **kwargs)
99-
out = torch.cat([prefix, suffix], dim=1)
100-
else:
101-
out = self.fn(x, **kwargs)
102-
103-
if exists(cache):
104-
cache[self.key] = out
105-
return out
106-
10790
# token shift classes
10891

10992
class PreShiftToken(nn.Module):
@@ -217,7 +200,8 @@ def __init__(
217200
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
218201
shared_ff_layers[ff_id] = ff
219202

220-
ff = Cached(f'ff_{ind}', ff)
203+
attn = FixCacheKey(f'attn_{ind}', attn)
204+
ff = FixCacheKey(f'ff_{ind}', Cached(ff))
221205

222206
if shift_tokens:
223207
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
@@ -229,9 +213,9 @@ def __init__(
229213

230214
execute_type = ReversibleSequence if reversible else SequentialSequence
231215
route_attn = ((True, False),) * depth
232-
route_ffn = ((False, True),) * depth
216+
route_all = ((True, True),) * depth
233217
attn_route_map = {'mask': route_attn, 'rotary_pos_emb': route_attn,
234-
'cache': route_ffn}
218+
'cache': route_all}
235219

236220
self.layers = execute_type(layers, args_route = attn_route_map)
237221

0 commit comments

Comments
 (0)