Skip to content

Commit 59cfc49

Browse files
committed
Add NonCached wrapper
1 parent adfce34 commit 59cfc49

File tree

2 files changed

+33
-9
lines changed

2 files changed

+33
-9
lines changed

dalle_pytorch/attention.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,7 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
122122
nn.Dropout(dropout)
123123
)
124124

125-
def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key = None):
126-
n0 = x.shape[1]
127-
if exists(cache):
128-
if cache_key in cache:
129-
x = torch.cat([cache[cache_key], x], dim=-2)
130-
cache[cache_key] = x
131-
125+
def forward(self, x, mask = None, rotary_pos_emb = None):
132126
b, n, _, h, img_size, kernel_size, dilation, seq_len, device = *x.shape, self.heads, self.image_size, self.kernel_size, self.dilation, self.seq_len, x.device
133127
softmax = torch.softmax if not self.stable else stable_softmax
134128

@@ -223,7 +217,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
223217

224218
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
225219
out = self.to_out(out)
226-
return out[:, n - n0:n]
220+
return out[:, :n]
227221

228222
# sparse axial causal attention
229223

dalle_pytorch/transformer.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,33 @@ def forward(self, x):
3636
maxes = x.amax(dim = self.dim, keepdim = True)
3737
return x / maxes
3838

39+
class NonCached(nn.Module):
40+
"""
41+
A wrapper for layers that don't support the inference cache themselves.
42+
Reconstructs the full sequence before the layer and
43+
cuts the suffix of the outputs after the layer.
44+
"""
45+
46+
def __init__(self, fn):
47+
super().__init__()
48+
self.fn = fn
49+
50+
def forward(self, x, *, cache = None, cache_key = None, **kwargs):
51+
n = x.shape[-2]
52+
if exists(cache):
53+
if cache_key in cache:
54+
x = torch.cat([cache[cache_key], x], dim=-2)
55+
cache[cache_key] = x
56+
57+
out = self.fn(x, **kwargs)
58+
59+
return out[:, -n:]
60+
3961
class CachedAs(nn.Module):
62+
"""
63+
A wrapper that defines a key for the inference cache.
64+
"""
65+
4066
def __init__(self, cache_key, fn):
4167
super().__init__()
4268
self.cache_key = cache_key
@@ -251,7 +277,11 @@ def __init__(
251277
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
252278
shared_ff_layers[ff_id] = ff
253279

254-
attn = CachedAs(f'attn_{ind}', attn)
280+
if isinstance(attn, Attention):
281+
attn = CachedAs(f'attn_{ind}', attn)
282+
else:
283+
# at the moment, other Attention classes don't support cache
284+
attn = NonCached(attn)
255285

256286
if shift_tokens:
257287
attn = CachedAs(f'preshift_attn_{ind}', PreShiftToken(attn, image_size = image_fmap_size, seq_len = seq_len))

0 commit comments

Comments
 (0)