jnp.dot vs explicit dot product followed by jnp.sum results in drastically different memory uses #17924
-
import jax
import jax.numpy as jnp
import numpy as np
n = 100000
d = 50000
pos = np.random.random((n, 3)).astype(np.float32)-0.5
HKL = np.random.random((d,3)).astype(np.float32)-0.5
pos = jnp.array(pos)
HKL = jnp.array(HKL)
def explicit(pos, HKL):
phase = (HKL[:,0]*pos[0]+HKL[:,1]*pos[1]+HKL[:,2]*pos[2])
return jnp.cos(phase)
def dot(pos, HKL):
phase = jnp.dot(HKL, pos)
return jnp.cos(phase)
def f1(pos, HKL):
return jnp.sum(jax.vmap(explicit, in_axes=(0, None))(pos, HKL), axis=0)
def f2(pos, HKL):
return jnp.sum(jax.vmap(dot, in_axes=(0, None))(pos, HKL), axis=0)
print(jax.jit(f1)(pos, HKL))
print(jax.jit(f2)(pos, HKL)) Why does the
It seems in the second case the full result out of vmap is generated before the sum, instead of doing the sum as the vmap proceeds. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi - thanks for the question! I think in some sense this is expected, in that different ways of expressing computations will in general lead to different usage of computational and memory resources. In a perfect world the XLA compiler would be able to recognize that these two sequences of operations are equivalent and choose the best approach given the resources available, but no perfect compiler exists. I'll address one of your points directly:
As far as I know, the XLA computational model will never "do the sum as the vmap proceeds". XLA is designed with vectorized computation in mind, where by "vectorized" I mean that the entire array is stored in memory, and a kernel is called that computes and stores the results. The compiler can fuse operations in many cases, but definitely not all cases. This computational model built on vectorized kernels is well-suited to accelerators like GPU and TPU, where the chip architecture allows such computations to complete very quickly. If you instead want a computation to be done serially or in batches, it's up to you to express your computation in that manner. Your "explicit" function is essentially that. Does that answer your question? |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! I think in some sense this is expected, in that different ways of expressing computations will in general lead to different usage of computational and memory resources. In a perfect world the XLA compiler would be able to recognize that these two sequences of operations are equivalent and choose the best approach given the resources available, but no perfect compiler exists.
I'll address one of your points directly:
As far as I know, the XLA computational model will never "do the sum as the vmap proceeds". XLA is designed with …