Replies: 2 comments 1 reply
-
Also tagging @shoyer and @jekbradbury who might be interested in this topic. |
Beta Was this translation helpful? Give feedback.
-
This is a very good question and in fact it has been raised multiple times in the past. While I am somewhat sympathetic to exposing this, at the same time I am a bit afraid of unintended consequences this would have on xmap. For example, what do you do when |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
xmap provides a really nice API to do distributed computing with automatic partitioning, while also expressing computations in a way that is agnostic to the sharding semantics. xmap along with XLA SPMD already feels like a super power for distributed computation!
However, in some use cases it would be useful to have more control over the individual chunks present in each device. It would still be great to use xmap for the actual computations, and take advantage of the partitioning logic and XLA optimizations behind the scenes.
The primary use case I have in mind is a sequence of (very cheap) convolution and slicing operations, say around ~10 of them. The required halo exchanges widths are known to me, so I would like to perform halo exchange upfront, manually, to improve the efficiency.
Another use case would be the one-shuffle scheme of doing distributed matrix multiplication, described in this paper: Lu et al, Large-Scale Discrete Fourier Transform on TPUs (2021). Briefly, the one-shuffle scheme cycles the row chunks of the RHS matrix and does matmul one chunk at a time to produce the output row chunk. This approach avoids all gather for matmul.
In order to express such computations, one possible approach could be to reinterpret a Sharded Device Array such as one outputted by xmap, in a way that is local to each core. Such a process should not change the actual content of the device buffers present in each device. Once the per-chunk xmap computations are done, we can switch back to the global view.
Here is the rough idea of what I have in mind. Of course the details need to be worked through and the new API should probably be more ergonomic and consistent.
I think it is possible to achieve at least some of what this API provides by rebuilding the Sharded Device Array using the partitioning information that is already part of its bookkeeping.
But it would be great to have an official, fleshed out, version of this API. Does it make sense to add support for such local and global reinterpretations to Jax?
Beta Was this translation helpful? Give feedback.
All reactions