Skip to content

Commit d7e2044

Browse files
committed
Refactoring learning utils a bit..
1 parent d33f03e commit d7e2044

File tree

2 files changed

+149
-215
lines changed

2 files changed

+149
-215
lines changed

neuralmonkey/experiment.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -239,29 +239,34 @@ def train(self) -> None:
239239
with self.graph.as_default():
240240
self.model.tf_manager.init_saving(self.get_path("variables.data"))
241241

242-
training_loop(
243-
tf_manager=self.model.tf_manager,
244-
epochs=self.model.epochs,
245-
trainers=self.model.trainers,
246-
batching_scheme=self.model.batching_scheme,
247-
runners_batching_scheme=self.model.runners_batching_scheme,
248-
log_directory=self.model.output,
249-
evaluators=self.model.evaluation,
250-
main_metric=self.model.main_metric,
251-
runners=self.model.runners,
252-
train_dataset=self.model.train_dataset,
253-
val_datasets=self.model.val_datasets,
254-
test_datasets=self.model.test_datasets,
255-
log_timer=self.model.log_timer,
256-
val_timer=self.model.val_timer,
257-
val_preview_input_series=self.model.val_preview_input_series,
258-
val_preview_output_series=self.model.val_preview_output_series,
259-
val_preview_num_examples=self.model.val_preview_num_examples,
260-
postprocess=self.model.postprocess,
261-
train_start_offset=self.model.train_start_offset,
262-
initial_variables=self.model.initial_variables,
263-
final_variables=self.get_path("variables.data.final"))
264-
242+
training_loop(cfg=self.model)
243+
244+
final_variables = self.get_path("variables.data.final")
245+
log("Saving final variables in {}".format(final_variables))
246+
self.model.tf_manager.save(final_variables)
247+
248+
if self.model.test_datasets:
249+
self.model.tf_manager.restore_best_vars()
250+
251+
for dataset in self.model.test_datasets:
252+
test_results, test_outputs = run_on_dataset(
253+
self.model.tf_manager,
254+
self.model.runners,
255+
dataset,
256+
self.model.postprocess,
257+
write_out=True,
258+
batching_scheme=self.model.runners_batching_scheme)
259+
260+
# ensure test outputs are iterable more than once
261+
test_outputs = {
262+
k: list(v) for k, v in test_outputs.items()}
263+
eval_result = evaluation(
264+
self.model.evaluation, dataset,
265+
self.model.runners, test_results, test_outputs)
266+
267+
print_final_evaluation(dataset.name, eval_result)
268+
269+
log("Finished.")
265270
self._vars_loaded = True
266271

267272
def load_variables(self, variable_files: List[str] = None) -> None:

0 commit comments

Comments
 (0)