@@ -123,23 +123,67 @@ Our github contains many useful docs on working with different aspects of PyTorc
123
123
124
124
## Getting Started
125
125
126
+ Following here are guides for two modes:
127
+ - Single process: one Python interpreter controlling a single GPU/TPU at a time
128
+ - Multi process: N Python interpreters are launched, corresponding to N GPU/TPUs
129
+ found on the system
130
+
131
+ Another mode is SPMD, where one Python interpreter controls all N GPU/TPUs found on
132
+ the system. Multi processing is more complex, and is not compatible with SPMD. This
133
+ tutorial does not dive into SPMD. For more on that, check our
134
+ [ SPMD guide] ( https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_basic.md ) .
135
+
136
+ ### Simple single process
137
+
138
+ To update your exisitng training loop, make the following changes:
139
+
140
+ ``` diff
141
+ + import torch_xla
142
+
143
+ def train(model, training_data, ...):
144
+ ...
145
+ for inputs, labels in train_loader:
146
+ + with torch_xla.step():
147
+ inputs, labels = training_data[i]
148
+ + inputs, labels = inputs.to('xla'), labels.to('xla')
149
+ optimizer.zero_grad()
150
+ outputs = model(inputs)
151
+ loss = loss_fn(outputs, labels)
152
+ loss.backward()
153
+ optimizer.step()
154
+
155
+ + torch_xla.sync()
156
+ ...
157
+
158
+ if __name__ == '__main__':
159
+ ...
160
+ + # Move the model paramters to your XLA device
161
+ + model.to('xla')
162
+ train(model, training_data, ...)
163
+ ...
164
+ ```
165
+
166
+ The changes above should get your model to train on the TPU.
167
+
168
+ ### Multi processing
169
+
126
170
To update your existing training loop, make the following changes:
127
171
128
172
``` diff
129
173
- import torch.multiprocessing as mp
130
- + import torch_xla as xla
174
+ + import torch_xla
131
175
+ import torch_xla.core.xla_model as xm
132
176
133
177
def _mp_fn(index):
134
178
...
135
179
136
180
+ # Move the model paramters to your XLA device
137
- + model.to(xla .device())
181
+ + model.to(torch_xla .device())
138
182
139
183
for inputs, labels in train_loader:
140
- + with xla .step():
184
+ + with torch_xla .step():
141
185
+ # Transfer data to the XLA device. This happens asynchronously.
142
- + inputs, labels = inputs.to(xla .device()), labels.to(xla .device())
186
+ + inputs, labels = inputs.to(torch_xla .device()), labels.to(torch_xla .device())
143
187
optimizer.zero_grad()
144
188
outputs = model(inputs)
145
189
loss = loss_fn(outputs, labels)
@@ -150,8 +194,8 @@ To update your existing training loop, make the following changes:
150
194
151
195
if __name__ == '__main__':
152
196
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
153
- + # xla .launch automatically selects the correct world size
154
- + xla .launch(_mp_fn, args=())
197
+ + # torch_xla .launch automatically selects the correct world size
198
+ + torch_xla .launch(_mp_fn, args=())
155
199
```
156
200
157
201
If you're using ` DistributedDataParallel ` , make the following changes:
0 commit comments