Skip to content

Commit 3b876eb

Browse files
committed
Encapsulate time-related operations into training profiler
1 parent b48e287 commit 3b876eb

File tree

2 files changed

+101
-25
lines changed

2 files changed

+101
-25
lines changed

neuralmonkey/learning_utils.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
from termcolor import colored
1313
from typeguard import check_argument_types
1414

15-
from neuralmonkey.logging import log, log_print, warn, notice
15+
from neuralmonkey.logging import log, log_print, warn
1616
from neuralmonkey.dataset import Dataset, BatchingScheme
1717
from neuralmonkey.tf_manager import TensorFlowManager
1818
from neuralmonkey.runners.base_runner import (
1919
BaseRunner, ExecutionResult, reduce_execution_results)
2020
from neuralmonkey.trainers.generic_trainer import GenericTrainer
2121
from neuralmonkey.trainers.multitask_trainer import MultitaskTrainer
2222
from neuralmonkey.trainers.delayed_update_trainer import DelayedUpdateTrainer
23+
from neuralmonkey.training_profiler import TrainingProfiler
2324

2425
# pylint: disable=invalid-name
2526
Evaluation = Dict[str, float]
@@ -133,14 +134,12 @@ def training_loop(tf_manager: TensorFlowManager,
133134
log("TensorBoard writer initialized.")
134135

135136
log("Starting training")
136-
last_log_time = time.process_time()
137-
last_val_time = time.process_time()
137+
profiler = TrainingProfiler()
138+
profiler.training_start()
139+
138140
interrupt = None
139141
try:
140142
for epoch_n in range(1, epochs + 1):
141-
log_print("")
142-
log("Epoch {} begins".format(epoch_n), color="red")
143-
144143
train_batches = train_dataset.batches(batching_scheme)
145144

146145
if epoch_n == 1 and train_start_offset:
@@ -150,11 +149,15 @@ def training_loop(tf_manager: TensorFlowManager,
150149
else:
151150
_skip_lines(train_start_offset, train_batches)
152151

152+
log_print("")
153+
log("Epoch {} begins".format(epoch_n), color="red")
154+
profiler.epoch_start()
155+
153156
for batch_n, batch in enumerate(train_batches):
154157
step += 1
155158
seen_instances += len(batch)
156159

157-
if log_timer(step, last_log_time):
160+
if log_timer(step, profiler.last_log_time):
158161
trainer_result = tf_manager.execute(
159162
batch, feedables, trainers, train=True, summaries=True)
160163
train_results, train_outputs = run_on_dataset(
@@ -172,14 +175,18 @@ def training_loop(tf_manager: TensorFlowManager,
172175
tb_writer, main_metric, train_evaluation,
173176
seen_instances, epoch_n, epochs, trainer_result,
174177
train=True)
175-
last_log_time = time.process_time()
178+
179+
profiler.log_done()
180+
176181
else:
177182
tf_manager.execute(batch, feedables, trainers, train=True,
178183
summaries=False)
179184

180-
if val_timer(step, last_val_time):
185+
if val_timer(step, profiler.last_val_time):
186+
181187
log_print("")
182-
val_duration_start = time.process_time()
188+
profiler.validation_start()
189+
183190
val_examples = 0
184191
for val_id, valset in enumerate(val_datasets):
185192
val_examples += len(valset)
@@ -243,24 +250,12 @@ def training_loop(tf_manager: TensorFlowManager,
243250
seen_instances, epoch_n, epochs, val_results,
244251
train=False, dataset_name=v_name)
245252

246-
# how long was the training between validations
247-
training_duration = val_duration_start - last_val_time
248-
val_duration = time.process_time() - val_duration_start
249-
250-
# the training should take at least twice the time of val.
251-
steptime = (training_duration
252-
/ (seen_instances - last_seen_instances))
253-
valtime = val_duration / val_examples
253+
profiler.validation_done()
254+
profiler.log_after_validation(
255+
val_examples, seen_instances - last_seen_instances)
254256
last_seen_instances = seen_instances
255-
log("Validation time: {:.2f}s, inter-validation: {:.2f}s, "
256-
"per-instance (train): {:.2f}s, per-instance (val): "
257-
"{:.2f}s".format(val_duration, training_duration,
258-
steptime, valtime), color="blue")
259-
if training_duration < 2 * val_duration:
260-
notice("Validation period setting is inefficient.")
261257

262258
log_print("")
263-
last_val_time = time.process_time()
264259

265260
except KeyboardInterrupt as ex:
266261
interrupt = ex

neuralmonkey/training_profiler.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# pylint: disable=unused-import
2+
from typing import List, Optional
3+
# pylint: enable=unused-import
4+
import time
5+
6+
from neuralmonkey.logging import log, notice
7+
8+
9+
class TrainingProfiler:
10+
11+
def __init__(self) -> None:
12+
self._start_time = None # type: Optional[float]
13+
self._epoch_starts = [] # type: List[float]
14+
15+
self._last_val_time = None # type: Optional[float]
16+
self._last_log_time = None # type: Optional[float]
17+
self._current_validation_start = None # type: Optional[float]
18+
19+
self.inter_val_times = [] # type: List[float]
20+
self.validation_times = [] # type: List[float]
21+
22+
self.time = time.process_time
23+
24+
@property
25+
def start_time(self) -> float:
26+
if self._start_time is None:
27+
raise RuntimeError("Training did not start yet")
28+
return self._start_time
29+
30+
@property
31+
def last_log_time(self) -> float:
32+
if self._last_log_time is None:
33+
return self.start_time
34+
return self._last_log_time
35+
36+
@property
37+
def last_val_time(self) -> float:
38+
if self._last_val_time is None:
39+
return self.start_time
40+
return self._last_val_time
41+
42+
def training_start(self) -> None:
43+
self._start_time = self.time()
44+
45+
def epoch_start(self) -> None:
46+
self._epoch_starts.append(self.time())
47+
48+
def log_done(self) -> None:
49+
self._last_log_time = self.time()
50+
51+
def validation_start(self) -> None:
52+
assert self._current_validation_start is None
53+
self._current_validation_start = self.time()
54+
self.inter_val_times.append(
55+
self._current_validation_start - self.last_val_time)
56+
57+
def validation_done(self) -> None:
58+
assert self._current_validation_start is not None
59+
self._last_val_time = self.time()
60+
61+
self.validation_times.append(
62+
self.last_val_time - self._current_validation_start)
63+
64+
self._current_validation_start = None
65+
66+
def log_after_validation(
67+
self, val_examples: int, train_examples: int) -> None:
68+
69+
train_duration = self.inter_val_times[-1]
70+
val_duration = self.validation_times[-1]
71+
72+
train_speed = train_examples / train_duration
73+
val_speed = val_examples / val_duration
74+
75+
log("Validation time: {:.2f}s ({:.1f} instances/sec), "
76+
"inter-validation: {:.2f}s, ({:.1f} instances/sec)"
77+
.format(val_duration, val_speed, train_duration, train_speed),
78+
color="blue")
79+
80+
if self.inter_val_times[-1] < 2 * self.validation_times[-1]:
81+
notice("Validation period setting is inefficient.")

0 commit comments

Comments
 (0)