Refine Tensorflow model in JAX? #17477
Unanswered
deeplearningrobotics
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello all,
I have a pre-trained Tensorflow model (
tf.keras.Model
) which I would like to refine in JAX with a loss written in JAX.With
jax2tf.call_tf
I can call the model and potentially also getting a gradient of the loss. But how do I now update the model parameters of the Tensorflow model based on this gradient?I would also like to train the Tensorflow model in Tensorflow more after a refinement with JAX, so it would be nice if the model would stay in Tensorflow.
https://github.com/deepmind/tf2jax might be another option I could try if
jax2tf.call_tf
is not supporting what I am trying to do.Thank you so much!
Beta Was this translation helpful? Give feedback.
All reactions