Array of structs vs struct of arrays #7544
Replies: 2 comments 4 replies
-
Use this colab to compare alphafold-style 3d matrix-vector products with einsum or the @ operator, and I found it to be about 2x faster on CPU, and 20% faster on GPU (TPU was having a flaky day as it often does for me and I didnt get meaningful numbers). Somewhat surprised it matters more on CPU than GPU; and this is but one rather synthetic benchmark, but they are nontrivial differences nonetheless. I tried to gain some understanding if the memory layout of my device array matters at all; but it appears to me JAX/XLA current coerces all device arrays to c-contiguous anyway (as you typically would when programming GPUs cause thats what pretty much all library calls demand), since I cannot seem to make any difference by trying to input either c or fortran style device arrays; despite the fact that this should matter a lot, at least to the GPU. |
Beta Was this translation helpful? Give feedback.
-
Is there any method to check the strides of a device array by the way? None that I can find so far and google is pretty quiet on the topic as well. Would be nice to check some assumptions as to what is going on under the hood. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Ive been writing code in JAX that involves length-3 vectors, and quaternions and the like. Ive been looking at jax/brax and the recently released alphafold for inspiration on best practices, and I noticed they take a different approach towards the implementation of for instance quaternion algebra operations. brax does what amounts to an array-of-structs, but here is the rationale behind alphafolds approach.
That is, they define a vector as
Vecs = collections.namedtuple('Vecs', ['x', 'y', 'z'])
, rather than as simple a jax array of size 3, and write code involving such objects accordingly; that is if we want to sum over components, we need to write that out and we cannot call a sum method over an axis.What we all want ofc is to be able to write high level expressive code, that will compile to something close to optimal on as many backends as possible. Its known that GPUs at least prefer a struct-of-arrays memory layout; as that will allow threads in a warp to perform the same operation, while their memory access patterns will coalesce.
Ideally, Id say a system like JAX could provide both expressiveness and performance; I can have arrays that contain many dimensions, including small dimension of size 3 or 4 or whatever; and during jitting JAX will decide for itself how to allocate those arrays, and what the strides for each axis should be. If my GPU prefers a struct-of-arrays type layout, so be it, and if my CPU prefers the fields in each object to be aligned together in memory, then we can compile it that way as well.
How close is current JAX to that ideal in practice though? Does it reason about the order of strides at all, or simply use C-order everywhere? Are there plans for getting the compiler involved in being smart about this? Is there any JAX style guide for best practices on this front?
Beta Was this translation helpful? Give feedback.
All reactions