Skip to content

Saving intermediate values in a file #16990

Answered by JiaYaobo
simlaharma asked this question in Q&A
Discussion options

You must be logged in to vote

jax.experimental.io_callback?

def save_array(array):
    np.save('array.npy', array)
    print('saved array')

@jax.jit
def f(x):
    io_callback(save_array, None, x)
    return jax.numpy.sin(x)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants