Skip to content

Commit 44eeba0

Browse files
[Flux LoRAs] fix lr scheduler bug in distributed scenarios (#11242)
* add fix * add fix * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 5873377 commit 44eeba0

File tree

3 files changed

+54
-22
lines changed

3 files changed

+54
-22
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,17 +1915,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19151915
free_memory()
19161916

19171917
# Scheduler and math around the number of training steps.
1918-
overrode_max_train_steps = False
1919-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1918+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1919+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
19201920
if args.max_train_steps is None:
1921-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1922-
overrode_max_train_steps = True
1921+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1922+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1923+
num_training_steps_for_scheduler = (
1924+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1925+
)
1926+
else:
1927+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
19231928

19241929
lr_scheduler = get_scheduler(
19251930
args.lr_scheduler,
19261931
optimizer=optimizer,
1927-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1928-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1932+
num_warmup_steps=num_warmup_steps_for_scheduler,
1933+
num_training_steps=num_training_steps_for_scheduler,
19291934
num_cycles=args.lr_num_cycles,
19301935
power=args.lr_power,
19311936
)
@@ -1949,7 +1954,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19491954
lr_scheduler,
19501955
)
19511956
else:
1952-
print("I SHOULD BE HERE")
19531957
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
19541958
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
19551959
)
@@ -1961,8 +1965,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19611965

19621966
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
19631967
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1964-
if overrode_max_train_steps:
1968+
if args.max_train_steps is None:
19651969
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1970+
if num_training_steps_for_scheduler != args.max_train_steps:
1971+
logger.warning(
1972+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1973+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1974+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1975+
)
19661976
# Afterwards we recalculate our number of training epochs
19671977
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
19681978

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,17 +1407,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14071407
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
14081408

14091409
# Scheduler and math around the number of training steps.
1410-
overrode_max_train_steps = False
1411-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1410+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1411+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
14121412
if args.max_train_steps is None:
1413-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1414-
overrode_max_train_steps = True
1413+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1414+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1415+
num_training_steps_for_scheduler = (
1416+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1417+
)
1418+
else:
1419+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
14151420

14161421
lr_scheduler = get_scheduler(
14171422
args.lr_scheduler,
14181423
optimizer=optimizer,
1419-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1420-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1424+
num_warmup_steps=num_warmup_steps_for_scheduler,
1425+
num_training_steps=num_training_steps_for_scheduler,
14211426
num_cycles=args.lr_num_cycles,
14221427
power=args.lr_power,
14231428
)
@@ -1444,8 +1449,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14441449

14451450
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
14461451
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1447-
if overrode_max_train_steps:
1452+
if args.max_train_steps is None:
14481453
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1454+
if num_training_steps_for_scheduler != args.max_train_steps:
1455+
logger.warning(
1456+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1457+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1458+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1459+
)
14491460
# Afterwards we recalculate our number of training epochs
14501461
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
14511462

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,17 +1524,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15241524
free_memory()
15251525

15261526
# Scheduler and math around the number of training steps.
1527-
overrode_max_train_steps = False
1528-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1527+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1528+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
15291529
if args.max_train_steps is None:
1530-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1531-
overrode_max_train_steps = True
1530+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1531+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1532+
num_training_steps_for_scheduler = (
1533+
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
1534+
)
1535+
else:
1536+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
15321537

15331538
lr_scheduler = get_scheduler(
15341539
args.lr_scheduler,
15351540
optimizer=optimizer,
1536-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1537-
num_training_steps=args.max_train_steps * accelerator.num_processes,
1541+
num_warmup_steps=num_warmup_steps_for_scheduler,
1542+
num_training_steps=num_training_steps_for_scheduler,
15381543
num_cycles=args.lr_num_cycles,
15391544
power=args.lr_power,
15401545
)
@@ -1561,8 +1566,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15611566

15621567
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
15631568
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1564-
if overrode_max_train_steps:
1569+
if args.max_train_steps is None:
15651570
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1571+
if num_training_steps_for_scheduler != args.max_train_steps:
1572+
logger.warning(
1573+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1574+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1575+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
1576+
)
15661577
# Afterwards we recalculate our number of training epochs
15671578
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
15681579

0 commit comments

Comments
 (0)