Multi-Host training checkpointing #21290
Replies: 2 comments
-
I managed to solve it. |
Beta Was this translation helpful? Give feedback.
0 replies
-
For those googling this error, I also ran into this error when I had a threading issue. Specifically, I was inadvertently launching TPU JAX kernels in a background data loading thread. So be sure you always launch the same TPU kernel on all workers at the same time. (Both the original error and mine are that.) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am trying to do multi-host training on a TPU pod. I managed to run back-propagation, but I got stuck at saving checkpoints, mainly in saving the distributed flax train state.
I distributed the initialised state using the following:
After a few updates, I try to save the state but failed to get the parameters on one process. The things I tried:
Do you have any advice on how to get the parameters on one host and save them on disk with orbax?
I also tried directly saving with the legacy api:
but the error that I get is:
Regards,
Rares
Beta Was this translation helpful? Give feedback.
All reactions