Replies: 1 comment
-
In mpi4jax/mpi4jax#275, it has been pointed out to me that I can use |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
First of all, thank you for the great tool!
I am trying to parallelize our optimization code which uses Jax for automatic differentiation and jit. My primary aim was to use multiple device to prevent the memory problem we have during certain calculations (usually the gradient of multi-objective cost functions). I think I kind of achieved that. Unfortunately, by doing so, I lost the ability to jit the main function that calculates the jacobian (because the computations occur in multiple devices). A simplified version of my current implementation is like this,
I skipped the parts that I declare these classes as proper pytrees, but I use
register_pytree_node()
to be able to jit the methods.We are using this kind of classes as the inputs of our optimizer, which packages everything using
ObjectiveFunctionParallel
. I distribute all the data on the objectives to specified devices like this,Re-assigning the optimizable of each objective with the original one that lives on the default device is necessary for the rest of our optimizer.
With these design considerations in mind, I would like to parallelize this specific portion,
Note that in general case, self.objectives are completely different classes with differenr functions and outlut shapes, so I cannot simply use
pmap
. Since the data is already transferred to the corresponding device, all I need is a 'spark' that will igniteobj.jac_error
on each GPU. MPI(with mpi4jax) seems to be a good option, previously I used openMP in C++
#pragma omp parallel for
, I didn't seejax
compatible way of doing this though. If there is a nativejax
way to do this, I would prefer it to reduce the complexity of the installation instruction of our package. I knowjax
parallelism is primarily for SPMD but I wanted to ask anyway.Sorry, I couldn't make a shorter explanation of the problem!
I described what I am trying to do also here mpi4jax/mpi4jax#275.
Best regards,
Beta Was this translation helpful? Give feedback.
All reactions