Performance regression while scaling up GPUs when computing deteriminants. #28675
-
Hi everyone, I am currently trying to parallelize some code that I have written over multiple GPUs. I have followed the "Distributed arrays and automatic parallelization" guide on the webpage and found that for the example network provided I get a significant speedup when working with up to four GPUs. However, when I was trying to do the same thing on my own code the speedup was much less significant, and it actually slowed down when going from 2 GPUs to 4 GPUs. From testing with a small example it seems like the issue is with parallelizing over determinants. Is there any known issue with parallelizing over determinants? Thanks in advance for any help and I have attached the script I was testing with below with its outputs.
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Good question! It looks like |
Beta Was this translation helpful? Give feedback.
-
(A side note, but do note that |
Beta Was this translation helpful? Give feedback.
Good question! It looks like
slogdet
is backed by either a QR or LU decomposition. These decompositions are themselves backed by calls to cuSOLVER on NVIDIA GPUs. Until recently (JAX v0.5.3, I think), these library calls didn't support sharding like this out of the box. If you try your experiment with the latest version of JAX, I predict that you will see the scaling you expect. For older versions of JAX like the one you're using, the recommendation would be to useshard_map
.