Skip to content

Commit d72184e

Browse files
leisuzzJ石页sayakpaulgithub-actions[bot]
authored
[training] add ds support to lora hidream (#11737)
* [training] add ds support to lora hidream * Apply style fixes --------- Co-authored-by: J石页 <jiangshuo9@h-partners.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 5ce4814 commit d72184e

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import numpy as np
3030
import torch
3131
import transformers
32-
from accelerate import Accelerator
32+
from accelerate import Accelerator, DistributedType
3333
from accelerate.logging import get_logger
3434
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3535
from huggingface_hub import create_repo, upload_folder
@@ -1181,13 +1181,15 @@ def save_model_hook(models, weights, output_dir):
11811181
transformer_lora_layers_to_save = None
11821182

11831183
for model in models:
1184-
if isinstance(model, type(unwrap_model(transformer))):
1184+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1185+
model = unwrap_model(model)
11851186
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
11861187
else:
11871188
raise ValueError(f"unexpected save model: {model.__class__}")
11881189

11891190
# make sure to pop weight so that corresponding model is not saved again
1190-
weights.pop()
1191+
if weights:
1192+
weights.pop()
11911193

11921194
HiDreamImagePipeline.save_lora_weights(
11931195
output_dir,
@@ -1197,13 +1199,20 @@ def save_model_hook(models, weights, output_dir):
11971199
def load_model_hook(models, input_dir):
11981200
transformer_ = None
11991201

1200-
while len(models) > 0:
1201-
model = models.pop()
1202+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
1203+
while len(models) > 0:
1204+
model = models.pop()
12021205

1203-
if isinstance(model, type(unwrap_model(transformer))):
1204-
transformer_ = model
1205-
else:
1206-
raise ValueError(f"unexpected save model: {model.__class__}")
1206+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1207+
model = unwrap_model(model)
1208+
transformer_ = model
1209+
else:
1210+
raise ValueError(f"unexpected save model: {model.__class__}")
1211+
else:
1212+
transformer_ = HiDreamImageTransformer2DModel.from_pretrained(
1213+
args.pretrained_model_name_or_path, subfolder="transformer"
1214+
)
1215+
transformer_.add_adapter(transformer_lora_config)
12071216

12081217
lora_state_dict = HiDreamImagePipeline.lora_state_dict(input_dir)
12091218

@@ -1655,7 +1664,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16551664
progress_bar.update(1)
16561665
global_step += 1
16571666

1658-
if accelerator.is_main_process:
1667+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
16591668
if global_step % args.checkpointing_steps == 0:
16601669
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
16611670
if args.checkpoints_total_limit is not None:

0 commit comments

Comments
 (0)