Skip to content

jax.grad for values originating from the same function #10418

Discussion options

You must be logged in to vote

IIUC, you need to split your network into an encoder(obs -> conv_output) and a classifier(conv_output -> predictions). This is because jax can only transform a function into its grad and doesn't have a tape.

conv_output = encoder_def.apply(encoder_params, obs)
grad, output = jax.value_and_grad(lambda z: classifier_def.apply(classifier_params, z))(conv_out)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@pseudo-rnd-thoughts
Comment options

Answer selected by pseudo-rnd-thoughts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants