Skip to content

Commit ce6da7b

Browse files
authored
Merge pull request #10 from neph1/update-v0.8.0
Update v0.8.0
2 parents cff511d + 431f992 commit ce6da7b

10 files changed

+201
-35
lines changed

app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class App:
1717

1818
def __init__(self):
1919
self.configs_path = "config/"
20-
self.tabs = dict() # Type Tab
20+
self.tabs = dict()
2121
self.setup_views()
2222

2323
def setup_views(self):

editor_factory.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

gradio_functions.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[project]
2+
name = "finetrainers-ui"
3+
version = "0.8.0"
4+
dependencies = [
5+
"gradio",
6+
"torch>=2.4.1"
7+
]
8+
9+
10+
[project.urls]
11+
Homepage = "https://github.com/neph1/finetrainers-ui"

run_trainer.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
9494
return "Unknown result"
9595

9696
def stop(self):
97-
self.running = False
98-
if self.process:
99-
self.process.terminate()
100-
time.sleep(3)
101-
if self.process.poll() is None:
102-
self.process.kill()
103-
return "Training stopped"
97+
try:
98+
self.running = False
99+
if self.process:
100+
self.process.terminate()
101+
time.sleep(3)
102+
if self.process.poll() is None:
103+
self.process.kill()
104+
except Exception as e:
105+
return f"Error stopping training: {e}"
106+
finally:
107+
self.process.wait()
108+
return "Training forcibly stopped"

tabs/tab.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import OrderedDict
33
import gradio as gr
44
import yaml
5-
import editor_factory
65

76
class Tab(ABC):
87

tabs/training_tab.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import os
44
import gradio as gr
55
from typing import OrderedDict
6-
from config import Config, global_config
6+
from config import Config
77

88
from run_trainer import RunTrainer
99
from tabs import general_tab
1010
from tabs.tab import Tab
11+
from trainer_config_validator import TrainerValidator
1112

1213
properties = OrderedDict()
1314

@@ -69,16 +70,21 @@ def run_trainer(self, *args):
6970
key = keys_list[index]
7071
properties[key].value = properties_values[index]
7172
config.set(key, properties_values[index])
73+
config.set('path_to_finetrainers', general_tab.properties['path_to_finetrainers'].value)
74+
75+
config_validator = TrainerValidator(config)
76+
try:
77+
config_validator.validate()
78+
except Exception as e:
79+
return str(e), None
7280

7381
output_path = os.path.join(properties['output_dir'].value, "config")
7482
os.makedirs(output_path, exist_ok=True)
7583
self.save_edits(os.path.join(output_path, "config_{}.yaml".format(time)), *properties_values)
7684

7785
log_file = os.path.join(output_path, "log_{}.txt".format(time))
7886

79-
if not general_tab.properties['path_to_finetrainers'].value:
80-
return "Please set the path to finetrainers in General Settings"
81-
result = self.trainer.run(config, general_tab.properties['path_to_finetrainers'].value, log_file)
87+
result = self.trainer.run(config, config.get('path_to_finetrainers'), log_file)
8288
self.trainer.running = False
8389
if isinstance(result, str):
8490
return result, log_file

test/__init__.py

Whitespace-only changes.

test/test_trainer_config_validator.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os
2+
import pytest
3+
from unittest.mock import patch
4+
5+
from trainer_config_validator import TrainerValidator
6+
7+
@pytest.fixture
8+
def valid_config():
9+
return {
10+
'path_to_finetrainers': '/path/to/finetrainers',
11+
'accelerate_config': 'config1',
12+
'batch_size': 32,
13+
'beta1': 0.9,
14+
'beta2': 0.999,
15+
'caption_column': 'captions.txt',
16+
'caption_dropout_p': 0.1,
17+
'checkpointing_limit': 5,
18+
'checkpointing_steps': 1000,
19+
'data_root': '/path/to/data',
20+
'dataloader_num_workers': 0,
21+
'epsilon': 1e-8,
22+
'gpu_ids': '0,1',
23+
'gradient_accumulation_steps': 2,
24+
'gradient_checkpointing': True,
25+
'id_token': 'token123',
26+
'lora_alpha': 128,
27+
'lr': 0.001,
28+
'lr_num_cycles': 10,
29+
'lr_scheduler': 'scheduler1',
30+
'lr_warmup_steps': 500,
31+
'max_grad_norm': 1.0,
32+
'mixed_precision': 'fp16',
33+
'model_name': 'model_v1',
34+
'nccl_timeout': 60,
35+
'optimizer': 'adam',
36+
'pretrained_model_name_or_path': 'pretrained_model',
37+
'rank': 64,
38+
'seed': 42,
39+
'target_modules': 'module1',
40+
'tracker_name': 'tracker',
41+
'train_steps': 10000,
42+
'training_type': 'type1',
43+
'validation_steps': 100,
44+
'video_column': 'videos.txt',
45+
'video_resolution_buckets': '24x480x720',
46+
'weight_decay': 0.01
47+
}
48+
49+
@pytest.fixture
50+
def trainer_validator(valid_config):
51+
return TrainerValidator(valid_config)
52+
53+
def test_valid_config(valid_config):
54+
trainer_validator = TrainerValidator(valid_config)
55+
with patch('os.path.isfile', return_value=True), patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
56+
trainer_validator.validate()
57+
58+
def test_validate_data_root_invalid(trainer_validator):
59+
trainer_validator.config['data_root'] = '/invalid/path'
60+
with pytest.raises(ValueError, match="data_root path /invalid/path does not exist"):
61+
trainer_validator.validate_data_root()
62+
63+
def test_validate_data_root_valid(trainer_validator):
64+
with patch('os.path.exists', return_value=True), patch('os.path.isdir', return_value=True):
65+
trainer_validator.config['data_root'] = '/path/to/data'
66+
trainer_validator.validate_data_root()
67+
68+
def test_validate_video_resolution_buckets_invalid(trainer_validator):
69+
trainer_validator.config['video_resolution_buckets'] = '720p,1080p,4k'
70+
with pytest.raises(ValueError, match=f"Each bucket must have the format '<frames>x<height>x<width>', but got {trainer_validator.config['video_resolution_buckets']}"):
71+
trainer_validator.validate_video_resolution_buckets()
72+
73+
def test_validate_video_resolution_buckets_valid(trainer_validator):
74+
trainer_validator.config['video_resolution_buckets'] = '24x480x720'
75+
trainer_validator.validate_video_resolution_buckets()
76+
77+
trainer_validator.config['video_resolution_buckets'] = '8x320x512 24x480x720 30x720x1280'
78+
trainer_validator.validate_video_resolution_buckets()

trainer_config_validator.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import os
2+
import re
3+
4+
class TrainerValidator:
5+
6+
def __init__(self, config):
7+
self.config = config
8+
9+
def validate(self):
10+
required_settings = [
11+
'path_to_finetrainers',
12+
'accelerate_config',
13+
'batch_size',
14+
'beta1',
15+
'beta2',
16+
'caption_column',
17+
'caption_dropout_p',
18+
'checkpointing_limit',
19+
'checkpointing_steps',
20+
'data_root',
21+
'dataloader_num_workers',
22+
'epsilon',
23+
'gpu_ids',
24+
'gradient_accumulation_steps',
25+
'id_token',
26+
'lora_alpha',
27+
'lr',
28+
'lr_num_cycles',
29+
'lr_scheduler',
30+
'lr_warmup_steps',
31+
'max_grad_norm',
32+
'mixed_precision',
33+
'model_name',
34+
'nccl_timeout',
35+
'optimizer',
36+
'pretrained_model_name_or_path',
37+
'rank',
38+
'seed',
39+
'target_modules',
40+
'train_steps',
41+
'training_type',
42+
'validation_steps',
43+
'video_column',
44+
'video_resolution_buckets',
45+
'weight_decay'
46+
]
47+
48+
for setting in required_settings:
49+
if not self.config.get(setting) and self.config.get(setting) != 0:
50+
raise ValueError(f"{setting} is required")
51+
52+
self.validate_finetrainers_path()
53+
self.validate_data_root()
54+
self.validate_caption_column()
55+
self.validate_video_column()
56+
self.validate_video_resolution_buckets()
57+
58+
def validate_finetrainers_path(self):
59+
train_script_path = os.path.join(self.config.get('path_to_finetrainers'), 'train.py')
60+
if not os.path.isfile(train_script_path):
61+
raise ValueError(f"train.py does not exist at {self.config.get('path_to_finetrainers')}")
62+
63+
def validate_caption_column(self):
64+
data_root = self.config.get('data_root')
65+
caption_column = self.config.get('caption_column')
66+
if data_root and caption_column:
67+
file_path = os.path.join(data_root, caption_column)
68+
if not os.path.isfile(file_path):
69+
raise ValueError(f"File {caption_column} does not exist at {data_root}")
70+
71+
def validate_data_root(self):
72+
data_root = self.config.get('data_root')
73+
if data_root and not os.path.isdir(data_root):
74+
raise ValueError(f"data_root path {data_root} does not exist")
75+
76+
def validate_video_column(self):
77+
data_root = self.config.get('data_root')
78+
video_column = self.config.get('video_column')
79+
if data_root and video_column:
80+
file_path = os.path.join(data_root, video_column)
81+
if not os.path.isfile(file_path):
82+
raise ValueError(f"File {video_column} does not exist at {data_root}")
83+
84+
def validate_video_resolution_buckets(self):
85+
buckets = self.config.get('video_resolution_buckets')
86+
split_buckets = buckets.split(' ')
87+
for bucket in split_buckets:
88+
if not isinstance(bucket, str) or not re.match(r'^\d+x\d+x\d+$', bucket):
89+
raise ValueError(f"Each bucket must have the format '<frames>x<height>x<width>', but got {bucket}")

0 commit comments

Comments
 (0)