29
29
import numpy as np
30
30
import torch
31
31
import transformers
32
- from accelerate import Accelerator
32
+ from accelerate import Accelerator , DistributedType
33
33
from accelerate .logging import get_logger
34
34
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
35
35
from huggingface_hub import create_repo , upload_folder
@@ -1181,13 +1181,15 @@ def save_model_hook(models, weights, output_dir):
1181
1181
transformer_lora_layers_to_save = None
1182
1182
1183
1183
for model in models :
1184
- if isinstance (model , type (unwrap_model (transformer ))):
1184
+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1185
+ model = unwrap_model (model )
1185
1186
transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1186
1187
else :
1187
1188
raise ValueError (f"unexpected save model: { model .__class__ } " )
1188
1189
1189
1190
# make sure to pop weight so that corresponding model is not saved again
1190
- weights .pop ()
1191
+ if weights :
1192
+ weights .pop ()
1191
1193
1192
1194
HiDreamImagePipeline .save_lora_weights (
1193
1195
output_dir ,
@@ -1197,13 +1199,20 @@ def save_model_hook(models, weights, output_dir):
1197
1199
def load_model_hook (models , input_dir ):
1198
1200
transformer_ = None
1199
1201
1200
- while len (models ) > 0 :
1201
- model = models .pop ()
1202
+ if not accelerator .distributed_type == DistributedType .DEEPSPEED :
1203
+ while len (models ) > 0 :
1204
+ model = models .pop ()
1202
1205
1203
- if isinstance (model , type (unwrap_model (transformer ))):
1204
- transformer_ = model
1205
- else :
1206
- raise ValueError (f"unexpected save model: { model .__class__ } " )
1206
+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1207
+ model = unwrap_model (model )
1208
+ transformer_ = model
1209
+ else :
1210
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
1211
+ else :
1212
+ transformer_ = HiDreamImageTransformer2DModel .from_pretrained (
1213
+ args .pretrained_model_name_or_path , subfolder = "transformer"
1214
+ )
1215
+ transformer_ .add_adapter (transformer_lora_config )
1207
1216
1208
1217
lora_state_dict = HiDreamImagePipeline .lora_state_dict (input_dir )
1209
1218
@@ -1655,7 +1664,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1655
1664
progress_bar .update (1 )
1656
1665
global_step += 1
1657
1666
1658
- if accelerator .is_main_process :
1667
+ if accelerator .is_main_process or accelerator . distributed_type == DistributedType . DEEPSPEED :
1659
1668
if global_step % args .checkpointing_steps == 0 :
1660
1669
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1661
1670
if args .checkpoints_total_limit is not None :
0 commit comments