Skip to content

Commit e341ff0

Browse files
authored
8740 add single processing to getting started instructions (#8875)
1 parent 7a3c051 commit e341ff0

File tree

1 file changed

+50
-6
lines changed

1 file changed

+50
-6
lines changed

README.md

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,67 @@ Our github contains many useful docs on working with different aspects of PyTorc
123123

124124
## Getting Started
125125

126+
Following here are guides for two modes:
127+
- Single process: one Python interpreter controlling a single GPU/TPU at a time
128+
- Multi process: N Python interpreters are launched, corresponding to N GPU/TPUs
129+
found on the system
130+
131+
Another mode is SPMD, where one Python interpreter controls all N GPU/TPUs found on
132+
the system. Multi processing is more complex, and is not compatible with SPMD. This
133+
tutorial does not dive into SPMD. For more on that, check our
134+
[SPMD guide](https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_basic.md).
135+
136+
### Simple single process
137+
138+
To update your exisitng training loop, make the following changes:
139+
140+
```diff
141+
+import torch_xla
142+
143+
def train(model, training_data, ...):
144+
...
145+
for inputs, labels in train_loader:
146+
+ with torch_xla.step():
147+
inputs, labels = training_data[i]
148+
+ inputs, labels = inputs.to('xla'), labels.to('xla')
149+
optimizer.zero_grad()
150+
outputs = model(inputs)
151+
loss = loss_fn(outputs, labels)
152+
loss.backward()
153+
optimizer.step()
154+
155+
+ torch_xla.sync()
156+
...
157+
158+
if __name__ == '__main__':
159+
...
160+
+ # Move the model paramters to your XLA device
161+
+ model.to('xla')
162+
train(model, training_data, ...)
163+
...
164+
```
165+
166+
The changes above should get your model to train on the TPU.
167+
168+
### Multi processing
169+
126170
To update your existing training loop, make the following changes:
127171

128172
```diff
129173
-import torch.multiprocessing as mp
130-
+import torch_xla as xla
174+
+import torch_xla
131175
+import torch_xla.core.xla_model as xm
132176

133177
def _mp_fn(index):
134178
...
135179

136180
+ # Move the model paramters to your XLA device
137-
+ model.to(xla.device())
181+
+ model.to(torch_xla.device())
138182

139183
for inputs, labels in train_loader:
140-
+ with xla.step():
184+
+ with torch_xla.step():
141185
+ # Transfer data to the XLA device. This happens asynchronously.
142-
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
186+
+ inputs, labels = inputs.to(torch_xla.device()), labels.to(torch_xla.device())
143187
optimizer.zero_grad()
144188
outputs = model(inputs)
145189
loss = loss_fn(outputs, labels)
@@ -150,8 +194,8 @@ To update your existing training loop, make the following changes:
150194

151195
if __name__ == '__main__':
152196
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
153-
+ # xla.launch automatically selects the correct world size
154-
+ xla.launch(_mp_fn, args=())
197+
+ # torch_xla.launch automatically selects the correct world size
198+
+ torch_xla.launch(_mp_fn, args=())
155199
```
156200

157201
If you're using `DistributedDataParallel`, make the following changes:

0 commit comments

Comments
 (0)