Skip to content

Optionally aggregate metrics for custom step method models in Trainer.evaluate() #21314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def evaluate(
steps=None,
callbacks=None,
return_dict=False,
aggregate=False,
**kwargs,
):
self._assert_compile_called("evaluate")
Expand Down Expand Up @@ -565,13 +566,31 @@ def evaluate(
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

def _aggregate_fn(_logs, _step_logs):
if not _logs:
return _step_logs

return tree.map_structure(backend.numpy.add, _logs, _step_logs)

def _reduce_fn(_logs, _total_steps):
if _total_steps == 0:
return _logs

def _div(val):
return val / _total_steps

return tree.map_structure(_div, _logs)

self._jax_state_synced = True
with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(step)

total_steps += 1

if self._jax_state_synced:
# The state may have been synced by a callback.
state = self._get_jax_state(
Expand All @@ -582,13 +601,18 @@ def evaluate(
)
self._jax_state_synced = False

logs, state = self.test_function(state, iterator)
step_logs, state = self.test_function(state, iterator)
(
trainable_variables,
non_trainable_variables,
metrics_variables,
) = state

if aggregate:
logs = _aggregate_fn(logs, step_logs)
else:
logs = step_logs

# Setting _jax_state enables callbacks to force a state sync
# if they need to.
self._jax_state = {
Expand All @@ -600,11 +624,14 @@ def evaluate(
}

# Dispatch callbacks. This takes care of async dispatch.
callbacks.on_test_batch_end(step, logs)
callbacks.on_test_batch_end(step, step_logs)

if self.stop_evaluating:
break

if aggregate:
logs = _reduce_fn(logs, total_steps)

# Reattach state back to model (if not already done by a callback).
self.jax_state_sync()

Expand Down
35 changes: 33 additions & 2 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
from tensorflow.python.eager import context as tf_context

from keras.src import backend
from keras.src import callbacks as callbacks_module
from keras.src import metrics as metrics_module
from keras.src import optimizers as optimizers_module
Expand Down Expand Up @@ -441,6 +442,7 @@ def evaluate(
steps=None,
callbacks=None,
return_dict=False,
aggregate=False,
**kwargs,
):
self._assert_compile_called("evaluate")
Expand Down Expand Up @@ -482,14 +484,43 @@ def evaluate(
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

def _aggregate_fn(_logs, _step_logs):
if not _logs:
return _step_logs

return tree.map_structure(backend.numpy.add, _logs, _step_logs)

def _reduce_fn(_logs, _total_steps):
if _total_steps == 0:
return _logs

def _div(val):
return val / _total_steps

return tree.map_structure(_div, _logs)

with epoch_iterator.catch_stop_iteration():
for step, iterator in epoch_iterator:
callbacks.on_test_batch_begin(step)
logs = self.test_function(iterator)
callbacks.on_test_batch_end(step, logs)
total_steps += 1

step_logs = self.test_function(iterator)

if aggregate:
logs = _aggregate_fn(logs, step_logs)
else:
logs = step_logs

callbacks.on_test_batch_end(step, step_logs)
if self.stop_evaluating:
break

if aggregate:
logs = _reduce_fn(logs, total_steps)

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

Expand Down
33 changes: 31 additions & 2 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def evaluate(
steps=None,
callbacks=None,
return_dict=False,
aggregate=False,
**kwargs,
):
# TODO: respect compiled trainable state
Expand Down Expand Up @@ -373,13 +374,41 @@ def evaluate(
self.stop_evaluating = False
callbacks.on_test_begin()
logs = {}
total_steps = 0
self.reset_metrics()

def _aggregate_fn(_logs, _step_logs):
if not _logs:
return _step_logs

return tree.map_structure(backend.numpy.add, _logs, _step_logs)

def _reduce_fn(_logs, _total_steps):
if _total_steps == 0:
return _logs

def _div(val):
return val / _total_steps

return tree.map_structure(_div, _logs)

for step, data in epoch_iterator:
callbacks.on_test_batch_begin(step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, logs)
total_steps += 1
step_logs = self.test_function(data)

if aggregate:
logs = _aggregate_fn(logs, step_logs)
else:
logs = step_logs

callbacks.on_test_batch_end(step, step_logs)
if self.stop_evaluating:
break

if aggregate:
logs = _reduce_fn(logs, total_steps)

logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)

Expand Down