Skip to content

Commit 4e9f0b2

Browse files
authored
Merge pull request #12 from neph1/update-v0.10.0
settings for fp8 training
2 parents d065821 + 2ef4089 commit 4e9f0b2

10 files changed

+27
-15
lines changed

config/config_categories.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p
2-
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
1+
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p, precompute_conditions
2+
Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
33
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
44
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
5-
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
5+
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
6+
Model: model_name, pretrained_model_name_or_path, text_encoder_dtype, text_encoder_2_dtype, text_encoder_3_dtype, vae_dtype, layerwise_upcasting_modules, layerwise_upcasting_storage_dtype, layerwise_upcasting_granularity

config/config_template.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ gpu_ids: '0'
2020
gradient_accumulation_steps: 4
2121
gradient_checkpointing: true
2222
id_token: afkx
23+
layerwise_upcasting_modules: [none, transformer]
24+
layerwise_upcasting_skip_modules_pattern: 'patch_embed pos_embed x_embedder context_embedder ^proj_in$ ^proj_out$ norm'
25+
layerwise_upcasting_storage_dtype: [float8_e4m3fn, float8_e5m2]
2326
image_resolution_buckets: 512x768
2427
lora_alpha: 128
2528
lr: 0.0001
2629
lr_num_cycles: 1
2730
lr_scheduler: ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup']
2831
lr_warmup_steps: 400
2932
max_grad_norm: 1.0
30-
mixed_precision: [bf16, fp16, 'no']
3133
model_name: ltx_video
3234
nccl_timeout: 1800
3335
num_validation_videos: 0
@@ -45,6 +47,7 @@ text_encoder_dtype: [bf16, fp16, fp32]
4547
text_encoder_2_dtype: [bf16, fp16, fp32]
4648
text_encoder_3_dtype: [bf16, fp16, fp32]
4749
tracker_name: finetrainers
50+
transformer_dtype: [bf16, fp16, fp32]
4851
train_steps: 3000
4952
training_type: lora
5053
use_8bit_bnb: false

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "finetrainers-ui"
3-
version = "0.9.3"
3+
version = "0.10.0"
44
dependencies = [
55
"gradio",
66
"torch>=2.4.1"

run_trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,16 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
2020
assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required"
2121

2222
model_cmd = ["--model_name", config.get('model_name'),
23-
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path')]
23+
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path'),
24+
"--text_encoder_dtype", config.get('text_encoder_dtype'),
25+
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
26+
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
27+
"--vae_dtype", config.get('vae_dtype')]
28+
29+
if config.get('layerwise_upcasting_modules') != 'none':
30+
model_cmd +=["--layerwise_upcasting_modules", config.get('layerwise_upcasting_modules'),
31+
"--layerwise_upcasting_storage_dtype", config.get('layerwise_upcasting_storage_dtype'),
32+
"--layerwise_upcasting_skip_modules_pattern", config.get('layerwise_upcasting_skip_modules_pattern')]
2433

2534
dataset_cmd = ["--data_root", config.get('data_root'),
2635
"--video_column", config.get('video_column'),
@@ -36,6 +45,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
3645
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
3746
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
3847
"--vae_dtype", config.get('vae_dtype'),
48+
"--transformer_dtype", config.get('transformer_dtype'),
3949
'--precompute_conditions' if config.get('precompute_conditions') else '']
4050
if config.get('dataset_file'):
4151
dataset_cmd += ["--dataset_file", config.get('dataset_file')]
@@ -47,7 +57,6 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
4757

4858
training_cmd = ["--training_type", config.get('training_type'),
4959
"--seed", config.get('seed'),
50-
"--mixed_precision", config.get('mixed_precision'),
5160
"--batch_size", config.get('batch_size'),
5261
"--train_steps", config.get('train_steps'),
5362
"--rank", config.get('rank'),

tabs/general_tab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, title, config_file_path, allow_load=False):
1515

1616
try:
1717
with self.settings_column:
18-
inputs = self.update_form(self.config)
18+
inputs = self.update_form()
1919
self.components = OrderedDict(inputs)
2020
children = []
2121
for child in self.settings_column.children:

tabs/prepare_tab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, title, config_file_path, allow_load=False):
2020

2121
try:
2222
with self.settings_column:
23-
self.components = OrderedDict(self.update_form(self.config))
23+
self.components = OrderedDict(self.update_form())
2424
for i in range(len(self.settings_column.children)):
2525
keys = list(self.components.keys())
2626
properties[keys[i]] = self.settings_column.children[i]

tabs/tab.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def add_buttons(self):
7171
outputs=[self.save_status, self.config_file_box, *self.get_properties().values()]
7272
)
7373

74-
def update_form(self, config):
74+
def update_form(self):
7575
inputs = dict()
7676

77-
for key, value in config.items():
77+
for key, value in self.config.items():
7878
category = 'Other'
7979
for categories in self.config_categories.keys():
8080
if key in self.config_categories[categories]:
@@ -114,6 +114,6 @@ def update_properties(self, *args):
114114

115115
properties_values[index] = value
116116
#properties[key].value = value
117-
return ["Config loaded. Edit below:", config_file_box, *properties_values]
117+
return ["Config loaded.", config_file_box, *properties_values]
118118
except Exception as e:
119119
return [f"Error loading config: {e}", config_file_box, *properties_values]

tabs/training_tab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, title, config_file_path, allow_load=False):
3030

3131
try:
3232
with self.settings_column:
33-
inputs = self.update_form(self.config)
33+
inputs = self.update_form()
3434
self.components = OrderedDict(inputs)
3535
children = []
3636
for child in self.settings_column.children:

tabs/training_tab_legacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, title, config_file_path, allow_load=False):
1717

1818
try:
1919
with self.settings_column:
20-
self.components = OrderedDict(self.update_form(self.config))
20+
self.components = OrderedDict(self.update_form())
2121
for i in range(len(self.settings_column.children)):
2222
keys = list(self.components.keys())
2323
properties[keys[i]] = self.settings_column.children[i]

trainer_config_validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def validate(self):
2929
'lr_scheduler',
3030
'lr_warmup_steps',
3131
'max_grad_norm',
32-
'mixed_precision',
3332
'model_name',
3433
'nccl_timeout',
3534
'optimizer',

0 commit comments

Comments
 (0)