Skip to content

Commit 2a65cd3

Browse files
author
jax authors
committed
Merge pull request #20806 from jakevdp:tweak-docs
PiperOrigin-RevId: 625869820
2 parents 1a8aae0 + 48e8457 commit 2a65cd3

10 files changed

+101
-84
lines changed

docs/tutorials/advanced-autodiff.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernelspec:
1313
---
1414

1515
(advanced-autodiff)=
16-
# Advanced automatic differentiation 201
16+
# Advanced automatic differentiation
1717

1818
In this tutorial, you will learn about complex applications of automatic differentiation (autodiff) in JAX and gain a better understanding of how taking derivatives in JAX can be both easy and powerful.docs.g
1919

docs/tutorials/automatic-differentiation.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ kernelspec:
1313
---
1414

1515
(automatic-differentiation)=
16-
# Automatic differentiation 101
16+
# Automatic differentiation
1717

18-
In this tutorial, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
18+
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
1919

2020
- {ref}`automatic-differentiation-taking-gradients`
2121
- {ref}`automatic-differentiation-linear logistic regression`
@@ -28,9 +28,9 @@ Make sure to also check out the {ref}`advanced-autodiff` tutorial for more advan
2828
While understanding how automatic differentiation works "under the hood" isn't crucial for using JAX in most contexts, you are encouraged to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.
2929

3030
(automatic-differentiation-taking-gradients)=
31-
## 1.Taking gradients with `jax.grad`
31+
## 1. Taking gradients with `jax.grad`
3232

33-
In JAX, you can differentiate a function with the {func}`jax.grad` transformation:
33+
In JAX, you can differentiate a scalar-valued function with the {func}`jax.grad` transformation:
3434

3535
```{code-cell}
3636
import jax
@@ -162,7 +162,7 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for
162162
(automatic-differentiation-nested-lists-tuples-and-dicts)=
163163
## 3. Differentiating with respect to nested lists, tuples, and dicts
164164

165-
Due to JAX's PyTree abstraction (see {ref}`pytrees-what-is-a-pytree`), differentiating with
165+
Due to JAX's PyTree abstraction (see {ref}`working-with-pytrees`), differentiating with
166166
respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
167167

168168
Continuing the previous example:
@@ -176,7 +176,7 @@ def loss2(params_dict):
176176
print(grad(loss2)({'W': W, 'b': b}))
177177
```
178178

179-
You can {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on).
179+
You can create {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad` but other JAX transformations ({func}`jax.jit`, {func}`jax.vmap`, and so on).
180180

181181

182182
(automatic-differentiation-evaluating-using-jax-value_and_grad)=

docs/tutorials/automatic-vectorization.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ kernelspec:
1313
---
1414

1515
(automatic-vectorization)=
16-
# Automatic Vectorization in JAX
16+
# Automatic vectorization
1717

18-
In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`.
18+
In the previous section we discussed JIT compilation via the {func}`jax.jit` function.
19+
This notebook discusses another of JAX's transforms: vectorization via {func}`jax.vmap`.
1920

20-
## Manual Vectorization
21+
## Manual vectorization
2122

2223
Consider the following simple code that computes the convolution of two one-dimensional vectors:
2324

@@ -72,9 +73,9 @@ def manually_vectorized_convolve(xs, ws):
7273
manually_vectorized_convolve(xs, ws)
7374
```
7475

75-
Such re-implementation is messy and error-prone; fortunately JAX provides another way.
76+
Such re-implementation can be messy and error-prone as the complexity of a function increases; fortunately JAX provides another way.
7677

77-
## Automatic Vectorization
78+
## Automatic vectorization
7879

7980
In JAX, the {func}`jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:
8081

docs/tutorials/debugging.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,22 @@ kernelspec:
1313
---
1414

1515
(debugging)=
16-
# Debugging 101
16+
# Introduction to debugging
1717

18-
This tutorial introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
18+
This section introduces you to a set of built-in JAX debugging methods — {func}`jax.debug.print`, {func}`jax.debug.breakpoint`, and {func}`jax.debug.callback` — that you can use with various JAX transformations.
1919

2020
Let's begin with {func}`jax.debug.print`.
2121

22-
## JAX `debug.print` for high-level debugging
22+
## JAX `debug.print` for high-level
2323

2424
**TL;DR** Here is a rule of thumb:
2525

2626
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
27-
- Use Python `print` for static values, such as dtypes and array shapes.
27+
- Use Python {func}`print` for static values, such as dtypes and array shapes.
2828

2929
Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`,
3030
the Python code is executed with abstract tracers in place of your arrays. Because of this,
31-
the Python `print` statement will only print this tracer value:
31+
the Python {func}`print` function will only print this tracer value:
3232

3333
```{code-cell}
3434
import jax
@@ -82,7 +82,7 @@ result = jax.lax.map(f, xs)
8282

8383
Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
8484

85-
Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's `print`, but it's consistent if you apply {func}`jax.jit` during the call.
85+
Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only prints the forward pass. In this case, the behavior is similar to Python's {func}`print`, but it's consistent if you apply {func}`jax.jit` during the call.
8686

8787
```{code-cell}
8888
def f(x):

docs/tutorials/jax-primitives.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernelspec:
1313
---
1414

1515
(jax-internals-jax-primitives)=
16-
# JAX internals 301: JAX primitives
16+
# JAX Internals: primitives
1717

1818
## Introduction to JAX primitives
1919

docs/tutorials/jaxpr.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernelspec:
1313
---
1414

1515
(jax-internals-jaxpr)=
16-
# JAX internals 301: The jaxpr language
16+
# JAX internals: The jaxpr language
1717

1818
Jaxprs are JAX’s internal intermediate representation (IR) of programs. They are explicitly typed, functional, first-order, and in algebraic normal form (ANF).
1919

docs/tutorials/jit-compilation.md

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ compilation of a JAX Python function so it can be executed efficiently in XLA.
2929
## How JAX transformations work
3030

3131
In the previous section, we discussed that JAX allows us to transform Python functions.
32-
This is done by first converting the Python function into a simple intermediate language called jaxpr.
33-
The transformations then work on the jaxpr representation.
32+
JAX accomplishes this by reducing each function into a sequence of {term}`primitive` operations, each
33+
representing one fundamental unit of computation.
3434

35-
We can show a representation of the jaxpr of a function by using {func}`jax.make_jaxpr`:
35+
One way to see the sequence of primitives behind a function is using {func}`jax.make_jaxpr`:
3636

3737
```{code-cell}
3838
import jax
@@ -51,9 +51,11 @@ print(jax.make_jaxpr(log2)(3.0))
5151

5252
The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output.
5353

54-
Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
54+
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`.
55+
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
56+
If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
5557

56-
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.
58+
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour under transformations. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JIT-compiled function to run once (during the first call), and never again, due to JAX's traced execution model.
5759

5860
When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.
5961

@@ -73,7 +75,8 @@ See how the printed `x` is a `Traced` object? That's the JAX internals at work.
7375

7476
The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation.
7577

76-
A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:
78+
A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it.
79+
For example, if we have a Python conditional, the jaxpr will only know about the branch we take:
7780

7881
```{code-cell}
7982
def log2_if_rank_2(x):
@@ -143,8 +146,7 @@ def f(x):
143146
else:
144147
return 2 * x
145148
146-
f_jit = jax.jit(f)
147-
f_jit(10) # Should raise an error.
149+
jax.jit(f)(10) # Raises an error
148150
```
149151

150152
```{code-cell}
@@ -158,19 +160,17 @@ def g(x, n):
158160
i += 1
159161
return x + i
160162
161-
g_jit = jax.jit(g)
162-
g_jit(10, 20) # Should raise an error.
163+
jax.jit(g)(10, 20) # Raises an error
163164
```
164165

165-
The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it.
166+
The problem in both cases is that we tried to condition the trace-time flow of the program using runtime values.
167+
Traced values within JIT, like `x` and `n` here, can only affect control flow via their static attributes: such as
168+
`shape` or `dtype`, and not via their values.
169+
For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
166170

167-
The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.
168-
169-
For {func}`jax.jit`, the default level is {class}`~jax.core.ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.
170-
171-
In {func}`jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
172-
173-
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
171+
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special {ref}`lax-control-flow` like {func}`jax.lax.cond`. However, sometimes that is not possible or practical.
172+
In that case, you can consider JIT-compiling only part of the function.
173+
For example, if the most computationally expensive part of the function is inside the loop, we can JIT-compile just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
174174

175175
```{code-cell}
176176
# While loop conditioned on x and n with a jitted body.
@@ -188,7 +188,11 @@ def g_inner_jitted(x, n):
188188
g_inner_jitted(10, 20)
189189
```
190190

191-
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.
191+
## Marking arguments as static
192+
193+
If we really need to JIT-compile a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`.
194+
The cost of this is that the resulting jaxpr and compiled artifact depends on the particular value passed, and so JAX will have to re-compile the function for every new value of the specified static input.
195+
It is only a good strategy if the function is guaranteed to see a limited set of static values.
192196

193197
```{code-cell}
194198
f_jit_correct = jax.jit(f, static_argnums=0)
@@ -227,9 +231,10 @@ print("g:")
227231
%timeit g(10, 20)
228232
```
229233

230-
This is because {func}`jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
234+
This is because {func}`jax.jit` introduces some overhead itself, and so it usually only saves time if the compiled function is nontrivial, or if you will run it numerous times.
235+
Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
231236

232-
Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
237+
Generally, you want to JIT-compile the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
233238

234239
## JIT and caching
235240

docs/tutorials/key-concepts.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ This section briefly introduces some key concepts of the JAX package.
2020
(key-concepts-jax-arrays)=
2121
## JAX arrays ({class}`jax.Array`)
2222

23-
- `jax.Array` is the default array implementation in JAX.
24-
- `jax.Array` objects are never created directly, but rather using familiar
25-
array creation APIs.
26-
- JAX arrays may be stored on a single device, or sharded across many devices.
23+
The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to
24+
the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it
25+
has some important differences.
2726

2827
### Array creation
2928

0 commit comments

Comments
 (0)