1
+ from collections .abc import Iterable
1
2
from functools import partial
2
3
from itertools import islice , cycle
3
4
@@ -21,9 +22,7 @@ def default(val, d):
21
22
return val if exists (val ) else d
22
23
23
24
def cast_tuple (val , depth = 1 ):
24
- if isinstance (val , list ):
25
- val = tuple (val )
26
- return val if isinstance (val , tuple ) else (val ,) * depth
25
+ return val if isinstance (val , Iterable ) else (val ,) * depth
27
26
28
27
# classes
29
28
@@ -184,15 +183,16 @@ def __init__(
184
183
else :
185
184
raise ValueError (f'attention type "{ attn_type } " is not valid' )
186
185
187
- attn = shared_attn_layers .get (attn_id )
186
+ attn , reused_attn_type = shared_attn_layers .get (attn_id , ( None , None ) )
188
187
if not exists (attn ):
189
188
if attn_type != 'mlp' :
190
189
attn = attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout )
191
190
else :
192
191
attn = attn_class (dim = dim , causal = causal , dim_ff = dim * 4 )
193
- shared_attn_layers [attn_id ] = attn
194
- else :
195
- assert isinstance (attn , attn_class ), 'attn_types do not match shared_attn_ids'
192
+ shared_attn_layers [attn_id ] = (attn , attn_type )
193
+ elif attn_type != reused_attn_type :
194
+ raise ValueError ('attn_types do not match shared_attn_ids '
195
+ f'(ind = { ind } , attn_type = "{ attn_type } ", reused_attn_type = "{ reused_attn_type } ")' )
196
196
197
197
ff = shared_ff_layers .get (ff_id )
198
198
if not exists (ff ):
0 commit comments