Closed
Description
I don't have minimum viable example to provide, but on tpu there has been a significant performance regression.
git diff v3.10.0 pr-21187 -- keras/src/backend/jax/
Training step time across many models are 2-3x. Please review and fix.
GPUs are fine. TPUs + jax are exhibiting the 2-3x slowdown