-
Hi, thanks for the great work. We are running the sample script
We run the above script with gcloud:
When printing out the batch on each host/process, we get the following logs:
This seems to imply we are getting the global batch (8 samples) on each host even if the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi @weirayao , thanks for checking out our code base. I think what you're seeing is consistent with the distributed execution model of torchprime, although the behavior is not intuitive. Here's what's happening:
We can get a better understanding of this process by instrumenting the data loaders. I've uploaded an example at torchprime/torchprime/torch_xla_models/train.py Lines 178 to 219 in 21633ae With that branch checked out when I run
to train on two 4-chip TPU VMs, I see the following logs:
This shows that the When you A lot of this is hidden by the GSPMD execution model (https://docs.pytorch.org/xla/master/perf/spmd_basic.html). In this model, a tensor's data may be backed by several TPU chips including those from other hosts, but a layer of abstraction presents you the global shape as if the tensor lives locally. This is the same reason why the loss values printed by different hosts are the same. When we print the loss, all the hosts will LMK if this answer makes sense. |
Beta Was this translation helpful? Give feedback.
-
Hi @weirayao I turned this question into a discussion thread. Feel free to ask more questions here and/or mark my answer as accepted. |
Beta Was this translation helpful? Give feedback.
Hi @weirayao , thanks for checking out our code base. I think what you're seeing is consistent with the distributed execution model of torchprime, although the behavior is not intuitive. Here's what's happening:
DataLoader
outputs local tensors stored on thecpu
.MpDeviceLoader
outputs sharded global tensors stored on the TPU. If we print its output, that would give the appearance of duplication.We can get a better understanding of this process by instrumenting the data loaders. I've uploaded an example at