Skip to content

Commit d80b243

Browse files
authored
Merge pull request #25 from neph1/update-v0.11.2
optional delayed start
2 parents 3b7f034 + 3c68c40 commit d80b243

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

tabs/training_tab.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11

22
import datetime
3+
import time
34
import os
5+
from threading import Thread
46
import gradio as gr
57
from typing import OrderedDict
68
from config import Config
@@ -48,10 +50,11 @@ def __init__(self, title, config_file_path, allow_load=False):
4850
with gr.Row(equal_height=True):
4951
run_button = gr.Button("Start Training", key='run_trainer')
5052
stop_button = gr.Button("Stop", key='stop_trainer')
53+
delay_box = gr.Number(value=0, label="Delay start (minutes)", minimum=0)
5154

5255
log_output = gr.File(label="Log File", interactive=False)
5356
run_button.click(self.run_trainer,
54-
inputs=[*properties.values()],
57+
inputs=[delay_box, *properties.values()],
5558
outputs=[self.output_box, log_output]
5659
)
5760

@@ -60,8 +63,8 @@ def __init__(self, title, config_file_path, allow_load=False):
6063
def get_properties(self) -> OrderedDict:
6164
return properties
6265

63-
def run_trainer(self, *args):
64-
time = datetime.datetime.now()
66+
def run_trainer(self, delay, *args):
67+
current_time = datetime.datetime.now()
6568
properties_values = list(args)
6669
keys_list = list(properties.keys())
6770

@@ -80,17 +83,22 @@ def run_trainer(self, *args):
8083

8184
output_path = os.path.join(properties['output_dir'].value, "config")
8285
os.makedirs(output_path, exist_ok=True)
83-
self.save_edits(os.path.join(output_path, "config_{}.yaml".format(time)), *properties_values)
84-
85-
log_file = os.path.join(output_path, "log_{}.txt".format(time))
86-
87-
result = self.trainer.run(config, config.get('path_to_finetrainers'), log_file)
88-
self.trainer.running = False
89-
if isinstance(result, str):
90-
return result, log_file
91-
if result.returncode == 0:
92-
return "Training finished. Please see the log file for more details.", log_file
93-
return "Training failed. Please see the log file for more details.", log_file
86+
self.save_edits(os.path.join(output_path, "config_{}.yaml".format(current_time)), *properties_values)
87+
88+
log_file = os.path.join(output_path, "log_{}.txt".format(current_time))
89+
90+
if delay:
91+
time.sleep(int(delay) * 60)
92+
Thread(target=self.trainer.run, args=(config, config.get('path_to_finetrainers'), log_file), daemon=True).start()
93+
return "Training is running asynchronously, no result returned to gradio. Please see the log file for more details.", log_file
94+
else:
95+
result = self.trainer.run(config, config.get('path_to_finetrainers'), log_file)
96+
self.trainer.running = False
97+
if isinstance(result, str):
98+
return result, log_file
99+
if result.returncode == 0:
100+
return "Training finished. Please see the log file for more details.", log_file
101+
return "Training failed. Please see the log file for more details.", log_file
94102

95103
def stop_trainer(self):
96104
self.trainer.stop()

0 commit comments

Comments
 (0)