-
Dear Community, I'm writing an astrophysical magnetohydrodynamics code (https://github.com/leo1200/jf1uids/) and am currently working on scaling it to multiple GPUs. As a test bench, I've written a simpler fluid code (https://github.com/leo1200/tinyfluids). Communication is necessary for spatial finite differencing as well as interface calculations which require data from the cell left and right to the interface. To get left and right states along the different spatial dimensions, I use
(see https://github.com/leo1200/tinyfluids/blob/main/tinyfluids/jax_tinyfluids/fluid.py#L385). For finite differencing only, one might also use an appropriate convolution. Sharding the input onto multiple GPUs and just relying on jit however, yields very poor scaling - especially compared to my custom solution based on shard_map and halo_exchange at shard interfaces, as demonstrated below (running on four H100s) "speedup" refers to using just jit and sharded data, shard mapped to using a shard map with halo exchange. From also testing on other hardware, it seems as if sharding the data + jit for this case has a pretty significant memory exchange overhead, limiting scaling. My question now is: Is there a way to obtain better scaling without using shard_map or other custom solutions? Upstreaming shard mapping and halo exchange to jf1uids would make our MHD code more complex and less user-friendly - and at the end of the day our goal is to enable a broad community of astrophysicists to contribute to the physics. I'm very much looking forward to suggestions, and it would be super cool if someone is interested in playing around with the codes a bit - animating those (astro)-fluid flows can be quite satisfying (they're super beautiful), especially if the scaling is nice and the simulation itself is fast :-) Best wishes |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It looks like using jnp.roll instead of the slicing gets me to pretty much perfect scaling without any custom code. |
Beta Was this translation helpful? Give feedback.
It looks like using jnp.roll instead of the slicing gets me to pretty much perfect scaling without any custom code.