Saving intermediate values in a file #16990
Answered
by
JiaYaobo
simlaharma
asked this question in
Q&A
-
Hello, I have a JAX function that I use during training a DNN model. I want to debug it by checking the values before and after the function. |
Beta Was this translation helpful? Give feedback.
Answered by
JiaYaobo
Aug 6, 2023
Replies: 1 comment 1 reply
-
def save_array(array):
np.save('array.npy', array)
print('saved array')
@jax.jit
def f(x):
io_callback(save_array, None, x)
return jax.numpy.sin(x) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
jakevdp
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
jax.experimental.io_callback
?