Skip to content

Commit 2168826

Browse files
committed
offer way to turn off the linear attention
1 parent 8d0de9a commit 2168826

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def __init__(
286286
latent_self_attn_depth,
287287
dim_latent = None,
288288
final_norm = True,
289+
patches_self_attn = True,
289290
**attn_kwargs
290291
):
291292
super().__init__()
@@ -304,8 +305,11 @@ def __init__(
304305
self.latent_final_norm = LayerNorm(dim_latent) if final_norm else nn.Identity()
305306

306307
self.patches_peg = PEG(dim)
307-
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
308-
self.patches_self_attn_ff = FeedForward(dim)
308+
self.patches_self_attn = patches_self_attn
309+
310+
if patches_self_attn:
311+
self.patches_self_attn = LinearAttention(dim, norm = True, **attn_kwargs)
312+
self.patches_self_attn_ff = FeedForward(dim)
309313

310314
self.patches_attend_to_latents = Attention(dim, dim_context = dim_latent, norm = True, norm_context = True, **attn_kwargs)
311315
self.patches_cross_attn_ff = FeedForward(dim)
@@ -325,10 +329,11 @@ def forward(self, patches, latents, t):
325329
latents = attn(latents, time = t) + latents
326330
latents = ff(latents, time = t) + latents
327331

328-
# additional patches self attention with linear attention
332+
if self.patches_self_attn:
333+
# additional patches self attention with linear attention
329334

330-
patches = self.patches_self_attn(patches, time = t) + patches
331-
patches = self.patches_self_attn_ff(patches) + patches
335+
patches = self.patches_self_attn(patches, time = t) + patches
336+
patches = self.patches_self_attn_ff(patches) + patches
332337

333338
# patches attend to the latents
334339

@@ -353,6 +358,7 @@ def __init__(
353358
learned_sinusoidal_dim = 16,
354359
latent_token_time_cond = False, # whether to use 1 latent token as time conditioning, or do it the adaptive layernorm way (which is highly effective as shown by some other papers "Paella" - Dominic Rampas et al.)
355360
dual_patchnorm = True,
361+
patches_self_attn = True, # the self attention in this repository is not strictly with the design proposed in the paper. offer way to remove it, in case it is the source of instability
356362
**attn_kwargs
357363
):
358364
super().__init__()
@@ -436,7 +442,7 @@ def __init__(
436442
if not latent_token_time_cond:
437443
attn_kwargs = {**attn_kwargs, 'time_cond_dim': time_dim}
438444

439-
self.blocks = nn.ModuleList([RINBlock(dim, dim_latent = dim_latent, latent_self_attn_depth = latent_self_attn_depth, **attn_kwargs) for _ in range(depth)])
445+
self.blocks = nn.ModuleList([RINBlock(dim, dim_latent = dim_latent, latent_self_attn_depth = latent_self_attn_depth, patches_self_attn = patches_self_attn, **attn_kwargs) for _ in range(depth)])
440446

441447
@property
442448
def device(self):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.7.7',
6+
version = '0.7.8',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)