@@ -33,6 +33,22 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
33
33
"--text_encoder_3_dtype" , config .get ('text_encoder_3_dtype' ),
34
34
"--transformer_dtype" , config .get ('transformer_dtype' ),
35
35
"--vae_dtype" , config .get ('vae_dtype' )]
36
+ if config .get ('text_encoder_id' ):
37
+ model_cmd += ["--text_encoder_id" , config .get ('text_encoder_id' )]
38
+ if config .get ('text_encoder_2_id' ):
39
+ model_cmd += ["--text_encoder_2_id" , config .get ('text_encoder_2_id' )]
40
+ if config .get ('text_encoder_3_id' ):
41
+ model_cmd += ["--text_encoder_3_id" , config .get ('text_encoder_3_id' )]
42
+ if config .get ('transformer_id' ):
43
+ model_cmd += ["--transformer_id" , config .get ('transformer_id' )]
44
+ if config .get ('vae_id' ):
45
+ model_cmd += ["--vae_id" , config .get ('vae_id' )]
46
+ if config .get ('tokenizer_id' ):
47
+ model_cmd += ["--tokenizer_id" , config .get ('tokenizer_id' )]
48
+ if config .get ('tokenizer_2_id' ):
49
+ model_cmd += ["--tokenizer_2_id" , config .get ('tokenizer_2_id' )]
50
+ if config .get ('tokenizer_3_id' ):
51
+ model_cmd += ["--tokenizer_3_id" , config .get ('tokenizer_3_id' )]
36
52
37
53
if config .get ('layerwise_upcasting_modules' ) != 'none' :
38
54
model_cmd += ["--layerwise_upcasting_modules" , config .get ('layerwise_upcasting_modules' ),
@@ -55,10 +71,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
55
71
training_cmd = ["--training_type" , config .get ('training_type' ),
56
72
"--seed" , config .get ('seed' ),
57
73
"--batch_size" , config .get ('batch_size' ),
58
- "--train_steps" , config .get ('train_steps' ),
59
- "--rank" , config .get ('rank' ),
60
- "--lora_alpha" , config .get ('lora_alpha' ),
61
- "--target_modules" ]
74
+ "--train_steps" , config .get ('train_steps' )]
62
75
training_cmd += config .get ('target_modules' ).split (' ' )
63
76
training_cmd += ["--gradient_accumulation_steps" , config .get ('gradient_accumulation_steps' ),
64
77
'--gradient_checkpointing' if config .get ('gradient_checkpointing' ) else '' ,
@@ -87,6 +100,12 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
87
100
validation_cmd = ["--validation_dataset_file" if config .get ('validation_dataset_file' ) else '' ,
88
101
"--num_validation_videos" , config .get ('num_validation_videos' ),
89
102
"--validation_steps" , config .get ('validation_steps' )]
103
+
104
+ control_cmd = ["--rank" , config .get ('rank' ),
105
+ "--lora_alpha" , config .get ('lora_alpha' ),
106
+ "--control_type" , config .get ('control_type' ),
107
+ "--frame_conditioning_index" , config .get ('frame_conditioning_index' ),
108
+ "--frame_conditioning_type" , config .get ('frame_conditioning_type' )]
90
109
91
110
miscellaneous_cmd = ["--tracker_name" , config .get ('tracker_name' ),
92
111
"--output_dir" , config .get ('output_dir' ),
@@ -105,7 +124,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
105
124
pre_command = ["accelerate" , "launch" , "--config_file" , f"{ finetrainers_path } /accelerate_configs/{ config .get ('accelerate_config' )} " , "--gpu_ids" , config .get ('gpu_ids' )]
106
125
elif parallel_backend == 'ptd' :
107
126
pre_command = ["torchrun" , "--standalone" , "--nnodes" , num_gpus , "--nproc_per_node" , config .get ('nproc_per_node' ), "--rdzv_backend" , "c10d" , "--rdzv_endpoint" , f"{ address } :{ port } " ]
108
- cmd = pre_command + [f"{ finetrainers_path } /train.py" ] + parallel_cmd + model_cmd + dataset_cmd + dataloader_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd
127
+ cmd = pre_command + [f"{ finetrainers_path } /train.py" ] + parallel_cmd + model_cmd + dataset_cmd + dataloader_cmd + training_cmd + optimizer_cmd + validation_cmd + miscellaneous_cmd + control_cmd
109
128
fixed_cmd = []
110
129
for i in range (len (cmd )):
111
130
if cmd [i ] != '' :
0 commit comments