Replies: 2 comments
-
I'm not sure what exactly you're looking for here. Can you put together a short example of what you want to compute? |
Beta Was this translation helpful? Give feedback.
0 replies
-
I suspect you're looking for |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I have a question related to my research.
I need to compute gradients of individual layers, not the whole model.
I saw that jax.grad can only be applied on functions with scalar output (i.e. a whole model with loss function at the end).
However, the computation of this gradient must create internally all the objects needed to build a gradient for functions with tensor output (the layers).
How difficult would it be to extract them and build, say, the gradient of a convolution?
Maybe there are internal functions that build what I need, which are not published on the interface, and which I could use?
Any help would be highly appreciated.
Edit: would jax.jacfwd and jax.jacrev what I look for?
Edit2: just tried jax.jacrev on a (small) convolution and it fails. It seems to try to allocate the whole Jacobian matrix, which can't work even for small convolutions. What I need is really the function that computes gradients on the inputs from gradients on the outputs.
Best regards,
Dumitru
Beta Was this translation helpful? Give feedback.
All reactions