9
9
10
10
from dalle_pytorch .reversible import ReversibleSequence , SequentialSequence
11
11
from dalle_pytorch .attention import Attention , SparseAttention , SparseConvCausalAttention , SparseAxialCausalAttention
12
- from dalle_pytorch .cache import FixCacheKey
13
12
14
13
from rotary_embedding_torch import RotaryEmbedding , broadcat
15
14
from g_mlp_pytorch import gMLPBlock
@@ -36,6 +35,15 @@ def forward(self, x):
36
35
maxes = x .amax (dim = self .dim , keepdim = True )
37
36
return x / maxes
38
37
38
+ class CachedAs (nn .Module ):
39
+ def __init__ (self , cache_key , fn ):
40
+ super ().__init__ ()
41
+ self .cache_key = cache_key
42
+ self .fn = fn
43
+
44
+ def forward (self , x , * , cache = None , ** kwargs ):
45
+ return self .fn (x , cache = cache , cache_key = self .cache_key , ** kwargs )
46
+
39
47
# https://arxiv.org/abs/2103.17239
40
48
class LayerScale (nn .Module ):
41
49
def __init__ (self , dim , depth , fn ):
@@ -200,7 +208,7 @@ def __init__(
200
208
ff = FeedForward (dim , mult = ff_mult , dropout = ff_dropout )
201
209
shared_ff_layers [ff_id ] = ff
202
210
203
- attn = FixCacheKey (f'attn_{ ind } ' , attn )
211
+ attn = CachedAs (f'attn_{ ind } ' , attn )
204
212
205
213
if shift_tokens :
206
214
attn , ff = map (lambda t : PreShiftToken (t , image_size = image_fmap_size , seq_len = seq_len ), (attn , ff ))
0 commit comments