Is it reasonable to create an empty jax.numpy array first and then fill it with elements? #18090
Replies: 2 comments 2 replies
-
It depends on what you're trying to achieve. If you intend to do everything within Jax, including the calculation of the elements A[i], then you should use the On the other hand, if you only want to use the result of the complex calculation in a part of your code that is written in Jax, then |
Beta Was this translation helpful? Give feedback.
-
In JAX, the iterative array filling approach will be slow. Better would be to use JAX transformations to do the operations more natively; that is, if you have iterative code that looks like this: A = jnp.empty(N)
for i in range(N):
A = A.at[i].set(some_function(i)) It will be much more efficient in general to write something like this: A = jax.vmap(some_function)(jnp.arange(N)) This pre-supposes that |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I now need to modify my numpy code to jax.numpy code. But I found a problem, for example:
In numpy, I can first create an array, and then go through complex operations to calculate each element, and then fill it into the array (the steps for calculating elements in the example code are simplified, but it is actually far more complicated than this).
However, this seems to be a very difficult thing in jax.numpy. Modifying the size of the value in place is not as simple as nunmpy
A[i] = i
.I thought of two solutions:
First, according to the manual need to modify the value like
x = x.at[idx].set(y)
.Second, first create a numpy array, and after all elements in it are calculated, create a jax.numpy array as a whole. The example is as follows:
My question is, what should I do if I usually encounter this situation (the calculation of the elements in the array is very complicated, and I need to create an array first and then calculate each element)?
Apart from these two methods, is there a more secure and elegant method? If not, which one of these two methods is better and whether it will affect the subsequent automatic differentiation. Thanks
Beta Was this translation helpful? Give feedback.
All reactions