-
I was creating a new package with Jax (and its ecosystem) and was writing tests to check that the gradients are correct.
However, this results in an
Can anyone tell me what is going wrong with this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
check_grads(dummy_loss, (init_params,), order=2, eps=1e-6) |
Beta Was this translation helpful? Give feedback.
If I run your example in 64-bit precision with
eps=1E-6
, I find that the gradients match. In float32 precision, if I use a smallereps
then the numerical result diverges.This makes me think that your function has a fast-varying second derivative, which makes the numerical gradient inaccurate, but nevertheless the analytic gradient is probably producing the correct value.