1
1
2
2
import datetime
3
+ import time
3
4
import os
5
+ from threading import Thread
4
6
import gradio as gr
5
7
from typing import OrderedDict
6
8
from config import Config
@@ -48,10 +50,11 @@ def __init__(self, title, config_file_path, allow_load=False):
48
50
with gr .Row (equal_height = True ):
49
51
run_button = gr .Button ("Start Training" , key = 'run_trainer' )
50
52
stop_button = gr .Button ("Stop" , key = 'stop_trainer' )
53
+ delay_box = gr .Number (value = 0 , label = "Delay start (minutes)" , minimum = 0 )
51
54
52
55
log_output = gr .File (label = "Log File" , interactive = False )
53
56
run_button .click (self .run_trainer ,
54
- inputs = [* properties .values ()],
57
+ inputs = [delay_box , * properties .values ()],
55
58
outputs = [self .output_box , log_output ]
56
59
)
57
60
@@ -60,8 +63,8 @@ def __init__(self, title, config_file_path, allow_load=False):
60
63
def get_properties (self ) -> OrderedDict :
61
64
return properties
62
65
63
- def run_trainer (self , * args ):
64
- time = datetime .datetime .now ()
66
+ def run_trainer (self , delay , * args ):
67
+ current_time = datetime .datetime .now ()
65
68
properties_values = list (args )
66
69
keys_list = list (properties .keys ())
67
70
@@ -80,17 +83,22 @@ def run_trainer(self, *args):
80
83
81
84
output_path = os .path .join (properties ['output_dir' ].value , "config" )
82
85
os .makedirs (output_path , exist_ok = True )
83
- self .save_edits (os .path .join (output_path , "config_{}.yaml" .format (time )), * properties_values )
84
-
85
- log_file = os .path .join (output_path , "log_{}.txt" .format (time ))
86
-
87
- result = self .trainer .run (config , config .get ('path_to_finetrainers' ), log_file )
88
- self .trainer .running = False
89
- if isinstance (result , str ):
90
- return result , log_file
91
- if result .returncode == 0 :
92
- return "Training finished. Please see the log file for more details." , log_file
93
- return "Training failed. Please see the log file for more details." , log_file
86
+ self .save_edits (os .path .join (output_path , "config_{}.yaml" .format (current_time )), * properties_values )
87
+
88
+ log_file = os .path .join (output_path , "log_{}.txt" .format (current_time ))
89
+
90
+ if delay :
91
+ time .sleep (int (delay ) * 60 )
92
+ Thread (target = self .trainer .run , args = (config , config .get ('path_to_finetrainers' ), log_file ), daemon = True ).start ()
93
+ return "Training is running asynchronously, no result returned to gradio. Please see the log file for more details." , log_file
94
+ else :
95
+ result = self .trainer .run (config , config .get ('path_to_finetrainers' ), log_file )
96
+ self .trainer .running = False
97
+ if isinstance (result , str ):
98
+ return result , log_file
99
+ if result .returncode == 0 :
100
+ return "Training finished. Please see the log file for more details." , log_file
101
+ return "Training failed. Please see the log file for more details." , log_file
94
102
95
103
def stop_trainer (self ):
96
104
self .trainer .stop ()
0 commit comments