Skip to content

Commit e3f0f68

Browse files
committed
Fixing temporary bug with inputs and bucketing
when series are bucketed (i.e. batches() do not return data in the same order as get_series()), inputs were returned in the wrong order.
1 parent 1f58c92 commit e3f0f68

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

neuralmonkey/learning_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ def run_on_dataset(tf_manager: TensorFlowManager,
317317
feedables = set.union(*[runner.feedables for runner in runners])
318318
feedables |= dataset_runner.feedables
319319

320+
fetched_input = {s: [] for s in dataset.series} # type: Dict[str, List]
321+
320322
processed_examples = 0
321323
for batch in dataset.batches():
322324
if 0 < log_progress < time.process_time() - last_log_time:
@@ -335,6 +337,9 @@ def run_on_dataset(tf_manager: TensorFlowManager,
335337
for script_list, ex_result in zip(batch_results, execution_results):
336338
script_list.append(ex_result)
337339

340+
for s_id in batch.series:
341+
fetched_input[s_id].extend(batch.get_series(s_id))
342+
338343
# Transpose runner interim results.
339344
all_results = [join_execution_results(res) for res in batch_results[:-1]]
340345

@@ -343,7 +348,6 @@ def run_on_dataset(tf_manager: TensorFlowManager,
343348
# fetched_input = {
344349
# k: [dic[k] for dic in input_transposed] for k in input_transposed[0]}
345350

346-
fetched_input = {s: list(dataset.get_series(s)) for s in dataset.series}
347351
fetched_input_lengths = {s: len(fetched_input[s]) for s in dataset.series}
348352

349353
if len(set(fetched_input_lengths.values())) != 1:

0 commit comments

Comments
 (0)