Replies: 2 comments 3 replies
-
Thanks @jakevdp for editting. I build latest tpu version from the source, and run the python3 nvidia_gpu_pjit.py --server_addr="10.128.0.25:1456" --num_hosts=2 --host_idx=0
python3 nvidia_gpu_pjit.py --server_addr="10.128.0.25:1456" --num_hosts=2 --host_idx=1 And here are the results ubuntu@t1v-n-207dbaa4-w-0:~$ python3 nvidia_gpu_pjit.py --server_addr="10.128.0.25:1456" --num_hosts=2 --host_idx=0
['t1v-n-207dbaa4-w-0:2664546204052242694:10.128.0.25']
Traceback (most recent call last):
File "nvidia_gpu_pjit.py", line 50, in <module>
app.run(main)
File "/home/ubuntu/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/ubuntu/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "nvidia_gpu_pjit.py", line 25, in main
jax.distributed.initialize(FLAGS.server_addr, FLAGS.num_hosts, FLAGS.host_idx)
File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/distributed.py", line 89, in initialize
raise RuntimeError('Number of workers does not equal the number of '
RuntimeError: Number of workers does not equal the number of processes. Auto detecting process_id is not possible.Please pass process_id manually. and the other host ubuntu@t1v-n-78ea0c76-w-0:~$ python3 nvidia_gpu_pjit.py --server_addr="10.128.0.25:1456" --num_hosts=2 --host_idx=1
['t1v-n-78ea0c76-w-0:8021122269074014007:10.128.0.26']
Traceback (most recent call last):
File "nvidia_gpu_pjit.py", line 50, in <module>
app.run(main)
File "/home/ubuntu/.local/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/ubuntu/.local/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "nvidia_gpu_pjit.py", line 25, in main
jax.distributed.initialize(FLAGS.server_addr, FLAGS.num_hosts, FLAGS.host_idx)
File "/home/ubuntu/.local/lib/python3.8/site-packages/jax/_src/distributed.py", line 89, in initialize
raise RuntimeError('Number of workers does not equal the number of '
RuntimeError: Number of workers does not equal the number of processes. Auto detecting process_id is not possible.Please pass process_id manually. |
Beta Was this translation helpful? Give feedback.
1 reply
-
I believe that this is not possible. Multi-host initialisation only works for TPU pods, like TPU v3-16, v4-16, etc. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
recently, i found https://github.com/google/jax/blob/main/jax/_src/distributed.py#L35-L36 adds the tpu in
jax distributed.initialize
. Does it currently support multiple tpu hosts? And I also wonder how to install jax nightly in order to keep updated with this features?Here are the details: I have launched 2
tpu:v2-8
vms.How can I use
distributed.initialize
to connected the two hosts?Beta Was this translation helpful? Give feedback.
All reactions