Jaxley parallelization/speedup questions #626
Replies: 1 comment 3 replies
-
Hi Elena, thanks for reaching out! The increase in runtime
|
Beta Was this translation helpful? Give feedback.
-
Hi Elena, thanks for reaching out! The increase in runtime
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Jaxley group, I've been working with Jaxley to try to optimize single neuron models and have noticed that the simulation time increases substantially with any increase in simulation length, compartment #, as well with batch size with vmapped simulations. Unexpectedly vmapped simulations on CPU have also been faster than the same over a GPU. The faster CPU runtime I assume is because my batch size is below what would be expected to show an improvement with a GPU, but I assumed the runtime increases I've been seeing on the CPU was due to resource limitations. However when I look at my overall CPU usage during training it consistently sits below 10%. I have a 20 CPU core system and during a simulation a small subset show variable %usage between 1-30 % and a single core sits much higher, near 100%. So I am hoping that if I can improve the parallelization across cores that it would speed up the runtime, especially when trying to simulate multiple stimuli simultaneously, but independently.
Are there strategies compatible with the Jaxley framework (especially with the Cell object class) beyond jit and vmap that you would recommend I try to improve simulation speed and/or multi-core recruitment? Some things I looked at suggested that using the Jax.pmap function (instead of vmap) would help, but it doesn't seem to be compatible with the Jaxley Cell class.
Thanks!
Some numbers:
On CPU running a vmapped simulation of batch size 4 of a cell with 165 compartments and 16 trainable parameters takes about 0.05s for a 50 ms simulation at 0.1 ms sampling rate. To do the same but with a simulation time of 1.4 seconds it takes about 6.5s. Running these simulations on CPU is roughly 3x faster than running them on a NVIDIA GPU (with vmapping).
Simulation code:
Beta Was this translation helpful? Give feedback.
All reactions