@@ -28,11 +28,11 @@ Key concepts:
28
28
* Each process has a
29
29
distinct set of * local* devices it can address. The * global* devices are the set
30
30
of all devices across all processes.
31
- * Use standard JAX parallelism APIs like {func}` ~jax.pmap ` and
32
- {func}` ~jax.experimental.maps.xmap ` . Each process “sees” * local* input and
31
+ * Use standard JAX parallelism APIs like {func}` ~jax.pmap ` and
32
+ {func}` ~jax.experimental.maps.xmap ` . Each process “sees” * local* input and
33
33
output to parallelized functions, but communication inside the computations
34
34
is * global* .
35
- * Make sure all processes run the same parallel computations in the same
35
+ * Make sure all processes run the same parallel computations in the same
36
36
order.
37
37
38
38
### Launching JAX processes
@@ -106,13 +106,13 @@ only launch computations on the 8 TPU cores attached directly to that host (see
106
106
the
107
107
[ Cloud TPU System Architecture] ( https://cloud.google.com/tpu/docs/system-architecture )
108
108
documentation for more details). You can see a process’s local devices via
109
- {func}` jax.local_devices() ` .
109
+ {func}` jax.local_devices() ` .
110
110
111
111
** The * global* devices are the devices across all processes.** A computation can
112
112
span devices across processes and perform collective operations via the direct
113
113
communication links between devices, as long as each process launches the
114
114
computation on its local devices. You can see all available global devices via
115
- {func}` jax.devices() ` . A process’s local devices are always a subset of the
115
+ {func}` jax.devices() ` . A process’s local devices are always a subset of the
116
116
global devices.
117
117
118
118
### Running multi-process computations
0 commit comments