File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -317,6 +317,8 @@ def run_on_dataset(tf_manager: TensorFlowManager,
317
317
feedables = set .union (* [runner .feedables for runner in runners ])
318
318
feedables |= dataset_runner .feedables
319
319
320
+ fetched_input = {s : [] for s in dataset .series } # type: Dict[str, List]
321
+
320
322
processed_examples = 0
321
323
for batch in dataset .batches ():
322
324
if 0 < log_progress < time .process_time () - last_log_time :
@@ -335,6 +337,9 @@ def run_on_dataset(tf_manager: TensorFlowManager,
335
337
for script_list , ex_result in zip (batch_results , execution_results ):
336
338
script_list .append (ex_result )
337
339
340
+ for s_id in batch .series :
341
+ fetched_input [s_id ].extend (batch .get_series (s_id ))
342
+
338
343
# Transpose runner interim results.
339
344
all_results = [join_execution_results (res ) for res in batch_results [:- 1 ]]
340
345
@@ -343,7 +348,6 @@ def run_on_dataset(tf_manager: TensorFlowManager,
343
348
# fetched_input = {
344
349
# k: [dic[k] for dic in input_transposed] for k in input_transposed[0]}
345
350
346
- fetched_input = {s : list (dataset .get_series (s )) for s in dataset .series }
347
351
fetched_input_lengths = {s : len (fetched_input [s ]) for s in dataset .series }
348
352
349
353
if len (set (fetched_input_lengths .values ())) != 1 :
You can’t perform that action at this time.
0 commit comments