Skip to content

Commit df58ec9

Browse files
authored
Progress bar now handles steps_per_execution. (#21422)
Progress bar would always report the starting batch + 1 at the end of the batch. Now it takes into account `steps_per_execution` for the last batch reported. Fixes #20861
1 parent 3a11132 commit df58ec9

File tree

10 files changed

+80
-78
lines changed

10 files changed

+80
-78
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,9 @@ def fit(
408408

409409
self._jax_state_synced = True
410410
with epoch_iterator.catch_stop_iteration():
411-
for step, iterator in epoch_iterator:
411+
for begin_step, end_step, iterator in epoch_iterator:
412412
# Callbacks
413-
callbacks.on_train_batch_begin(step)
413+
callbacks.on_train_batch_begin(begin_step)
414414

415415
# Train step
416416
if self._jax_state_synced:
@@ -441,7 +441,7 @@ def fit(
441441
"metrics_variables": metrics_variables,
442442
}
443443
# Dispatch callbacks. This takes care of async dispatch.
444-
callbacks.on_train_batch_end(step, logs)
444+
callbacks.on_train_batch_end(end_step, logs)
445445

446446
if self.stop_training:
447447
# Stop training if a callback has set
@@ -569,8 +569,8 @@ def evaluate(
569569

570570
self._jax_state_synced = True
571571
with epoch_iterator.catch_stop_iteration():
572-
for step, iterator in epoch_iterator:
573-
callbacks.on_test_batch_begin(step)
572+
for begin_step, end_step, iterator in epoch_iterator:
573+
callbacks.on_test_batch_begin(begin_step)
574574

575575
if self._jax_state_synced:
576576
# The state may have been synced by a callback.
@@ -600,7 +600,7 @@ def evaluate(
600600
}
601601

602602
# Dispatch callbacks. This takes care of async dispatch.
603-
callbacks.on_test_batch_end(step, logs)
603+
callbacks.on_test_batch_end(end_step, logs)
604604

605605
if self.stop_evaluating:
606606
break
@@ -633,7 +633,7 @@ def predict(
633633

634634
if not all(layer.built for layer in self._flatten_layers()):
635635
# Build the model on one batch of data.
636-
for _, iterator in epoch_iterator:
636+
for _, _, iterator in epoch_iterator:
637637
# Build model
638638
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(
639639
next(iterator)
@@ -677,8 +677,8 @@ def append_to_outputs(batch_outputs, outputs):
677677
outputs = None
678678
non_trainable_variables = None
679679
with epoch_iterator.catch_stop_iteration():
680-
for step, iterator in epoch_iterator:
681-
callbacks.on_predict_batch_begin(step)
680+
for begin_step, end_step, iterator in epoch_iterator:
681+
callbacks.on_predict_batch_begin(begin_step)
682682
if self._jax_state_synced:
683683
# The state may have been synced by a callback.
684684
state = self._get_jax_state(
@@ -701,7 +701,9 @@ def append_to_outputs(batch_outputs, outputs):
701701
outputs = append_to_outputs(batch_outputs, outputs)
702702

703703
# Dispatch callbacks. This takes care of async dispatch.
704-
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
704+
callbacks.on_predict_batch_end(
705+
end_step, {"outputs": batch_outputs}
706+
)
705707

706708
if self.stop_predicting:
707709
break

keras/src/backend/numpy/trainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,11 @@ def append_to_outputs(batch_outputs, outputs):
211211
self.stop_predicting = False
212212
callbacks.on_predict_begin()
213213
outputs = None
214-
for step, data in epoch_iterator:
215-
callbacks.on_predict_batch_begin(step)
214+
for begin_step, end_step, data in epoch_iterator:
215+
callbacks.on_predict_batch_begin(begin_step)
216216
batch_outputs = self.predict_function(data)
217217
outputs = append_to_outputs(batch_outputs, outputs)
218-
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
218+
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
219219
if self.stop_predicting:
220220
break
221221
callbacks.on_predict_end()
@@ -255,7 +255,7 @@ def evaluate(
255255

256256
if not all(layer.built for layer in self._flatten_layers()):
257257
# Build the model on one batch of data.
258-
for _, data in epoch_iterator:
258+
for _, _, data in epoch_iterator:
259259
data_batch = data[0]
260260
self._symbolic_build(data_batch)
261261
break
@@ -276,10 +276,10 @@ def evaluate(
276276
callbacks.on_test_begin()
277277
logs = {}
278278
self.reset_metrics()
279-
for step, data in epoch_iterator:
280-
callbacks.on_test_batch_begin(step)
279+
for begin_step, end_step, data in epoch_iterator:
280+
callbacks.on_test_batch_begin(begin_step)
281281
logs = self.test_function(data)
282-
callbacks.on_test_batch_end(step, logs)
282+
callbacks.on_test_batch_end(end_step, logs)
283283
if self.stop_evaluating:
284284
break
285285
logs = self._get_metrics_result_or_logs(logs)

keras/src/backend/openvino/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,11 @@ def append_to_outputs(batch_outputs, outputs):
213213
self.stop_predicting = False
214214
callbacks.on_predict_begin()
215215
outputs = None
216-
for step, data in epoch_iterator.enumerate_epoch():
217-
callbacks.on_predict_batch_begin(step)
216+
for begin_step, end_step, data in epoch_iterator.enumerate_epoch():
217+
callbacks.on_predict_batch_begin(begin_step)
218218
batch_outputs = self.predict_function(data)
219219
outputs = append_to_outputs(batch_outputs, outputs)
220-
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
220+
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
221221
if self.stop_predicting:
222222
break
223223
callbacks.on_predict_end()

keras/src/backend/tensorflow/distribute_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_epoch_iterator(self):
104104
distribute_strategy=strategy,
105105
)
106106
steps_seen = []
107-
for step, data_iterator in epoch_iterator:
107+
for step, _, data_iterator in epoch_iterator:
108108
steps_seen.append(step)
109109
batch = next(data_iterator)
110110
self.assertEqual(len(batch), 3)

keras/src/backend/tensorflow/trainer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,10 @@ def fit(
372372
self.reset_metrics()
373373
callbacks.on_epoch_begin(epoch)
374374
with epoch_iterator.catch_stop_iteration():
375-
for step, iterator in epoch_iterator:
376-
callbacks.on_train_batch_begin(step)
375+
for begin_step, end_step, iterator in epoch_iterator:
376+
callbacks.on_train_batch_begin(begin_step)
377377
logs = self.train_function(iterator)
378-
callbacks.on_train_batch_end(step, logs)
378+
callbacks.on_train_batch_end(end_step, logs)
379379
if self.stop_training:
380380
break
381381

@@ -484,10 +484,10 @@ def evaluate(
484484
logs = {}
485485
self.reset_metrics()
486486
with epoch_iterator.catch_stop_iteration():
487-
for step, iterator in epoch_iterator:
488-
callbacks.on_test_batch_begin(step)
487+
for begin_step, end_step, iterator in epoch_iterator:
488+
callbacks.on_test_batch_begin(begin_step)
489489
logs = self.test_function(iterator)
490-
callbacks.on_test_batch_end(step, logs)
490+
callbacks.on_test_batch_end(end_step, logs)
491491
if self.stop_evaluating:
492492
break
493493
logs = self._get_metrics_result_or_logs(logs)
@@ -560,12 +560,14 @@ def get_data(iterator):
560560
callbacks.on_predict_begin()
561561
outputs = None
562562
with epoch_iterator.catch_stop_iteration():
563-
for step, iterator in epoch_iterator:
564-
callbacks.on_predict_batch_begin(step)
563+
for begin_step, end_step, iterator in epoch_iterator:
564+
callbacks.on_predict_batch_begin(begin_step)
565565
data = get_data(iterator)
566566
batch_outputs = self.predict_function(data)
567567
outputs = append_to_outputs(batch_outputs, outputs)
568-
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
568+
callbacks.on_predict_batch_end(
569+
end_step, {"outputs": batch_outputs}
570+
)
569571
if self.stop_predicting:
570572
break
571573
callbacks.on_predict_end()
@@ -696,7 +698,7 @@ def _maybe_symbolic_build(self, iterator=None, data_batch=None):
696698
# Unlike jax/torch iterator, tf iterator returns an iterator instead
697699
# of data batch in `iterator`.
698700
if iterator is not None:
699-
for _, it in iterator:
701+
for _, _, it in iterator:
700702
maybe_distributed_data_batch = next(it)
701703
has_distributed_values = tree.map_structure(
702704
lambda x: isinstance(x, tf.distribute.DistributedValues),

keras/src/backend/torch/trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,14 @@ def fit(
256256
self.train()
257257

258258
logs = {}
259-
for step, data in epoch_iterator:
259+
for begin_step, end_step, data in epoch_iterator:
260260
# Callbacks
261-
callbacks.on_train_batch_begin(step)
261+
callbacks.on_train_batch_begin(begin_step)
262262

263263
logs = self.train_function(data)
264264

265265
# Callbacks
266-
callbacks.on_train_batch_end(step, logs)
266+
callbacks.on_train_batch_end(end_step, logs)
267267
if self.stop_training:
268268
break
269269

@@ -374,10 +374,10 @@ def evaluate(
374374
callbacks.on_test_begin()
375375
logs = {}
376376
self.reset_metrics()
377-
for step, data in epoch_iterator:
378-
callbacks.on_test_batch_begin(step)
377+
for begin_step, end_step, data in epoch_iterator:
378+
callbacks.on_test_batch_begin(begin_step)
379379
logs = self.test_function(data)
380-
callbacks.on_test_batch_end(step, logs)
380+
callbacks.on_test_batch_end(end_step, logs)
381381
if self.stop_evaluating:
382382
break
383383
logs = self._get_metrics_result_or_logs(logs)
@@ -433,11 +433,11 @@ def append_to_outputs(batch_outputs, outputs):
433433
self.stop_predicting = False
434434
callbacks.on_predict_begin()
435435
outputs = None
436-
for step, data in epoch_iterator:
437-
callbacks.on_predict_batch_begin(step)
436+
for begin_step, end_step, data in epoch_iterator:
437+
callbacks.on_predict_batch_begin(begin_step)
438438
batch_outputs = self.predict_function(data)
439439
outputs = append_to_outputs(batch_outputs, outputs)
440-
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
440+
callbacks.on_predict_batch_end(end_step, {"outputs": batch_outputs})
441441
if self.stop_predicting:
442442
break
443443
callbacks.on_predict_end()

keras/src/trainers/epoch_iterator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,11 @@ def _enumerate_iterator(self):
116116
self._interrupted_warning()
117117
break
118118
self._steps_seen += self.steps_per_execution
119-
yield step, self._current_iterator
119+
yield (
120+
step,
121+
step + self.steps_per_execution - 1,
122+
self._current_iterator,
123+
)
120124
if self._num_batches and self._steps_seen >= self._num_batches:
121125
self._current_iterator = iter(self._get_iterator())
122126
self._steps_seen = 0
@@ -126,7 +130,7 @@ def _enumerate_iterator(self):
126130
while True:
127131
step += self.steps_per_execution
128132
self._steps_seen = step + self.steps_per_execution
129-
yield step, iterator
133+
yield step, step + self.steps_per_execution - 1, iterator
130134
self.data_adapter.on_epoch_end()
131135

132136
def __iter__(self):
@@ -135,19 +139,19 @@ def __iter__(self):
135139

136140
def __next__(self):
137141
buffer = []
138-
step, iterator = next(self._epoch_iterator)
142+
begin_step, end_step, iterator = next(self._epoch_iterator)
139143
with self.catch_stop_iteration():
140144
for _ in range(self.steps_per_execution):
141145
data = next(iterator)
142146
buffer.append(data)
143-
return step, buffer
147+
return begin_step, end_step, buffer
144148
if buffer:
145-
return step, buffer
149+
return begin_step, end_step, buffer
146150
raise StopIteration
147151

148152
def enumerate_epoch(self):
149-
for step, data in self:
150-
yield step, data
153+
for begin_step, end_step, data in self:
154+
yield begin_step, end_step, data
151155

152156
@contextlib.contextmanager
153157
def catch_stop_iteration(self):

keras/src/trainers/epoch_iterator_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def test_basic_flow(self, call_type):
3131
generator = iterator
3232
else:
3333
generator = iterator.enumerate_epoch()
34-
for step, batch in generator:
34+
for begin_step, end_step, batch in generator:
3535
batch = batch[0]
36-
steps_seen.append(step)
36+
steps_seen.append(begin_step)
37+
self.assertEqual(begin_step, end_step)
3738
self.assertEqual(len(batch), 3)
3839
self.assertIsInstance(batch[0], np.ndarray)
3940
self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])
@@ -52,7 +53,7 @@ def test_insufficient_data(self):
5253
)
5354
steps_seen = []
5455
with pytest.warns(match="Your input ran out of data"):
55-
for step, _ in iterator:
56+
for step, _, _ in iterator:
5657
steps_seen.append(step)
5758
self.assertLen(steps_seen, steps_per_epoch - 2)
5859

@@ -96,7 +97,7 @@ def __getitem__(self, idx):
9697
torch_dataset, batch_size=8, shuffle=True
9798
)
9899
iterator = epoch_iterator.EpochIterator(torch_dataloader)
99-
for _, batch in iterator:
100+
for _, _, batch in iterator:
100101
batch = batch[0]
101102
self.assertEqual(batch[0].shape, (8, 2))
102103
self.assertEqual(batch[1].shape, (8, 1))
@@ -226,7 +227,7 @@ def on_epoch_end(self):
226227

227228
num_epochs = 5
228229
for epoch in range(num_epochs):
229-
for step, batch in epoch_iter:
230+
for _, _, _ in epoch_iter:
230231
pass
231232

232233
self.assertAllEqual(ds.tracker, [1, 2] * num_epochs)

keras/src/trainers/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def to_symbolic_input(v):
10721072
)
10731073

10741074
if data_batch is None:
1075-
for _, data_or_iterator in iterator:
1075+
for _, _, data_or_iterator in iterator:
10761076
if isinstance(data_or_iterator, (list, tuple)):
10771077
data_batch = data_or_iterator[0]
10781078
else:

0 commit comments

Comments
 (0)