Skip to content

Commit df89951

Browse files
committed
Rename FixCacheKey -> CachedAs
1 parent 8e8dea8 commit df89951

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

dalle_pytorch/cache.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

dalle_pytorch/transformer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from dalle_pytorch.reversible import ReversibleSequence, SequentialSequence
1111
from dalle_pytorch.attention import Attention, SparseAttention, SparseConvCausalAttention, SparseAxialCausalAttention
12-
from dalle_pytorch.cache import FixCacheKey
1312

1413
from rotary_embedding_torch import RotaryEmbedding, broadcat
1514
from g_mlp_pytorch import gMLPBlock
@@ -36,6 +35,15 @@ def forward(self, x):
3635
maxes = x.amax(dim = self.dim, keepdim = True)
3736
return x / maxes
3837

38+
class CachedAs(nn.Module):
39+
def __init__(self, cache_key, fn):
40+
super().__init__()
41+
self.cache_key = cache_key
42+
self.fn = fn
43+
44+
def forward(self, x, *, cache=None, **kwargs):
45+
return self.fn(x, cache=cache, cache_key=self.cache_key, **kwargs)
46+
3947
# https://arxiv.org/abs/2103.17239
4048
class LayerScale(nn.Module):
4149
def __init__(self, dim, depth, fn):
@@ -200,7 +208,7 @@ def __init__(
200208
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
201209
shared_ff_layers[ff_id] = ff
202210

203-
attn = FixCacheKey(f'attn_{ind}', attn)
211+
attn = CachedAs(f'attn_{ind}', attn)
204212

205213
if shift_tokens:
206214
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))

0 commit comments

Comments
 (0)