GPU implementation of birth-death sampling million times slower? #18461
Replies: 2 comments
-
I think the issue here is that your algorithm is one that is very ill-suited to run on GPU, no matter how you express it. GPUs in general are very good at running implicitly parallelized vector operations over arrays; modern CPUs are OK in this regime, but will not match the speed of GPUs for such problems. GPUs are very bad at doing sequential operations over single values stored in memory; however this is the regime where CPUs excel. Your computation deals sequentially with individual array elements, with no possible parallelization, because the input state of each step explicitly depends on the output state of the previous step: thus it falls squarely within the latter regime, where CPU excels and GPU architectures are not well-suited. Given this, I'd expect that no matter how you express it (even writing a custom-tuned CUDA kernel) you'll never get this to run as fast on GPU as it does on CPU. You may find a use for GPUs in this problem if, say, you are hoping to run many such procedures at once. Then each step in the sequential procedure could be parallelized to take advantage of the GPU hardware, and you could likely do much better than the CPU, which would essentially have to run each of the many sequences individually. Does that make sense? |
Beta Was this translation helpful? Give feedback.
-
Yes it does! Thank you so much for your fast response. One strategy would then be to identify and perform computations involving parallel computations on GPU, and perform sequential logic on CPU. Are there any good practices in mixing computations of this type? Or it shouldn't matter you think? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I am new to JAX and was hoping to get some feedback on my implementation of a birth-death sampling algorithm. My JIT compiled GPU version is a million times slower than a standard Numpy implementation according to my benchmarks!
The logic is fairly simple:
Input:
If there is an excess of mass, then$\Lambda_n > 0$ and that particle will "teleport" to another randomly chosen particle. On the other hand, if a particle is in a "deficit" region $\Lambda_n <0$ , then another randomly chosen particle will teleport to that location.
Note: The algorithm requires keeping track of which particles have already teleported.
In standard Numpy, the logic is as follows:
where I use the following to keep track of which particles have already been teleported
This is my best attempt at writing a compilable version of the previous code. I keep track of particles that are alive/dead with a binary array of 1's and 0's, and use that as a unnormalized probability mass function in the
jax.random.choice
function to select particles. This is computationally inefficient compared to the first approach, but I am stuck on how else to perform this operation.Any thoughts or suggestions?
Beta Was this translation helpful? Give feedback.
All reactions