Skip to content

Commit fe44afc

Browse files
author
jax authors
committed
Merge pull request #20161 from rajasekharporeddy:doc_typos
PiperOrigin-RevId: 614813533
2 parents 0bd7070 + 61c64c1 commit fe44afc

13 files changed

+21
-21
lines changed

docs/jax-101/06-parallelism.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@
410410
"## Communication between devices\n",
411411
"\n",
412412
"The above is enough to perform simple parallel operations, e.g. batching a simple MLP forward pass across several devices. However, sometimes we need to pass information between the devices. For example, perhaps we are interested in normalizing the output of each device so they sum to 1.\n",
413-
"For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through `axis_name` argument, and then refer to it when calling the op. Here's how to do that:"
413+
"For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through the `axis_name` argument, and then refer to it when calling the op. Here's how to do that:"
414414
]
415415
},
416416
{

docs/jax-101/06-parallelism.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Keep in mind that when calling the transformed function, the size of the specifi
159159
## Communication between devices
160160

161161
The above is enough to perform simple parallel operations, e.g. batching a simple MLP forward pass across several devices. However, sometimes we need to pass information between the devices. For example, perhaps we are interested in normalizing the output of each device so they sum to 1.
162-
For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through `axis_name` argument, and then refer to it when calling the op. Here's how to do that:
162+
For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through the `axis_name` argument, and then refer to it when calling the op. Here's how to do that:
163163

164164
```{code-cell} ipython3
165165
:id: 0nCxGwqmtd3w

docs/jep/18137-numpy-scipy-scope.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ incompatible with JAX’s computation model. We instead focus on {mod}`jax.rando
140140
which offers similar functionality using a counter-based PRNG.
141141

142142
#### `numpy.ma` & `numpy.polynomial`
143-
The {mod}`numpy.ma` andd {mod}`numpy.polynomial` submodules are mostly concerned with
143+
The {mod}`numpy.ma` and {mod}`numpy.polynomial` submodules are mostly concerned with
144144
providing object-oriented interfaces to computations that can be expressed via other
145145
functional means (Axis 5); for this reason, we deem them out-of-scope for JAX.
146146

@@ -187,7 +187,7 @@ evaluations. {func}`jax.experimental.ode.odeint` is related, but rather limited
187187
under any active development.
188188

189189
JAX does currently include {func}`jax.scipy.integrate.trapezoid`, but this is only because
190-
{func}`numpy.trapz` was recently deprecated in favor of this. For any particular inputs,
190+
{func}`numpy.trapz` was recently deprecated in favor of this. For any particular input,
191191
its implementation could be replaced with one line of {mod}`jax.numpy` expressions, so
192192
it’s not a particularly useful API to provide.
193193

docs/pallas/quickstart.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@
245245
"metadata": {},
246246
"source": [
247247
"With `grid` and `program_id` in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels.\n",
248-
"To build intution, let's try to implement a matrix multiplication.\n",
248+
"To build intuition, let's try to implement a matrix multiplication.\n",
249249
"\n",
250250
"A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones.\n",
251251
"\n",

docs/pallas/quickstart.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ On TPUs, programs are executed in a combination of parallel and sequential (depe
143143
+++
144144

145145
With `grid` and `program_id` in mind, Pallas provides an abstraction that takes care of some common indexing patterns seen in a lot of kernels.
146-
To build intution, let's try to implement a matrix multiplication.
146+
To build intuition, let's try to implement a matrix multiplication.
147147

148148
A simple strategy for implementing a matrix multiplication in Pallas is to implement it recursively. We know our underlying hardware has support for small matrix multiplications (using GPU and TPU tensorcores), so we just express a big matrix multiplication in terms of smaller ones.
149149

docs/pallas/tpu/details.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ Elementwise operations
266266
^^^^^^^^^^^^^^^^^^^^^^
267267

268268
Many elementwise operations are supported. It is worth noting that the hardware
269-
generally only supports elementwise compute using 32-bit types. When loading
269+
generally only supports elementwise computation using 32-bit types. When loading
270270
operands that use lower-precision types, they should generally be upcast to a
271271
32-bit type before applying elementwise ops.
272272

@@ -344,5 +344,5 @@ However, loop primitives get fully unrolled during the compilation at the
344344
moment, so try to keep the loop trip count reasonably small.
345345

346346
Overusing control flow can lead to significant regressions in low-level code
347-
generation, and it is recommended to try to squeeze as many computationaly
347+
generation, and it is recommended to try to squeeze as many computationally
348348
expensive operations into a single basic block as possible.

docs/pallas/tpu/pipelining.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
"source": [
129129
"We've written two functions: `add_matrices_kernel` and `add_matrices`.\n",
130130
"\n",
131-
"`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on then to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.\n",
131+
"`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.\n",
132132
"\n",
133133
"The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`."
134134
]
@@ -497,7 +497,7 @@
497497
"id": "Kv9qJYJY4jbK"
498498
},
499499
"source": [
500-
"Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n",
500+
"Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like to squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.\n",
501501
"\n",
502502
"`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value!\n",
503503
"\n",
@@ -582,7 +582,7 @@
582582
"\n",
583583
"Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?\n",
584584
"\n",
585-
"The basic idea is that if we have embarassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`."
585+
"The basic idea is that if we have embarrassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`."
586586
]
587587
},
588588
{

docs/pallas/tpu/pipelining.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ add_matrices(x, y)
8888

8989
We've written two functions: `add_matrices_kernel` and `add_matrices`.
9090

91-
`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on then to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.
91+
`add_matrices_kernel` operates using `Ref`s that live in VMEM. Loading from a VMEM `Ref` produces a value that lives in VREGs. Values in VREGs behave like `jax.Array`s in that we can use `jnp` and `jax.lax` operations on them to produce new values that live in VREGs. When we produce the values we'd like to return, we store them in the output VMEM `Ref`.
9292

9393
The `add_matrices` function acts on `jax.Array`s and returns a `jax.Array`. Inside it, we pass `x` and `y` into `pallas_call`. `pallas_call` is responsible for copying `x` and `y` into VMEM and for allocating the VMEM buffers that the kernel operates on (including allocating `z_vmem_ref`, the output VMEM buffer). After the kernel function is finished running, `pallas_call` will also copy the value in `z_vmem_ref` to HBM, resulting in an output `jax.Array`.
9494

@@ -311,7 +311,7 @@ naive_sum(x)
311311

312312
+++ {"id": "Kv9qJYJY4jbK"}
313313

314-
Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.
314+
Notice how we've set up the `BlockSpec`s: we're loading the entirety of the `(512, 512)` dimension into VMEM (no pipelining there) but selecting the `i`-th dimension of `x` each iteration in the `index_map`. We are using a `None` for that dimension in the block shape, which indicates that we are selecting a singleton dimension from `x` that we would like to squeeze away in the kernel. Therefore, `x_ref` is `(512, 512)`-shaped in VMEM as well.
315315

316316
`out_spec` uses `lambda i: (0, 0)` as its `index_map`, indicating that `o_ref` is unchanged over the course of the pipeline. This means that we can update its value each iteration by reading from and writing to it. Or can it? Actually there is one catch: *`o_ref` is initially garbage*, meaning we'll be accumulating into garbage. This will result in the overall function outputting the incorrect value!
317317

@@ -359,7 +359,7 @@ Some TPU chips have two TensorCores but appear as one device to JAX users. This
359359

360360
Conceptually, TPUs in Megacore behave like very simple GPUs, i.e. they have only two threads. How do we modify our kernels to utilize both TensorCores simultaneously?
361361

362-
The basic idea is that if we have embarassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`.
362+
The basic idea is that if we have embarrassingly parallel dimensions in our computation, we can split up those dimensions across the TensorCores. We can indicate which dimensions are parallelizable by providing an annotation to `pallas_call` called `dimension_semantics`.
363363

364364
```{code-cell}
365365
:id: nQNa8RaQ-TR1

docs/tutorials/advanced-autodiff.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ JAX's autodiff makes it easy to compute higher-order derivatives, because the fu
3838

3939
The single-variable case was covered in the {ref}`automatic-differentiation` tutorial, where the example showed how to use {func}`jax.grad` to compute the the derivative of $f(x) = x^3 + 2x^2 - 3x + 1$.
4040

41-
In the multi-variable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to:
41+
In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to:
4242

4343
$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$
4444

docs/tutorials/jax-primitives.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ If you attempt now to use reverse differentiation, you'll notice that JAX starts
406406
When computing the reverse differentiation, JAX first performs an abstract evaluation of the forward differentiation code `multiply_add_value_and_jvp` to obtain a trace of primitives that compute the output tangent.
407407

408408
- Observe that JAX performs this abstract evaluation with concrete values for the differentiation point, and abstract values for the tangents.
409-
- Notice that JAX uses the special abstract tangent value `Zero` for the tangent corresponding to the 3rd argument of `ma`. This reflects the fact that you do not differentiate w.r.t. the secibd argument to `square_add_prim`, which flows to the third argument to `multiply_add_prim`.
409+
- Notice that JAX uses the special abstract tangent value `Zero` for the tangent corresponding to the third argument of `ma`. This reflects the fact that you do not differentiate w.r.t. the second argument to `square_add_prim`, which flows to the third argument to `multiply_add_prim`.
410410
- Notice also that during the abstract evaluation of the tangent you pass the value `0.0` as the tangent for the third argument. This is because of the use of the `make_zero` function in the definition of `multiply_add_value_and_jvp`.
411411

412412
```{code-cell}

0 commit comments

Comments
 (0)