@@ -239,29 +239,34 @@ def train(self) -> None:
239
239
with self .graph .as_default ():
240
240
self .model .tf_manager .init_saving (self .get_path ("variables.data" ))
241
241
242
- training_loop (
243
- tf_manager = self .model .tf_manager ,
244
- epochs = self .model .epochs ,
245
- trainers = self .model .trainers ,
246
- batching_scheme = self .model .batching_scheme ,
247
- runners_batching_scheme = self .model .runners_batching_scheme ,
248
- log_directory = self .model .output ,
249
- evaluators = self .model .evaluation ,
250
- main_metric = self .model .main_metric ,
251
- runners = self .model .runners ,
252
- train_dataset = self .model .train_dataset ,
253
- val_datasets = self .model .val_datasets ,
254
- test_datasets = self .model .test_datasets ,
255
- log_timer = self .model .log_timer ,
256
- val_timer = self .model .val_timer ,
257
- val_preview_input_series = self .model .val_preview_input_series ,
258
- val_preview_output_series = self .model .val_preview_output_series ,
259
- val_preview_num_examples = self .model .val_preview_num_examples ,
260
- postprocess = self .model .postprocess ,
261
- train_start_offset = self .model .train_start_offset ,
262
- initial_variables = self .model .initial_variables ,
263
- final_variables = self .get_path ("variables.data.final" ))
264
-
242
+ training_loop (cfg = self .model )
243
+
244
+ final_variables = self .get_path ("variables.data.final" )
245
+ log ("Saving final variables in {}" .format (final_variables ))
246
+ self .model .tf_manager .save (final_variables )
247
+
248
+ if self .model .test_datasets :
249
+ self .model .tf_manager .restore_best_vars ()
250
+
251
+ for dataset in self .model .test_datasets :
252
+ test_results , test_outputs = run_on_dataset (
253
+ self .model .tf_manager ,
254
+ self .model .runners ,
255
+ dataset ,
256
+ self .model .postprocess ,
257
+ write_out = True ,
258
+ batching_scheme = self .model .runners_batching_scheme )
259
+
260
+ # ensure test outputs are iterable more than once
261
+ test_outputs = {
262
+ k : list (v ) for k , v in test_outputs .items ()}
263
+ eval_result = evaluation (
264
+ self .model .evaluation , dataset ,
265
+ self .model .runners , test_results , test_outputs )
266
+
267
+ print_final_evaluation (dataset .name , eval_result )
268
+
269
+ log ("Finished." )
265
270
self ._vars_loaded = True
266
271
267
272
def load_variables (self , variable_files : List [str ] = None ) -> None :
0 commit comments