-
I have a function I want to minimize which takes a vector input x and computes a scalar output s. The function needs an incredible amount of support data which is static. I save all the support data in a dictionary of arrays. So, my function has the form where x is a vector and the free parameters, y is the support data stored in dict of arrays e.g. y = {'v1':np.array([...]), ... 'vn':np.array([...])} I want to be able to do: It seems like this would be a common pattern for optimization problems. My thought is that this is already solved and I doing something wrong. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 6 replies
-
Static arguments to JIT must be hashable, so dicts and arrays cannot be passed as static arguments. But if you want to pass arrays that will be treated as constants to your function, you can do so by closing over them. For example: from functools import partial
compiled_f = jax.jit(partial(f, y=y))
result = compiled_f(x) Does that work for your use-case? |
Beta Was this translation helpful? Give feedback.
-
From one of the maintainers, anecdotally it seems making these constant arrays static will not yield any benefits. Thus, using the closure in the first response is likely best. From @jakevdp,
|
Beta Was this translation helpful? Give feedback.
From one of the maintainers, anecdotally it seems making these constant arrays static will not yield any benefits. Thus, using the closure in the first response is likely best.
From @jakevdp,