Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -350,6 +363,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
Expand Down Expand Up @@ -425,9 +440,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
27 changes: 21 additions & 6 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,14 @@ def train(args):
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -427,7 +432,10 @@ def train(args):
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -643,6 +651,8 @@ def optimizer_hook(parameter: torch.Tensor):
m.train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step

if args.blockwise_fused_optimizers:
Expand Down Expand Up @@ -746,15 +756,20 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if not args.optimizer_schedulefree_wrapper:
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
49 changes: 47 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,6 +3303,20 @@ def int_or_float(value):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
)

parser.add_argument(
"--optimizer_schedulefree_wrapper",
action="store_true",
help="use schedulefree_wrapper any optimizer / 任意のオプティマイザにschedulefree_wrapperを使用",
)

parser.add_argument(
"--schedulefree_wrapper_args",
type=str,
default=None,
nargs="*",
help='additional arguments for schedulefree_wrapper (like "momentum=0.9 weight_decay_at_y=0.1 ...") / オプティマイザの追加引数(例: "momentum=0.9 weight_decay_at_y=0.1 ...")',
)

parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module / 使用するスケジューラ")
parser.add_argument(
"--lr_scheduler_args",
Expand Down Expand Up @@ -4361,6 +4375,8 @@ def get_optimizer(args, trainable_params):
optimizer_kwargs[key] = value
# logger.info(f"optkwargs {optimizer}_{kwargs}")

schedulefree_wrapper_kwargs = {}

lr = args.learning_rate
optimizer = None

Expand Down Expand Up @@ -4581,20 +4597,49 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
logger.info(f"use {optimizer_type} | {optimizer_kwargs}")
if "." not in optimizer_type:
optimizer_module = torch.optim
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
if args.optimizer_schedulefree_wrapper and not optimizer_type.endswith("schedulefree"):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")

if args.schedulefree_wrapper_args is not None and len(args.schedulefree_wrapper_args) > 0:
for arg in args.schedulefree_wrapper_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
schedulefree_wrapper_kwargs[key] = value
optimizer = sf.ScheduleFreeWrapper(optimizer, **schedulefree_wrapper_kwargs)
else:
values = optimizer_type.split(".")
optimizer_module = importlib.import_module(".".join(values[:-1]))
optimizer_type = values[-1]

optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
optimizer_class = getattr(optimizer_module, optimizer_type)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.3
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7
tensorboard
safetensors==0.4.4
# gradio==3.16.2
Expand Down
27 changes: 21 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -497,7 +502,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -630,6 +638,8 @@ def optimizer_hook(parameter: torch.Tensor):
m.train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step

if args.fused_optimizer_groups:
Expand Down Expand Up @@ -749,15 +759,20 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if not args.optimizer_schedulefree_wrapper:
lr_scheduler.step()
if args.fused_optimizer_groups:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
18 changes: 16 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,22 @@ def train(args):
unet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter

if args.gradient_checkpointing:
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる

else:
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()
unet.eval()

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
Expand Down Expand Up @@ -416,6 +424,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
Expand Down Expand Up @@ -510,9 +520,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
23 changes: 19 additions & 4 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,24 @@ def train(args):
network.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
else:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
else:
unet.eval()
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

network.prepare_grad_etc()

Expand Down Expand Up @@ -357,6 +366,8 @@ def remove_model(old_ckpt_name):
network.on_epoch_start() # train()

for step, batch in enumerate(train_dataloader):
if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
Expand Down Expand Up @@ -449,9 +460,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not (args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.optimizer_type.lower().endswith("schedulefree") or args.optimizer_schedulefree_wrapper:
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
Loading