Skip to content

Add comprehensive input validation for training hyperparameters in ControlNet Flux training script #11980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions examples/controlnet/train_controlnet_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,106 @@ def parse_args(input_args=None):
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)

# Additional comprehensive parameter validation
if args.learning_rate <= 0:
raise ValueError("`--learning_rate` must be positive")

if args.train_batch_size <= 0:
raise ValueError("`--train_batch_size` must be positive")

if args.num_train_epochs <= 0:
raise ValueError("`--num_train_epochs` must be positive")

if args.gradient_accumulation_steps <= 0:
raise ValueError("`--gradient_accumulation_steps` must be positive")

if args.max_train_steps is not None and args.max_train_steps <= 0:
raise ValueError("`--max_train_steps` must be positive when specified")

if args.checkpointing_steps <= 0:
raise ValueError("`--checkpointing_steps` must be positive")

if args.validation_steps <= 0:
raise ValueError("`--validation_steps` must be positive")

if args.num_validation_images <= 0:
raise ValueError("`--num_validation_images` must be positive")

if args.lr_warmup_steps < 0:
raise ValueError("`--lr_warmup_steps` must be non-negative")

if args.lr_num_cycles <= 0:
raise ValueError("`--lr_num_cycles` must be positive")

if args.lr_power <= 0:
raise ValueError("`--lr_power` must be positive")

if args.dataloader_num_workers < 0:
raise ValueError("`--dataloader_num_workers` must be non-negative")

if not (0.0 <= args.adam_beta1 < 1.0):
raise ValueError("`--adam_beta1` must be in the range [0.0, 1.0)")

if not (0.0 <= args.adam_beta2 < 1.0):
raise ValueError("`--adam_beta2` must be in the range [0.0, 1.0)")

if args.adam_weight_decay < 0:
raise ValueError("`--adam_weight_decay` must be non-negative")

if args.adam_epsilon <= 0:
raise ValueError("`--adam_epsilon` must be positive")

if args.max_grad_norm <= 0:
raise ValueError("`--max_grad_norm` must be positive")

if args.max_train_samples is not None and args.max_train_samples <= 0:
raise ValueError("`--max_train_samples` must be positive when specified")

if args.num_double_layers <= 0:
raise ValueError("`--num_double_layers` must be positive")

if args.num_single_layers <= 0:
raise ValueError("`--num_single_layers` must be positive")

if args.guidance_scale < 0:
raise ValueError("`--guidance_scale` must be non-negative")

if args.logit_std <= 0:
raise ValueError("`--logit_std` must be positive")

if args.mode_scale <= 0:
raise ValueError("`--mode_scale` must be positive")

if args.checkpoints_total_limit is not None and args.checkpoints_total_limit <= 0:
raise ValueError("`--checkpoints_total_limit` must be positive when specified")

# Validate resolution is reasonable (not too small or absurdly large)
if args.resolution < 64:
raise ValueError("`--resolution` must be at least 64 pixels")

if args.resolution > 4096:
raise ValueError("`--resolution` should not exceed 4096 pixels for memory efficiency")

# Validate crop coordinates are non-negative
if args.crops_coords_top_left_h < 0:
raise ValueError("`--crops_coords_top_left_h` must be non-negative")

if args.crops_coords_top_left_w < 0:
raise ValueError("`--crops_coords_top_left_w` must be non-negative")

# Warn about potentially problematic combinations
if args.gradient_accumulation_steps > 1 and args.train_batch_size > 32:
logger.warning(
f"Large batch size ({args.train_batch_size}) with gradient accumulation ({args.gradient_accumulation_steps}) "
"may cause memory issues. Consider reducing batch size or gradient accumulation steps."
)

if args.learning_rate > 1e-2:
logger.warning(
f"Learning rate ({args.learning_rate}) is quite high. This may cause training instability. "
"Consider using a lower learning rate (e.g., 1e-4 to 1e-5)."
)

return args


Expand Down