Skip to content

How to tell the jit all my supporting data is static in a clean manner (i.e. subclassing, alternate data structure, etc)? #13460

Answered by isaacgerg
isaacgerg asked this question in General
Discussion options

You must be logged in to vote

From one of the maintainers, anecdotally it seems making these constant arrays static will not yield any benefits. Thus, using the closure in the first response is likely best.

From @jakevdp,

from functools import partial
compiled_f = jax.jit(partial(f, y=y))
result = compiled_f(x)

Replies: 2 comments 6 replies

Comment options

You must be logged in to vote
6 replies
@jakevdp
Comment options

@isaacgerg
Comment options

@jakevdp
Comment options

@isaacgerg
Comment options

@jakevdp
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by isaacgerg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants