Skip to content

Commit d8e9411

Browse files
authored
Merge pull request #15 from neph1/update-v0.9.3
add more config options
2 parents e1cd838 + 69418bb commit d8e9411

File tree

4 files changed

+11
-2
lines changed

4 files changed

+11
-2
lines changed

config/config_categories.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Dataset: data_root, video_column, caption_column, id_token, video_resolution_buckets, caption_dropout_p
1+
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p
22
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
33
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
44
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator

config/config_template.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ checkpointing_limit: 102
1010
checkpointing_steps: 500
1111
data_root: ''
1212
dataloader_num_workers: 0
13+
dataset_file: ''
1314
diffusion_options: ''
15+
enable_model_cpu_offload: false
1416
enable_slicing: true
1517
enable_tiling: true
1618
epsilon: 1e-8
1719
gpu_ids: '0'
1820
gradient_accumulation_steps: 4
1921
gradient_checkpointing: true
2022
id_token: afkx
23+
image_resolution_buckets: 512x768
2124
lora_alpha: 128
2225
lr: 0.0001
2326
lr_num_cycles: 1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "finetrainers-ui"
3-
version = "0.9.1"
3+
version = "0.9.3"
44
dependencies = [
55
"gradio",
66
"torch>=2.4.1"

run_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,17 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
2828
"--id_token", config.get('id_token'),
2929
"--video_resolution_buckets"]
3030
dataset_cmd += config.get('video_resolution_buckets').split(' ')
31+
dataset_cmd += ["--image_resolution_buckets"]
32+
dataset_cmd += config.get('image_resolution_buckets').split(' ')
3133
dataset_cmd += ["--caption_dropout_p", config.get('caption_dropout_p'),
3234
"--caption_dropout_technique", config.get('caption_dropout_technique'),
3335
"--text_encoder_dtype", config.get('text_encoder_dtype'),
3436
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
3537
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
3638
"--vae_dtype", config.get('vae_dtype'),
3739
'--precompute_conditions' if config.get('precompute_conditions') else '']
40+
if config.get('dataset_file'):
41+
dataset_cmd += ["--dataset_file", config.get('dataset_file')]
3842

3943
dataloader_cmd = ["--dataloader_num_workers", config.get('dataloader_num_workers')]
4044

@@ -56,6 +60,8 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
5660
"--checkpointing_limit", config.get('checkpointing_limit'),
5761
'--enable_slicing' if config.get('enable_slicing') else '',
5862
'--enable_tiling' if config.get('enable_tiling') else '']
63+
if config.get('enable_model_cpu_offload'):
64+
training_cmd += ["--enable_model_cpu_offload"]
5965

6066
if config.get('resume_from_checkpoint'):
6167
training_cmd += ["--resume_from_checkpoint", config.get('resume_from_checkpoint')]

0 commit comments

Comments
 (0)