-
from jax._src.api import _std_basis
def loss_and_grad(model, y, key):
treedef = jax.tree_util.tree_structure(model)
fj = partial(jax.jvp, _loss, (model, y, key))
fg = jax.vmap(fj)((_std_basis(model), None, None))
return fg[0].mean(), jax.tree_util.tree_unflatten(treedef, fg[1]) which feels clunky (not just because I am importing from private API, but its a lot of code for something that in reverse mode is just 1 function, additionally, most of this code is just in the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
See #762, where there has been discussion in the past about adding various In your case, though, I think it's a bit more specialized because you're computing the element-wise gradient, but returning the mean of the value, so I'm not sure you'd be able to do better than the custom implementation you have. What do you think? |
Beta Was this translation helpful? Give feedback.
See #762, where there has been discussion in the past about adding various
value_and_*
functions. Looking at #762 (comment), it seems like the resolution was that you could define such functions yourself if needed (the result looks really close to the recipe you came up with).In your case, though, I think it's a bit more specialized because you're computing the element-wise gradient, but returning the mean of the value, so I'm not sure you'd be able to do better than the custom implementation you have.
What do you think?