Skip to content

Commit 44775fc

Browse files
committed
Implement weight sharing in transformer
1 parent a34d5d9 commit 44775fc

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

dalle_pytorch/transformer.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def __init__(
150150
stable = False,
151151
sandwich_norm = False,
152152
shift_tokens = False,
153-
rotary_emb = True
153+
rotary_emb = True,
154+
shared_attn_ids = None,
155+
shared_ff_ids = None,
154156
):
155157
super().__init__()
156158
layers = nn.ModuleList([])
@@ -160,7 +162,13 @@ def __init__(
160162
attn_types = cast_tuple(attn_types)
161163
attn_type_layer = islice(cycle(attn_types), depth)
162164

163-
for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer):
165+
shared_attn_ids = cycle(default(shared_attn_ids, range(depth)))
166+
shared_ff_ids = cycle(default(shared_ff_ids, range(depth)))
167+
shared_attn_layers = {}
168+
shared_ff_layers = {}
169+
170+
for (ind, sparse_attn, attn_type, attn_id, ff_id) in \
171+
zip(range(depth), sparse_layer, attn_type_layer, shared_attn_ids, shared_ff_ids):
164172
if attn_type == 'full':
165173
attn_class = partial(Attention, stable = stable)
166174
elif attn_type == 'sparse':
@@ -176,12 +184,20 @@ def __init__(
176184
else:
177185
raise ValueError(f'attention type "{attn_type}" is not valid')
178186

179-
if attn_type != 'mlp':
180-
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
187+
attn = shared_attn_layers.get(attn_id)
188+
if not exists(attn):
189+
if attn_type != 'mlp':
190+
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
191+
else:
192+
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
193+
shared_attn_layers[attn_id] = attn
181194
else:
182-
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
195+
assert isinstance(attn, attn_class), 'attn_types do not match shared_attn_ids'
183196

184-
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
197+
ff = shared_ff_layers.get(ff_id)
198+
if not exists(ff):
199+
ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout)
200+
shared_ff_layers[ff_id] = ff
185201

186202
if shift_tokens:
187203
attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff))

0 commit comments

Comments
 (0)