Skip to content

Optionally aggregate metrics for custom step method models in Trainer.evaluate() #21314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

LarsKue
Copy link
Contributor

@LarsKue LarsKue commented May 21, 2025

Currently, models that use a custom test_step function do not support metric aggregation in evaluate(). Instead, the method simply returns metrics from the last step:

# Verify that train / test step logs passed and metric logs have
# matching keys. It could be different when using custom step functions,
# in which case we return the logs from the last step.

This can lead to user confusion, e.g. in bayesflow-org/bayesflow#481.

This PR introduces a very primitive way to aggregate metrics, which is controllable by a boolean flag in the evaluate method: aggregate=False is set to False by default to preserve backward-compatibility. The current idea is to simply sum all metrics and divide by the total number of steps taken, which works when all steps return any PyTree of numerical (tensor or primitive) metrics, which must be consistent across all steps taken.

I understand that this code might still be too simple, lack extensibility to multi-device synchronization, or be inconsistent with other parts of the library. I am open to modifying the PR, given feedback.

@fchollet
Copy link
Collaborator

Thanks for the PR. Is there an alternative way to solve this problem? Other backends don't have this issue.

@codecov-commenter
Copy link

codecov-commenter commented May 21, 2025

Codecov Report

Attention: Patch coverage is 37.93103% with 36 lines in your changes missing coverage. Please review.

Project coverage is 82.55%. Comparing base (785c9b0) to head (f85f0be).

Files with missing lines Patch % Lines
keras/src/backend/jax/trainer.py 36.84% 10 Missing and 2 partials ⚠️
keras/src/backend/tensorflow/trainer.py 40.00% 10 Missing and 2 partials ⚠️
keras/src/backend/torch/trainer.py 36.84% 10 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21314      +/-   ##
==========================================
- Coverage   82.60%   82.55%   -0.05%     
==========================================
  Files         565      565              
  Lines       54773    54825      +52     
  Branches     8508     8520      +12     
==========================================
+ Hits        45244    45260      +16     
- Misses       7439     7469      +30     
- Partials     2090     2096       +6     
Flag Coverage Δ
keras 82.36% <37.93%> (-0.05%) ⬇️
keras-jax 63.51% <13.79%> (-0.05%) ⬇️
keras-numpy 58.65% <1.72%> (-0.06%) ⬇️
keras-openvino 33.10% <1.72%> (-0.03%) ⬇️
keras-tensorflow 63.91% <13.79%> (-0.05%) ⬇️
keras-torch 63.57% <13.79%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@LarsKue
Copy link
Contributor Author

LarsKue commented May 21, 2025

@fchollet Could you elaborate? I think the issue exists independently of the backend.

In terms of options, I believe we could also do something along the lines of tree_concatenate and then tree_mean or let the user supply an aggregation function. This requires slightly more memory, but would also allow different steps to return different metrics.

@fchollet
Copy link
Collaborator

Aggregation should be handled by stateful Metric instances -- it's not something that you need extra logic for.

The idea being that at each test step, you update the Metric instances with the metric value for the last batch, and at the end, you query the Metric instances to get the aggregated (average) value.

@LarsKue
Copy link
Contributor Author

LarsKue commented May 21, 2025

@fchollet I agree that this could and should be handled by the metrics themselves. However, the current keras.Metric only supports a supervised interface using y_pred, y_true. Thus, models that circumvent the mostly supervised-learning oriented interface by implementing a custom train_step, test_step will encounter this issue.

In that case, should we fall back to a conversation about extending the support for arbitrary gradient-based learning in keras? I.e. loss = f(data) rather than loss = loss_fn(f(data), label).

@fchollet
Copy link
Collaborator

However, the current keras.Metric only supports a supervised interface using y_pred, y_true. Thus, models that circumvent the mostly supervised-learning oriented interface by implementing a custom train_step, test_step will encounter this issue.

You can just use a Mean metric which will aggregate any tensor (it will even work with non-scalars).

Metric reduction across an epoch (or across one call to evaluate()) should be done by metric objects in any case.

@LarsKue
Copy link
Contributor Author

LarsKue commented May 22, 2025

I did some more digging and could conclude that the issue only arises when the train_step or test_step returns multiple metrics - at least one of which has no keras.metrics.Mean tracker (since one is automatically constructed for the loss). Specifically, then the keys of metric_results and logs are different in the code block I linked above.

Closing this since it seems there is already a workaround, but I think it could be nice if keras removed the requirement for users to explicitly define an aggregation scheme for each metric and would default to using the mean.

@LarsKue LarsKue closed this May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants