Skip to content

Commit e4e101f

Browse files
committed
add sandwich norm, from the coqview paper, for stabilizing training even more, hidden behind feature flag
1 parent 15d2f35 commit e4e101f

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

dalle_pytorch/dalle_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def __init__(
324324
attn_types = None,
325325
loss_img_weight = 7,
326326
stable = False,
327+
sandwich_norm = False,
327328
shift_tokens = True,
328329
rotary_emb = True
329330
):
@@ -371,6 +372,7 @@ def __init__(
371372
image_fmap_size = image_fmap_size,
372373
sparse_attn = sparse_attn,
373374
stable = stable,
375+
sandwich_norm = sandwich_norm,
374376
shift_tokens = shift_tokens,
375377
rotary_emb = rotary_emb
376378
)

dalle_pytorch/transformer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,16 @@ def forward(self, x, **kwargs):
5656
# layer norm
5757

5858
class PreNorm(nn.Module):
59-
def __init__(self, dim, fn):
59+
def __init__(self, dim, fn, sandwich = False):
6060
super().__init__()
6161
self.norm = nn.LayerNorm(dim)
62+
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
6263
self.fn = fn
6364

6465
def forward(self, x, **kwargs):
65-
return self.fn(self.norm(x), **kwargs)
66+
x = self.norm(x)
67+
x = self.fn(x, **kwargs)
68+
return self.norm_out(x)
6669

6770
# feed forward
6871

@@ -145,6 +148,7 @@ def __init__(
145148
image_fmap_size = None,
146149
sparse_attn = False,
147150
stable = False,
151+
sandwich_norm = False,
148152
shift_tokens = False,
149153
rotary_emb = True
150154
):
@@ -183,8 +187,8 @@ def __init__(
183187
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))
184188

185189
layers.append(nn.ModuleList([
186-
LayerScale(dim, ind + 1, PreNorm(dim, attn)),
187-
LayerScale(dim, ind + 1, PreNorm(dim, ff))
190+
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich = sandwich_norm)),
191+
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich = sandwich_norm))
188192
]))
189193

190194
execute_type = ReversibleSequence if reversible else SequentialSequence

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'dalle-pytorch',
55
packages = find_packages(),
66
include_package_data = True,
7-
version = '1.0.8',
7+
version = '1.1.0',
88
license='MIT',
99
description = 'DALL-E - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)