25
25
print_final_evaluation )
26
26
from neuralmonkey .model .sequence import EmbeddedFactorSequence
27
27
from neuralmonkey .runners .base_runner import ExecutionResult
28
+ from neuralmonkey .runners .dataset_runner import DatasetRunner
28
29
29
30
30
31
_TRAIN_ARGS = [
@@ -160,6 +161,8 @@ def register_inputs(self) -> None:
160
161
for feedable in feedables :
161
162
feedable .register_input ()
162
163
164
+ self .model .dataset_runner .register_input ()
165
+
163
166
def build_model (self ) -> None :
164
167
"""Build the configuration and the computational graph.
165
168
@@ -197,6 +200,9 @@ def build_model(self) -> None:
197
200
self ._model = self .config .model
198
201
self ._model_built = True
199
202
203
+ # prepare dataset runner
204
+ self .model .dataset_runner = DatasetRunner ()
205
+
200
206
# build dataset
201
207
self .register_inputs ()
202
208
@@ -257,27 +263,6 @@ def train(self) -> None:
257
263
log ("Saving final variables in {}" .format (final_variables ))
258
264
self .model .tf_manager .save (final_variables )
259
265
260
- if self .model .test_datasets :
261
- self .model .tf_manager .restore_best_vars ()
262
-
263
- for dataset in self .model .test_datasets :
264
- test_results , test_outputs = run_on_dataset (
265
- self .model .tf_manager ,
266
- self .model .runners ,
267
- dataset ,
268
- self .model .postprocess ,
269
- write_out = True ,
270
- batching_scheme = self .model .runners_batching_scheme )
271
-
272
- # ensure test outputs are iterable more than once
273
- test_outputs = {
274
- k : list (v ) for k , v in test_outputs .items ()}
275
- eval_result = evaluation (
276
- self .model .evaluation , dataset ,
277
- self .model .runners , test_results , test_outputs )
278
-
279
- print_final_evaluation (dataset .name , eval_result )
280
-
281
266
log ("Finished." )
282
267
self ._vars_loaded = True
283
268
@@ -321,8 +306,8 @@ def run_model(self,
321
306
dataset : Dataset ,
322
307
write_out : bool = False ,
323
308
batch_size : int = None ,
324
- log_progress : int = 0 ) -> Tuple [List [ ExecutionResult ],
325
- Dict [str , List [ Any ] ]]:
309
+ log_progress : int = 0 ) -> Tuple [
310
+ List [ ExecutionResult ], Dict [ str , List ], Dict [str , List ]]:
326
311
"""Run the model on a given dataset.
327
312
328
313
Args:
@@ -352,16 +337,21 @@ def run_model(self,
352
337
with self .graph .as_default ():
353
338
# TODO: check_dataset_and_coders(dataset, self.model.runners)
354
339
return run_on_dataset (
355
- self .model .tf_manager , self .model .runners , dataset ,
340
+ self .model .tf_manager ,
341
+ self .model .runners ,
342
+ self .model .dataset_runner ,
343
+ dataset ,
356
344
self .model .postprocess ,
357
- write_out = write_out , log_progress = log_progress ,
345
+ write_out = write_out ,
346
+ log_progress = log_progress ,
358
347
batching_scheme = batching_scheme )
359
348
360
349
def evaluate (self ,
361
350
dataset : Dataset ,
362
351
write_out : bool = False ,
363
352
batch_size : int = None ,
364
- log_progress : int = 0 ) -> Dict [str , Any ]:
353
+ log_progress : int = 0 ,
354
+ name : str = None ) -> Dict [str , Any ]:
365
355
"""Run the model on a given dataset and evaluate the outputs.
366
356
367
357
Args:
@@ -370,23 +360,24 @@ def evaluate(self,
370
360
defined in the dataset object.
371
361
batch_size: size of the minibatch
372
362
log_progress: log progress every X seconds
363
+ name: The name of the evaluated dataset
373
364
374
365
Returns:
375
366
Dictionary of evaluation names and their values which includes the
376
367
metrics applied on respective series loss and loss values from the
377
368
run.
378
369
"""
379
- execution_results , output_data = self .run_model (
370
+ execution_results , output_data , f_dataset = self .run_model (
380
371
dataset , write_out , batch_size , log_progress )
381
372
382
373
evaluators = [(e [0 ], e [0 ], e [1 ]) if len (e ) == 2 else e
383
374
for e in self .model .evaluation ]
384
375
with self .graph .as_default ():
385
376
eval_result = evaluation (
386
- evaluators , dataset , self .model .runners ,
377
+ evaluators , f_dataset , self .model .runners ,
387
378
execution_results , output_data )
388
379
if eval_result :
389
- print_final_evaluation (dataset . name , eval_result )
380
+ print_final_evaluation (eval_result , name )
390
381
391
382
return eval_result
392
383
0 commit comments