Skip to content

Commit c871d03

Browse files
committed
Adding dataset runner, which will return input data in future
Plus, faking this "future" output in run_on_dataset and refactoring all subsequent calls to evaluation-related functions to work with the returned dataset, which is now not a Dataset instance, but merely a dict of series.
1 parent f2204cd commit c871d03

File tree

8 files changed

+148
-98
lines changed

8 files changed

+148
-98
lines changed

neuralmonkey/checking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def check_dataset_and_coders(dataset: Dataset,
4747
missing = []
4848

4949
for (serie, coder) in data_list:
50-
if not dataset.has_series(serie):
50+
if serie not in dataset:
5151
log("dataset {} does not have serie {}".format(
5252
dataset.name, serie))
5353
missing.append((coder, serie))

neuralmonkey/dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,7 @@ def __len__(self) -> int:
467467
assert self.length is not None
468468
return self.length
469469

470-
@property
471-
def series(self) -> List[str]:
472-
return list(sorted(self.iterators.keys()))
473-
474-
def has_series(self, name: str) -> bool:
470+
def __contains__(self, name: str) -> bool:
475471
"""Check if the dataset contains a series of a given name.
476472
477473
Arguments:
@@ -482,6 +478,10 @@ def has_series(self, name: str) -> bool:
482478
"""
483479
return name in self.iterators
484480

481+
@property
482+
def series(self) -> List[str]:
483+
return list(sorted(self.iterators.keys()))
484+
485485
def get_series(self, name: str) -> Iterator:
486486
"""Get the data series with a given name.
487487

neuralmonkey/experiment.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
print_final_evaluation)
2626
from neuralmonkey.model.sequence import EmbeddedFactorSequence
2727
from neuralmonkey.runners.base_runner import ExecutionResult
28+
from neuralmonkey.runners.dataset_runner import DatasetRunner
2829

2930

3031
_TRAIN_ARGS = [
@@ -160,6 +161,8 @@ def register_inputs(self) -> None:
160161
for feedable in feedables:
161162
feedable.register_input()
162163

164+
self.model.dataset_runner.register_input()
165+
163166
def build_model(self) -> None:
164167
"""Build the configuration and the computational graph.
165168
@@ -197,6 +200,9 @@ def build_model(self) -> None:
197200
self._model = self.config.model
198201
self._model_built = True
199202

203+
# prepare dataset runner
204+
self.model.dataset_runner = DatasetRunner()
205+
200206
# build dataset
201207
self.register_inputs()
202208

@@ -257,27 +263,6 @@ def train(self) -> None:
257263
log("Saving final variables in {}".format(final_variables))
258264
self.model.tf_manager.save(final_variables)
259265

260-
if self.model.test_datasets:
261-
self.model.tf_manager.restore_best_vars()
262-
263-
for dataset in self.model.test_datasets:
264-
test_results, test_outputs = run_on_dataset(
265-
self.model.tf_manager,
266-
self.model.runners,
267-
dataset,
268-
self.model.postprocess,
269-
write_out=True,
270-
batching_scheme=self.model.runners_batching_scheme)
271-
272-
# ensure test outputs are iterable more than once
273-
test_outputs = {
274-
k: list(v) for k, v in test_outputs.items()}
275-
eval_result = evaluation(
276-
self.model.evaluation, dataset,
277-
self.model.runners, test_results, test_outputs)
278-
279-
print_final_evaluation(dataset.name, eval_result)
280-
281266
log("Finished.")
282267
self._vars_loaded = True
283268

@@ -321,8 +306,8 @@ def run_model(self,
321306
dataset: Dataset,
322307
write_out: bool = False,
323308
batch_size: int = None,
324-
log_progress: int = 0) -> Tuple[List[ExecutionResult],
325-
Dict[str, List[Any]]]:
309+
log_progress: int = 0) -> Tuple[
310+
List[ExecutionResult], Dict[str, List], Dict[str, List]]:
326311
"""Run the model on a given dataset.
327312
328313
Args:
@@ -352,16 +337,21 @@ def run_model(self,
352337
with self.graph.as_default():
353338
# TODO: check_dataset_and_coders(dataset, self.model.runners)
354339
return run_on_dataset(
355-
self.model.tf_manager, self.model.runners, dataset,
340+
self.model.tf_manager,
341+
self.model.runners,
342+
self.model.dataset_runner,
343+
dataset,
356344
self.model.postprocess,
357-
write_out=write_out, log_progress=log_progress,
345+
write_out=write_out,
346+
log_progress=log_progress,
358347
batching_scheme=batching_scheme)
359348

360349
def evaluate(self,
361350
dataset: Dataset,
362351
write_out: bool = False,
363352
batch_size: int = None,
364-
log_progress: int = 0) -> Dict[str, Any]:
353+
log_progress: int = 0,
354+
name: str = None) -> Dict[str, Any]:
365355
"""Run the model on a given dataset and evaluate the outputs.
366356
367357
Args:
@@ -370,23 +360,24 @@ def evaluate(self,
370360
defined in the dataset object.
371361
batch_size: size of the minibatch
372362
log_progress: log progress every X seconds
363+
name: The name of the evaluated dataset
373364
374365
Returns:
375366
Dictionary of evaluation names and their values which includes the
376367
metrics applied on respective series loss and loss values from the
377368
run.
378369
"""
379-
execution_results, output_data = self.run_model(
370+
execution_results, output_data, f_dataset = self.run_model(
380371
dataset, write_out, batch_size, log_progress)
381372

382373
evaluators = [(e[0], e[0], e[1]) if len(e) == 2 else e
383374
for e in self.model.evaluation]
384375
with self.graph.as_default():
385376
eval_result = evaluation(
386-
evaluators, dataset, self.model.runners,
377+
evaluators, f_dataset, self.model.runners,
387378
execution_results, output_data)
388379
if eval_result:
389-
print_final_evaluation(dataset.name, eval_result)
380+
print_final_evaluation(eval_result, name)
390381

391382
return eval_result
392383

0 commit comments

Comments
 (0)