12
12
from termcolor import colored
13
13
from typeguard import check_argument_types
14
14
15
- from neuralmonkey .logging import log , log_print , warn , notice
15
+ from neuralmonkey .logging import log , log_print , warn
16
16
from neuralmonkey .dataset import Dataset , BatchingScheme
17
17
from neuralmonkey .tf_manager import TensorFlowManager
18
18
from neuralmonkey .runners .base_runner import (
19
19
BaseRunner , ExecutionResult , reduce_execution_results )
20
20
from neuralmonkey .trainers .generic_trainer import GenericTrainer
21
21
from neuralmonkey .trainers .multitask_trainer import MultitaskTrainer
22
22
from neuralmonkey .trainers .delayed_update_trainer import DelayedUpdateTrainer
23
+ from neuralmonkey .training_profiler import TrainingProfiler
23
24
24
25
# pylint: disable=invalid-name
25
26
Evaluation = Dict [str , float ]
@@ -133,14 +134,12 @@ def training_loop(tf_manager: TensorFlowManager,
133
134
log ("TensorBoard writer initialized." )
134
135
135
136
log ("Starting training" )
136
- last_log_time = time .process_time ()
137
- last_val_time = time .process_time ()
137
+ profiler = TrainingProfiler ()
138
+ profiler .training_start ()
139
+
138
140
interrupt = None
139
141
try :
140
142
for epoch_n in range (1 , epochs + 1 ):
141
- log_print ("" )
142
- log ("Epoch {} begins" .format (epoch_n ), color = "red" )
143
-
144
143
train_batches = train_dataset .batches (batching_scheme )
145
144
146
145
if epoch_n == 1 and train_start_offset :
@@ -150,11 +149,15 @@ def training_loop(tf_manager: TensorFlowManager,
150
149
else :
151
150
_skip_lines (train_start_offset , train_batches )
152
151
152
+ log_print ("" )
153
+ log ("Epoch {} begins" .format (epoch_n ), color = "red" )
154
+ profiler .epoch_start ()
155
+
153
156
for batch_n , batch in enumerate (train_batches ):
154
157
step += 1
155
158
seen_instances += len (batch )
156
159
157
- if log_timer (step , last_log_time ):
160
+ if log_timer (step , profiler . last_log_time ):
158
161
trainer_result = tf_manager .execute (
159
162
batch , feedables , trainers , train = True , summaries = True )
160
163
train_results , train_outputs = run_on_dataset (
@@ -172,14 +175,18 @@ def training_loop(tf_manager: TensorFlowManager,
172
175
tb_writer , main_metric , train_evaluation ,
173
176
seen_instances , epoch_n , epochs , trainer_result ,
174
177
train = True )
175
- last_log_time = time .process_time ()
178
+
179
+ profiler .log_done ()
180
+
176
181
else :
177
182
tf_manager .execute (batch , feedables , trainers , train = True ,
178
183
summaries = False )
179
184
180
- if val_timer (step , last_val_time ):
185
+ if val_timer (step , profiler .last_val_time ):
186
+
181
187
log_print ("" )
182
- val_duration_start = time .process_time ()
188
+ profiler .validation_start ()
189
+
183
190
val_examples = 0
184
191
for val_id , valset in enumerate (val_datasets ):
185
192
val_examples += len (valset )
@@ -243,24 +250,12 @@ def training_loop(tf_manager: TensorFlowManager,
243
250
seen_instances , epoch_n , epochs , val_results ,
244
251
train = False , dataset_name = v_name )
245
252
246
- # how long was the training between validations
247
- training_duration = val_duration_start - last_val_time
248
- val_duration = time .process_time () - val_duration_start
249
-
250
- # the training should take at least twice the time of val.
251
- steptime = (training_duration
252
- / (seen_instances - last_seen_instances ))
253
- valtime = val_duration / val_examples
253
+ profiler .validation_done ()
254
+ profiler .log_after_validation (
255
+ val_examples , seen_instances - last_seen_instances )
254
256
last_seen_instances = seen_instances
255
- log ("Validation time: {:.2f}s, inter-validation: {:.2f}s, "
256
- "per-instance (train): {:.2f}s, per-instance (val): "
257
- "{:.2f}s" .format (val_duration , training_duration ,
258
- steptime , valtime ), color = "blue" )
259
- if training_duration < 2 * val_duration :
260
- notice ("Validation period setting is inefficient." )
261
257
262
258
log_print ("" )
263
- last_val_time = time .process_time ()
264
259
265
260
except KeyboardInterrupt as ex :
266
261
interrupt = ex
0 commit comments