jax.grad for values originating from the same function #10418
Answered
by
YouJiacheng
pseudo-rnd-thoughts
asked this question in
General
-
Im trying to implement GradCam # Create a graph that outputs target convolution and output
grad_model = tf.keras.models.Model([model.inputs], [model.get_layer(LAYER_NAME).output, model.output])
# Get the score for target class
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(np.array([img]))
loss = predictions[:, TARGET_CLASS]
# Extract filters and gradients
grads = tape.gradient(loss, conv_outputs)[0] Source: https://www.sicara.fr/blog/2019-08-28-interpretability-deep-learning-tensorflow However for jax I was unsure how I could make a function that somehow took the network state as an argument in order to differentiate the selected output def gradcam(newtork_def, network_params, obs):
predictions, conv_output = network_def.apply(network_params, obs, intermediates='mutable')
return conv_output, predictions
grad, (output, _) = jax.value_and_grad(gradcam, argnum=???)(network_def, network_params, obs) |
Beta Was this translation helpful? Give feedback.
Answered by
YouJiacheng
Apr 22, 2022
Replies: 1 comment 1 reply
-
IIUC, you need to split your network into an encoder( conv_output = encoder_def.apply(encoder_params, obs)
grad, output = jax.value_and_grad(lambda z: classifier_def.apply(classifier_params, z))(conv_out) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
pseudo-rnd-thoughts
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
IIUC, you need to split your network into an encoder(
obs -> conv_output
) and a classifier(conv_output -> predictions
). This is becausejax
can only transform a function into its grad and doesn't have a tape.