Skip to content

Commit d7c034e

Browse files
committed
Improve checking reused attn type
1 parent e10096e commit d7c034e

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

dalle_pytorch/transformer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterable
12
from functools import partial
23
from itertools import islice, cycle
34

@@ -21,9 +22,7 @@ def default(val, d):
2122
return val if exists(val) else d
2223

2324
def cast_tuple(val, depth = 1):
24-
if isinstance(val, list):
25-
val = tuple(val)
26-
return val if isinstance(val, tuple) else (val,) * depth
25+
return val if isinstance(val, Iterable) else (val,) * depth
2726

2827
# classes
2928

@@ -184,15 +183,16 @@ def __init__(
184183
else:
185184
raise ValueError(f'attention type "{attn_type}" is not valid')
186185

187-
attn = shared_attn_layers.get(attn_id)
186+
attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None))
188187
if not exists(attn):
189188
if attn_type != 'mlp':
190189
attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout)
191190
else:
192191
attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4)
193-
shared_attn_layers[attn_id] = attn
194-
else:
195-
assert isinstance(attn, attn_class), 'attn_types do not match shared_attn_ids'
192+
shared_attn_layers[attn_id] = (attn, attn_type)
193+
elif attn_type != reused_attn_type:
194+
raise ValueError('attn_types do not match shared_attn_ids '
195+
f'(ind = {ind}, attn_type = "{attn_type}", reused_attn_type = "{reused_attn_type}")')
196196

197197
ff = shared_ff_layers.get(ff_id)
198198
if not exists(ff):

0 commit comments

Comments
 (0)