@@ -150,7 +150,9 @@ def __init__(
150
150
stable = False ,
151
151
sandwich_norm = False ,
152
152
shift_tokens = False ,
153
- rotary_emb = True
153
+ rotary_emb = True ,
154
+ shared_attn_ids = None ,
155
+ shared_ff_ids = None ,
154
156
):
155
157
super ().__init__ ()
156
158
layers = nn .ModuleList ([])
@@ -160,7 +162,13 @@ def __init__(
160
162
attn_types = cast_tuple (attn_types )
161
163
attn_type_layer = islice (cycle (attn_types ), depth )
162
164
163
- for ind , sparse_attn , attn_type in zip (range (depth ), sparse_layer , attn_type_layer ):
165
+ shared_attn_ids = cycle (default (shared_attn_ids , range (depth )))
166
+ shared_ff_ids = cycle (default (shared_ff_ids , range (depth )))
167
+ shared_attn_layers = {}
168
+ shared_ff_layers = {}
169
+
170
+ for (ind , sparse_attn , attn_type , attn_id , ff_id ) in \
171
+ zip (range (depth ), sparse_layer , attn_type_layer , shared_attn_ids , shared_ff_ids ):
164
172
if attn_type == 'full' :
165
173
attn_class = partial (Attention , stable = stable )
166
174
elif attn_type == 'sparse' :
@@ -176,12 +184,20 @@ def __init__(
176
184
else :
177
185
raise ValueError (f'attention type "{ attn_type } " is not valid' )
178
186
179
- if attn_type != 'mlp' :
180
- attn = attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout )
187
+ attn = shared_attn_layers .get (attn_id )
188
+ if not exists (attn ):
189
+ if attn_type != 'mlp' :
190
+ attn = attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout )
191
+ else :
192
+ attn = attn_class (dim = dim , causal = causal , dim_ff = dim * 4 )
193
+ shared_attn_layers [attn_id ] = attn
181
194
else :
182
- attn = attn_class ( dim = dim , causal = causal , dim_ff = dim * 4 )
195
+ assert isinstance ( attn , attn_class ), 'attn_types do not match shared_attn_ids'
183
196
184
- ff = FeedForward (dim , mult = ff_mult , dropout = ff_dropout )
197
+ ff = shared_ff_layers .get (ff_id )
198
+ if not exists (ff ):
199
+ ff = FeedForward (dim , mult = ff_mult , dropout = ff_dropout )
200
+ shared_ff_layers [ff_id ] = ff
185
201
186
202
if shift_tokens :
187
203
attn , ff = map (lambda t : PreShiftToken (t , image_size = image_fmap_size , seq_len = seq_len ), (attn , ff ))
0 commit comments