Multi-host parallelization with GPUs #18659
Unanswered
IrishWhiskey
asked this question in
Q&A
Replies: 2 comments 1 reply
-
You need the EFA drivers and certain magic bits for the stack to work properly in AWS. the jax container from jax-toolbox should work for AWS cc: @yhtang |
Beta Was this translation helpful? Give feedback.
1 reply
-
I have the same question. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I'm trying the JAX multi-host parallelization on AWS GPU instances but I can't get it working. What am I doing wrong? See below the process I followed.
I created two EC2 instances (
p3.2xlarge
) and set up the network configurations so that they can communicate with each other. After installing the required dependencies, I ran the following script on both instances:The script is not 100% same as in the second instance I set the
process_id
to1
(the coordinator address refers to the first instance).From the terminal output I can see that up until the
print(xs)
everything works fine but the last line causes an error. Below are the outputs I got from both instances.First instance:
Second instance:
I'm using python
3.10.12
, jax0.4.20
and jaxlib0.4.20+cuda11.cudnn86
.Beta Was this translation helpful? Give feedback.
All reactions