@@ -1523,17 +1523,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1523
1523
tokens_two = torch .cat ([tokens_two , class_tokens_two ], dim = 0 )
1524
1524
1525
1525
# Scheduler and math around the number of training steps.
1526
- overrode_max_train_steps = False
1527
- num_update_steps_per_epoch = math . ceil ( len ( train_dataloader ) / args . gradient_accumulation_steps )
1526
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1527
+ num_warmup_steps_for_scheduler = args . lr_warmup_steps * accelerator . num_processes
1528
1528
if args .max_train_steps is None :
1529
- args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1530
- overrode_max_train_steps = True
1529
+ len_train_dataloader_after_sharding = math .ceil (len (train_dataloader ) / accelerator .num_processes )
1530
+ num_update_steps_per_epoch = math .ceil (len_train_dataloader_after_sharding / args .gradient_accumulation_steps )
1531
+ num_training_steps_for_scheduler = (
1532
+ args .num_train_epochs * accelerator .num_processes * num_update_steps_per_epoch
1533
+ )
1534
+ else :
1535
+ num_training_steps_for_scheduler = args .max_train_steps * accelerator .num_processes
1531
1536
1532
1537
lr_scheduler = get_scheduler (
1533
1538
args .lr_scheduler ,
1534
1539
optimizer = optimizer ,
1535
- num_warmup_steps = args . lr_warmup_steps * accelerator . num_processes ,
1536
- num_training_steps = args . max_train_steps * accelerator . num_processes ,
1540
+ num_warmup_steps = num_warmup_steps_for_scheduler ,
1541
+ num_training_steps = num_training_steps_for_scheduler ,
1537
1542
num_cycles = args .lr_num_cycles ,
1538
1543
power = args .lr_power ,
1539
1544
)
@@ -1550,7 +1555,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
1550
1555
1551
1556
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
1552
1557
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
1553
- if overrode_max_train_steps :
1558
+ if args .max_train_steps is None :
1559
+ args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1560
+ if num_training_steps_for_scheduler != args .max_train_steps :
1561
+ logger .warning (
1562
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({ len (train_dataloader )} ) does not match "
1563
+ f"the expected length ({ len_train_dataloader_after_sharding } ) when the learning rate scheduler was created. "
1564
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1565
+ )
1554
1566
args .max_train_steps = args .num_train_epochs * num_update_steps_per_epoch
1555
1567
# Afterwards we recalculate our number of training epochs
1556
1568
args .num_train_epochs = math .ceil (args .max_train_steps / num_update_steps_per_epoch )
0 commit comments