Skip to content

Commit 0dec414

Browse files
[train_dreambooth_lora_sdxl.py] Fix the LR Schedulers when num_train_epochs is passed in a distributed training env (#11240)
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent 44eeba0 commit 0dec414

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,17 +1523,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15231523
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
15241524

15251525
# Scheduler and math around the number of training steps.
1526-
overrode_max_train_steps = False
1527-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1526+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1527+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
15281528
if args.max_train_steps is None:
1529-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1530-
overrode_max_train_steps = True
1529+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1530+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1531+
num_training_steps_for_scheduler = (
1532+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1533+
)
1534+
else:
1535+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
15311536

15321537
lr_scheduler = get_scheduler(
15331538
args.lr_scheduler,
15341539
optimizer=optimizer,
1535-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1536-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1540+
num_warmup_steps=num_warmup_steps_for_scheduler,
1541+
num_training_steps=num_training_steps_for_scheduler,
15371542
num_cycles=args.lr_num_cycles,
15381543
power=args.lr_power,
15391544
)
@@ -1550,7 +1555,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15501555

15511556
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
15521557
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1553-
if overrode_max_train_steps:
1558+
if args.max_train_steps is None:
1559+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1560+
if num_training_steps_for_scheduler != args.max_train_steps:
1561+
logger.warning(
1562+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1563+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1564+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1565+
)
15541566
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
15551567
# Afterwards we recalculate our number of training epochs
15561568
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

0 commit comments

Comments
 (0)