Skip to content

Commit e6625cd

Browse files
committed
fix small issue loading models
1 parent b5b8220 commit e6625cd

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

backends/model_converter/convert_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
for k in torch_weights['state_dict']:
3636
if k not in SD_SHAPES and k not in extra_keys:
3737
continue
38+
if 'model_ema' in k:
39+
continue
3840
np_arr = torch_weights['state_dict'][k]
3941
key_bytes = np_arr.tobytes()
4042
shape = list(np_arr.shape)
@@ -44,7 +46,7 @@
4446
if dtype == 'int64':
4547
np_arr = np_arr.astype('float32')
4648
dtype = 'float32'
47-
assert dtype in ['float16' , 'float32']
49+
assert dtype in ['float16' , 'float32'] , (dtype, k)
4850
e = s + len(key_bytes)
4951
out_file.write(key_bytes)
5052
keys_info[k] = {"start": s , "end" : e , "shape": shape , "dtype" : dtype }

0 commit comments

Comments
 (0)